Refactor
This commit is contained in:
parent
86152eafb9
commit
94251f1968
|
@ -27,7 +27,7 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
func initConfig() {
|
||||
parseConfigFile(appConfig)
|
||||
// Replace values that are set via environment vars (to make it compatible with old method)
|
||||
overwriteEnvVarValues(appConfig)
|
||||
|
|
17
database.go
17
database.go
|
@ -19,17 +19,20 @@ type Database struct {
|
|||
trackingStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func initDatabase() (database *Database, e error) {
|
||||
database = &Database{}
|
||||
var (
|
||||
db = &Database{}
|
||||
)
|
||||
|
||||
func initDatabase() (e error) {
|
||||
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
|
||||
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
|
||||
}
|
||||
database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
|
||||
db.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
e = migrateDatabase(database.sqlDB)
|
||||
database.trackingStmt, e = database.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)")
|
||||
e = migrateDatabase(db.sqlDB)
|
||||
db.trackingStmt, e = db.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)")
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
|
@ -46,7 +49,7 @@ func migrateDatabase(database *sql.DB) (e error) {
|
|||
|
||||
// Tracking
|
||||
|
||||
func (db *Database) trackView(urlString string, ref string, ua string) {
|
||||
func trackView(urlString string, ref string, ua string) {
|
||||
if len(urlString) == 0 {
|
||||
// Don't track empty urls
|
||||
return
|
||||
|
@ -104,7 +107,7 @@ type RequestResultRow struct {
|
|||
Second int `json:"second"`
|
||||
}
|
||||
|
||||
func (db *Database) request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) {
|
||||
func request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) {
|
||||
statement, parameters := request.buildStatement()
|
||||
namedArgs := make([]interface{}, len(parameters))
|
||||
for i, v := range parameters {
|
||||
|
|
|
@ -5,8 +5,6 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
|
||||
|
||||
func TestViewsRequest_buildDateTimeFilter(t *testing.T) {
|
||||
t.Run("No DateTime filter", func(t *testing.T) {
|
||||
request := &ViewsRequest{
|
||||
|
|
187
main.go
187
main.go
|
@ -1,21 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gobuffalo/packr/v2"
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/kis3/kis3/helpers"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type kis3 struct {
|
||||
db *Database
|
||||
router *mux.Router
|
||||
staticBox *packr.Box
|
||||
}
|
||||
|
@ -27,181 +23,34 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
e := setupDB()
|
||||
initConfig()
|
||||
e := initDatabase()
|
||||
if e != nil {
|
||||
log.Fatal("Database setup failed:", e)
|
||||
}
|
||||
setupRouter()
|
||||
setupReports()
|
||||
initRouter()
|
||||
}
|
||||
|
||||
func main() {
|
||||
startListening()
|
||||
go startListeningToWeb()
|
||||
go startReports()
|
||||
// Graceful stop
|
||||
var gracefulStop = make(chan os.Signal, 1)
|
||||
signal.Notify(gracefulStop, os.Interrupt, syscall.SIGTERM)
|
||||
sig := <-gracefulStop
|
||||
fmt.Printf("Received signal: %+v", sig)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func setupDB() (e error) {
|
||||
app.db, e = initDatabase()
|
||||
return
|
||||
}
|
||||
|
||||
func setupRouter() {
|
||||
func initRouter() {
|
||||
app.router = mux.NewRouter()
|
||||
|
||||
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"}))
|
||||
|
||||
viewRouter := app.router.PathPrefix("/view").Subrouter()
|
||||
viewRouter.Use(corsHandler)
|
||||
viewRouter.Path("").HandlerFunc(TrackingHandler)
|
||||
|
||||
app.router.HandleFunc("/stats", StatsHandler)
|
||||
|
||||
staticRouter := app.router.PathPrefix("").Subrouter()
|
||||
staticRouter.Use(corsHandler)
|
||||
staticRouter.HandleFunc("/kis3.js", TrackingScriptHandler)
|
||||
staticRouter.PathPrefix("").Handler(http.HandlerFunc(HelloResponseHandler))
|
||||
initStatsRouter()
|
||||
initTrackingRouter()
|
||||
}
|
||||
|
||||
func startListening() {
|
||||
func startListeningToWeb() {
|
||||
port := appConfig.Port
|
||||
addr := ":" + port
|
||||
fmt.Printf("Listening to %s\n", addr)
|
||||
log.Fatal(http.ListenAndServe(addr, app.router))
|
||||
}
|
||||
|
||||
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
|
||||
url := r.URL.Query().Get("url")
|
||||
ref := r.URL.Query().Get("ref")
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
|
||||
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
|
||||
_, _ = fmt.Fprint(w, "true")
|
||||
}
|
||||
}
|
||||
|
||||
func HelloResponseHandler(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = fmt.Fprint(w, "Hello from KISSS")
|
||||
}
|
||||
|
||||
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/javascript")
|
||||
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
|
||||
filename := "kis3.js"
|
||||
file, err := app.staticBox.Open(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
http.ServeContent(w, r, filename, stat.ModTime(), file)
|
||||
}
|
||||
|
||||
func StatsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Require authentication
|
||||
if appConfig.statsAuth() {
|
||||
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Do request
|
||||
queries := r.URL.Query()
|
||||
view := PAGES
|
||||
switch strings.ToLower(queries.Get("view")) {
|
||||
case "pages":
|
||||
view = PAGES
|
||||
case "referrers":
|
||||
view = REFERRERS
|
||||
case "useragents":
|
||||
view = USERAGENTS
|
||||
case "useragentnames":
|
||||
view = USERAGENTNAMES
|
||||
case "hours":
|
||||
view = HOURS
|
||||
case "days":
|
||||
view = DAYS
|
||||
case "weeks":
|
||||
view = WEEKS
|
||||
case "months":
|
||||
view = MONTHS
|
||||
case "allhours":
|
||||
view = ALLHOURS
|
||||
case "alldays":
|
||||
view = ALLDAYS
|
||||
case "count":
|
||||
view = COUNT
|
||||
}
|
||||
result, e := app.db.request(&ViewsRequest{
|
||||
view: view,
|
||||
from: queries.Get("from"),
|
||||
fromRel: queries.Get("fromrel"),
|
||||
to: queries.Get("to"),
|
||||
toRel: queries.Get("torel"),
|
||||
url: queries.Get("url"),
|
||||
ref: queries.Get("ref"),
|
||||
ua: queries.Get("ua"),
|
||||
ordercol: strings.ToLower(queries.Get("ordercol")),
|
||||
order: strings.ToUpper(queries.Get("order")),
|
||||
limit: queries.Get("limit"),
|
||||
})
|
||||
if e != nil {
|
||||
fmt.Println("Database request failed:", e)
|
||||
w.WriteHeader(500)
|
||||
} else if result != nil {
|
||||
w.Header().Set("Cache-Control", "max-age=0")
|
||||
switch queries.Get("format") {
|
||||
case "json":
|
||||
sendJsonResponse(result, w)
|
||||
case "chart":
|
||||
sendChartResponse(result, w)
|
||||
default: // "plain"
|
||||
sendPlainResponse(result, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
for _, row := range result {
|
||||
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
|
||||
}
|
||||
}
|
||||
|
||||
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
jsonBytes, _ := json.Marshal(result)
|
||||
_, _ = fmt.Fprintln(w, string(jsonBytes))
|
||||
}
|
||||
|
||||
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
labels := make([]string, len(result))
|
||||
values := make([]int, len(result))
|
||||
for i, row := range result {
|
||||
labels[i] = row.First
|
||||
values[i] = row.Second
|
||||
}
|
||||
chartJSString, e := app.staticBox.FindString("Chart.min.js")
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
data := struct {
|
||||
Labels []string
|
||||
Values []int
|
||||
ChartJS template.JS
|
||||
}{
|
||||
Labels: labels,
|
||||
Values: values,
|
||||
ChartJS: template.JS(chartJSString),
|
||||
}
|
||||
chartTemplateString, e := app.staticBox.FindString("chart.html")
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
t, e := template.New("chart").Parse(chartTemplateString)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
_ = t.Execute(w, data)
|
||||
}
|
||||
|
|
26
main_test.go
26
main_test.go
|
@ -1,26 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHelloResponseHandler(t *testing.T) {
|
||||
t.Run("Hello response", func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(HelloResponseHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
expected := "Hello from KISSS"
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -25,7 +25,7 @@ type report struct {
|
|||
TGUserId int64 `json:"tgUserId"`
|
||||
}
|
||||
|
||||
func setupReports() {
|
||||
func startReports() {
|
||||
scheduler := clockwork.NewScheduler()
|
||||
for _, r := range appConfig.Reports {
|
||||
scheduledReport := r
|
||||
|
@ -33,7 +33,7 @@ func setupReports() {
|
|||
executeReport(&scheduledReport)
|
||||
})
|
||||
}
|
||||
go scheduler.Run()
|
||||
scheduler.Run()
|
||||
}
|
||||
|
||||
func executeReport(r *report) {
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/kis3/kis3/helpers"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func initStatsRouter() {
|
||||
app.router.HandleFunc("/stats", StatsHandler)
|
||||
}
|
||||
|
||||
func StatsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Require authentication
|
||||
if appConfig.statsAuth() {
|
||||
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Do request
|
||||
queries := r.URL.Query()
|
||||
view := PAGES
|
||||
switch strings.ToLower(queries.Get("view")) {
|
||||
case "pages":
|
||||
view = PAGES
|
||||
case "referrers":
|
||||
view = REFERRERS
|
||||
case "useragents":
|
||||
view = USERAGENTS
|
||||
case "useragentnames":
|
||||
view = USERAGENTNAMES
|
||||
case "hours":
|
||||
view = HOURS
|
||||
case "days":
|
||||
view = DAYS
|
||||
case "weeks":
|
||||
view = WEEKS
|
||||
case "months":
|
||||
view = MONTHS
|
||||
case "allhours":
|
||||
view = ALLHOURS
|
||||
case "alldays":
|
||||
view = ALLDAYS
|
||||
case "count":
|
||||
view = COUNT
|
||||
}
|
||||
result, e := request(&ViewsRequest{
|
||||
view: view,
|
||||
from: queries.Get("from"),
|
||||
fromRel: queries.Get("fromrel"),
|
||||
to: queries.Get("to"),
|
||||
toRel: queries.Get("torel"),
|
||||
url: queries.Get("url"),
|
||||
ref: queries.Get("ref"),
|
||||
ua: queries.Get("ua"),
|
||||
ordercol: strings.ToLower(queries.Get("ordercol")),
|
||||
order: strings.ToUpper(queries.Get("order")),
|
||||
limit: queries.Get("limit"),
|
||||
})
|
||||
if e != nil {
|
||||
fmt.Println("Database request failed:", e)
|
||||
w.WriteHeader(500)
|
||||
} else if result != nil {
|
||||
w.Header().Set("Cache-Control", "max-age=0")
|
||||
switch queries.Get("format") {
|
||||
case "json":
|
||||
sendJsonResponse(result, w)
|
||||
case "chart":
|
||||
sendChartResponse(result, w)
|
||||
default: // "plain"
|
||||
sendPlainResponse(result, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
for _, row := range result {
|
||||
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
|
||||
}
|
||||
}
|
||||
|
||||
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
jsonBytes, _ := json.Marshal(result)
|
||||
_, _ = fmt.Fprintln(w, string(jsonBytes))
|
||||
}
|
||||
|
||||
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||
labels := make([]string, len(result))
|
||||
values := make([]int, len(result))
|
||||
for i, row := range result {
|
||||
labels[i] = row.First
|
||||
values[i] = row.Second
|
||||
}
|
||||
chartJSString, e := app.staticBox.FindString("Chart.min.js")
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
data := struct {
|
||||
Labels []string
|
||||
Values []int
|
||||
ChartJS template.JS
|
||||
}{
|
||||
Labels: labels,
|
||||
Values: values,
|
||||
ChartJS: template.JS(chartJSString),
|
||||
}
|
||||
chartTemplateString, e := app.staticBox.FindString("chart.html")
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
t, e := template.New("chart").Parse(chartTemplateString)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
_ = t.Execute(w, data)
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gorilla/handlers"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func initTrackingRouter() {
|
||||
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"}))
|
||||
viewRouter := app.router.Path("/view").Subrouter()
|
||||
viewRouter.Use(corsHandler)
|
||||
viewRouter.Path("").HandlerFunc(TrackingHandler)
|
||||
scriptRouter := app.router.Path("/kis3.js").Subrouter()
|
||||
scriptRouter.Use(corsHandler)
|
||||
scriptRouter.HandleFunc("", TrackingScriptHandler)
|
||||
}
|
||||
|
||||
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
|
||||
url := r.URL.Query().Get("url")
|
||||
ref := r.URL.Query().Get("ref")
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
|
||||
go trackView(url, ref, ua) // run with goroutine for awesome speed!
|
||||
_, _ = fmt.Fprint(w, "true")
|
||||
}
|
||||
}
|
||||
|
||||
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/javascript")
|
||||
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
|
||||
filename := "kis3.js"
|
||||
file, err := app.staticBox.Open(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
http.ServeContent(w, r, filename, stat.ModTime(), file)
|
||||
}
|
Reference in New Issue