diff --git a/config.go b/config.go index bce1aa0..5ce6d6e 100644 --- a/config.go +++ b/config.go @@ -27,7 +27,7 @@ var ( } ) -func init() { +func initConfig() { parseConfigFile(appConfig) // Replace values that are set via environment vars (to make it compatible with old method) overwriteEnvVarValues(appConfig) diff --git a/database.go b/database.go index 990c5b5..2db75f0 100644 --- a/database.go +++ b/database.go @@ -19,17 +19,20 @@ type Database struct { trackingStmt *sql.Stmt } -func initDatabase() (database *Database, e error) { - database = &Database{} +var ( + db = &Database{} +) + +func initDatabase() (e error) { if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) { _ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm) } - database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath) + db.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath) if e != nil { return } - e = migrateDatabase(database.sqlDB) - database.trackingStmt, e = database.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)") + e = migrateDatabase(db.sqlDB) + db.trackingStmt, e = db.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)") if e != nil { return } @@ -46,7 +49,7 @@ func migrateDatabase(database *sql.DB) (e error) { // Tracking -func (db *Database) trackView(urlString string, ref string, ua string) { +func trackView(urlString string, ref string, ua string) { if len(urlString) == 0 { // Don't track empty urls return @@ -104,7 +107,7 @@ type RequestResultRow struct { Second int `json:"second"` } -func (db *Database) request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) { +func request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) { statement, parameters := request.buildStatement() namedArgs := make([]interface{}, len(parameters)) for i, v := range parameters { diff --git a/database_test.go b/database_test.go index c7ad0a7..f4ead8a 100644 --- a/database_test.go +++ b/database_test.go @@ -5,8 +5,6 @@ import ( "testing" ) - - func TestViewsRequest_buildDateTimeFilter(t *testing.T) { t.Run("No DateTime filter", func(t *testing.T) { request := &ViewsRequest{ diff --git a/main.go b/main.go index 6f005e6..afb67bf 100644 --- a/main.go +++ b/main.go @@ -1,21 +1,17 @@ package main import ( - "encoding/json" "fmt" "github.com/gobuffalo/packr/v2" - "github.com/gorilla/handlers" "github.com/gorilla/mux" - "github.com/kis3/kis3/helpers" - "html/template" "log" "net/http" - "strconv" - "strings" + "os" + "os/signal" + "syscall" ) type kis3 struct { - db *Database router *mux.Router staticBox *packr.Box } @@ -27,181 +23,34 @@ var ( ) func init() { - e := setupDB() + initConfig() + e := initDatabase() if e != nil { log.Fatal("Database setup failed:", e) } - setupRouter() - setupReports() + initRouter() } func main() { - startListening() + go startListeningToWeb() + go startReports() + // Graceful stop + var gracefulStop = make(chan os.Signal, 1) + signal.Notify(gracefulStop, os.Interrupt, syscall.SIGTERM) + sig := <-gracefulStop + fmt.Printf("Received signal: %+v", sig) + os.Exit(0) } -func setupDB() (e error) { - app.db, e = initDatabase() - return -} - -func setupRouter() { +func initRouter() { app.router = mux.NewRouter() - - corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"})) - - viewRouter := app.router.PathPrefix("/view").Subrouter() - viewRouter.Use(corsHandler) - viewRouter.Path("").HandlerFunc(TrackingHandler) - - app.router.HandleFunc("/stats", StatsHandler) - - staticRouter := app.router.PathPrefix("").Subrouter() - staticRouter.Use(corsHandler) - staticRouter.HandleFunc("/kis3.js", TrackingScriptHandler) - staticRouter.PathPrefix("").Handler(http.HandlerFunc(HelloResponseHandler)) + initStatsRouter() + initTrackingRouter() } -func startListening() { +func startListeningToWeb() { port := appConfig.Port addr := ":" + port fmt.Printf("Listening to %s\n", addr) log.Fatal(http.ListenAndServe(addr, app.router)) } - -func TrackingHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0") - url := r.URL.Query().Get("url") - ref := r.URL.Query().Get("ref") - ua := r.Header.Get("User-Agent") - if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) { - go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed! - _, _ = fmt.Fprint(w, "true") - } -} - -func HelloResponseHandler(w http.ResponseWriter, _ *http.Request) { - _, _ = fmt.Fprint(w, "Hello from KISSS") -} - -func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/javascript") - w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days - filename := "kis3.js" - file, err := app.staticBox.Open(filename) - if err != nil { - return - } - defer file.Close() - stat, err := file.Stat() - if err != nil { - return - } - http.ServeContent(w, r, filename, stat.ModTime(), file) -} - -func StatsHandler(w http.ResponseWriter, r *http.Request) { - // Require authentication - if appConfig.statsAuth() { - if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) { - return - } - } - // Do request - queries := r.URL.Query() - view := PAGES - switch strings.ToLower(queries.Get("view")) { - case "pages": - view = PAGES - case "referrers": - view = REFERRERS - case "useragents": - view = USERAGENTS - case "useragentnames": - view = USERAGENTNAMES - case "hours": - view = HOURS - case "days": - view = DAYS - case "weeks": - view = WEEKS - case "months": - view = MONTHS - case "allhours": - view = ALLHOURS - case "alldays": - view = ALLDAYS - case "count": - view = COUNT - } - result, e := app.db.request(&ViewsRequest{ - view: view, - from: queries.Get("from"), - fromRel: queries.Get("fromrel"), - to: queries.Get("to"), - toRel: queries.Get("torel"), - url: queries.Get("url"), - ref: queries.Get("ref"), - ua: queries.Get("ua"), - ordercol: strings.ToLower(queries.Get("ordercol")), - order: strings.ToUpper(queries.Get("order")), - limit: queries.Get("limit"), - }) - if e != nil { - fmt.Println("Database request failed:", e) - w.WriteHeader(500) - } else if result != nil { - w.Header().Set("Cache-Control", "max-age=0") - switch queries.Get("format") { - case "json": - sendJsonResponse(result, w) - case "chart": - sendChartResponse(result, w) - default: // "plain" - sendPlainResponse(result, w) - } - } -} - -func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) { - w.Header().Set("Content-Type", "text/plain") - for _, row := range result { - _, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second)) - } -} - -func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) { - w.Header().Set("Content-Type", "application/json") - jsonBytes, _ := json.Marshal(result) - _, _ = fmt.Fprintln(w, string(jsonBytes)) -} - -func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) { - labels := make([]string, len(result)) - values := make([]int, len(result)) - for i, row := range result { - labels[i] = row.First - values[i] = row.Second - } - chartJSString, e := app.staticBox.FindString("Chart.min.js") - if e != nil { - return - } - data := struct { - Labels []string - Values []int - ChartJS template.JS - }{ - Labels: labels, - Values: values, - ChartJS: template.JS(chartJSString), - } - chartTemplateString, e := app.staticBox.FindString("chart.html") - if e != nil { - return - } - t, e := template.New("chart").Parse(chartTemplateString) - if e != nil { - return - } - _ = t.Execute(w, data) -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 33a2186..0000000 --- a/main_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package main - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestHelloResponseHandler(t *testing.T) { - t.Run("Hello response", func(t *testing.T) { - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) - } - rr := httptest.NewRecorder() - handler := http.HandlerFunc(HelloResponseHandler) - handler.ServeHTTP(rr, req) - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) - } - expected := "Hello from KISSS" - if rr.Body.String() != expected { - t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) - } - }) -} diff --git a/reports.go b/reports.go index a0ed1dc..efd019b 100644 --- a/reports.go +++ b/reports.go @@ -25,7 +25,7 @@ type report struct { TGUserId int64 `json:"tgUserId"` } -func setupReports() { +func startReports() { scheduler := clockwork.NewScheduler() for _, r := range appConfig.Reports { scheduledReport := r @@ -33,7 +33,7 @@ func setupReports() { executeReport(&scheduledReport) }) } - go scheduler.Run() + scheduler.Run() } func executeReport(r *report) { diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..d5dca93 --- /dev/null +++ b/stats.go @@ -0,0 +1,122 @@ +package main + +import ( + "encoding/json" + "fmt" + "github.com/kis3/kis3/helpers" + "html/template" + "net/http" + "strconv" + "strings" +) + +func initStatsRouter() { + app.router.HandleFunc("/stats", StatsHandler) +} + +func StatsHandler(w http.ResponseWriter, r *http.Request) { + // Require authentication + if appConfig.statsAuth() { + if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) { + return + } + } + // Do request + queries := r.URL.Query() + view := PAGES + switch strings.ToLower(queries.Get("view")) { + case "pages": + view = PAGES + case "referrers": + view = REFERRERS + case "useragents": + view = USERAGENTS + case "useragentnames": + view = USERAGENTNAMES + case "hours": + view = HOURS + case "days": + view = DAYS + case "weeks": + view = WEEKS + case "months": + view = MONTHS + case "allhours": + view = ALLHOURS + case "alldays": + view = ALLDAYS + case "count": + view = COUNT + } + result, e := request(&ViewsRequest{ + view: view, + from: queries.Get("from"), + fromRel: queries.Get("fromrel"), + to: queries.Get("to"), + toRel: queries.Get("torel"), + url: queries.Get("url"), + ref: queries.Get("ref"), + ua: queries.Get("ua"), + ordercol: strings.ToLower(queries.Get("ordercol")), + order: strings.ToUpper(queries.Get("order")), + limit: queries.Get("limit"), + }) + if e != nil { + fmt.Println("Database request failed:", e) + w.WriteHeader(500) + } else if result != nil { + w.Header().Set("Cache-Control", "max-age=0") + switch queries.Get("format") { + case "json": + sendJsonResponse(result, w) + case "chart": + sendChartResponse(result, w) + default: // "plain" + sendPlainResponse(result, w) + } + } +} + +func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/plain") + for _, row := range result { + _, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second)) + } +} + +func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + jsonBytes, _ := json.Marshal(result) + _, _ = fmt.Fprintln(w, string(jsonBytes)) +} + +func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) { + labels := make([]string, len(result)) + values := make([]int, len(result)) + for i, row := range result { + labels[i] = row.First + values[i] = row.Second + } + chartJSString, e := app.staticBox.FindString("Chart.min.js") + if e != nil { + return + } + data := struct { + Labels []string + Values []int + ChartJS template.JS + }{ + Labels: labels, + Values: values, + ChartJS: template.JS(chartJSString), + } + chartTemplateString, e := app.staticBox.FindString("chart.html") + if e != nil { + return + } + t, e := template.New("chart").Parse(chartTemplateString) + if e != nil { + return + } + _ = t.Execute(w, data) +} diff --git a/tracking.go b/tracking.go new file mode 100644 index 0000000..94cc3be --- /dev/null +++ b/tracking.go @@ -0,0 +1,44 @@ +package main + +import ( + "fmt" + "github.com/gorilla/handlers" + "net/http" +) + +func initTrackingRouter() { + corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"})) + viewRouter := app.router.Path("/view").Subrouter() + viewRouter.Use(corsHandler) + viewRouter.Path("").HandlerFunc(TrackingHandler) + scriptRouter := app.router.Path("/kis3.js").Subrouter() + scriptRouter.Use(corsHandler) + scriptRouter.HandleFunc("", TrackingScriptHandler) +} + +func TrackingHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0") + url := r.URL.Query().Get("url") + ref := r.URL.Query().Get("ref") + ua := r.Header.Get("User-Agent") + if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) { + go trackView(url, ref, ua) // run with goroutine for awesome speed! + _, _ = fmt.Fprint(w, "true") + } +} + +func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/javascript") + w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days + filename := "kis3.js" + file, err := app.staticBox.Open(filename) + if err != nil { + return + } + defer file.Close() + stat, err := file.Stat() + if err != nil { + return + } + http.ServeContent(w, r, filename, stat.ModTime(), file) +}