Refactor
This commit is contained in:
parent
86152eafb9
commit
94251f1968
|
@ -27,7 +27,7 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func initConfig() {
|
||||||
parseConfigFile(appConfig)
|
parseConfigFile(appConfig)
|
||||||
// Replace values that are set via environment vars (to make it compatible with old method)
|
// Replace values that are set via environment vars (to make it compatible with old method)
|
||||||
overwriteEnvVarValues(appConfig)
|
overwriteEnvVarValues(appConfig)
|
||||||
|
|
17
database.go
17
database.go
|
@ -19,17 +19,20 @@ type Database struct {
|
||||||
trackingStmt *sql.Stmt
|
trackingStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDatabase() (database *Database, e error) {
|
var (
|
||||||
database = &Database{}
|
db = &Database{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDatabase() (e error) {
|
||||||
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
|
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
|
||||||
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
|
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
|
||||||
}
|
}
|
||||||
database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
|
db.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
e = migrateDatabase(database.sqlDB)
|
e = migrateDatabase(db.sqlDB)
|
||||||
database.trackingStmt, e = database.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)")
|
db.trackingStmt, e = db.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)")
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -46,7 +49,7 @@ func migrateDatabase(database *sql.DB) (e error) {
|
||||||
|
|
||||||
// Tracking
|
// Tracking
|
||||||
|
|
||||||
func (db *Database) trackView(urlString string, ref string, ua string) {
|
func trackView(urlString string, ref string, ua string) {
|
||||||
if len(urlString) == 0 {
|
if len(urlString) == 0 {
|
||||||
// Don't track empty urls
|
// Don't track empty urls
|
||||||
return
|
return
|
||||||
|
@ -104,7 +107,7 @@ type RequestResultRow struct {
|
||||||
Second int `json:"second"`
|
Second int `json:"second"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *Database) request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) {
|
func request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) {
|
||||||
statement, parameters := request.buildStatement()
|
statement, parameters := request.buildStatement()
|
||||||
namedArgs := make([]interface{}, len(parameters))
|
namedArgs := make([]interface{}, len(parameters))
|
||||||
for i, v := range parameters {
|
for i, v := range parameters {
|
||||||
|
|
|
@ -5,8 +5,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func TestViewsRequest_buildDateTimeFilter(t *testing.T) {
|
func TestViewsRequest_buildDateTimeFilter(t *testing.T) {
|
||||||
t.Run("No DateTime filter", func(t *testing.T) {
|
t.Run("No DateTime filter", func(t *testing.T) {
|
||||||
request := &ViewsRequest{
|
request := &ViewsRequest{
|
||||||
|
|
187
main.go
187
main.go
|
@ -1,21 +1,17 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gobuffalo/packr/v2"
|
"github.com/gobuffalo/packr/v2"
|
||||||
"github.com/gorilla/handlers"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/kis3/kis3/helpers"
|
|
||||||
"html/template"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"os"
|
||||||
"strings"
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
type kis3 struct {
|
type kis3 struct {
|
||||||
db *Database
|
|
||||||
router *mux.Router
|
router *mux.Router
|
||||||
staticBox *packr.Box
|
staticBox *packr.Box
|
||||||
}
|
}
|
||||||
|
@ -27,181 +23,34 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
e := setupDB()
|
initConfig()
|
||||||
|
e := initDatabase()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Fatal("Database setup failed:", e)
|
log.Fatal("Database setup failed:", e)
|
||||||
}
|
}
|
||||||
setupRouter()
|
initRouter()
|
||||||
setupReports()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
startListening()
|
go startListeningToWeb()
|
||||||
|
go startReports()
|
||||||
|
// Graceful stop
|
||||||
|
var gracefulStop = make(chan os.Signal, 1)
|
||||||
|
signal.Notify(gracefulStop, os.Interrupt, syscall.SIGTERM)
|
||||||
|
sig := <-gracefulStop
|
||||||
|
fmt.Printf("Received signal: %+v", sig)
|
||||||
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupDB() (e error) {
|
func initRouter() {
|
||||||
app.db, e = initDatabase()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRouter() {
|
|
||||||
app.router = mux.NewRouter()
|
app.router = mux.NewRouter()
|
||||||
|
initStatsRouter()
|
||||||
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"}))
|
initTrackingRouter()
|
||||||
|
|
||||||
viewRouter := app.router.PathPrefix("/view").Subrouter()
|
|
||||||
viewRouter.Use(corsHandler)
|
|
||||||
viewRouter.Path("").HandlerFunc(TrackingHandler)
|
|
||||||
|
|
||||||
app.router.HandleFunc("/stats", StatsHandler)
|
|
||||||
|
|
||||||
staticRouter := app.router.PathPrefix("").Subrouter()
|
|
||||||
staticRouter.Use(corsHandler)
|
|
||||||
staticRouter.HandleFunc("/kis3.js", TrackingScriptHandler)
|
|
||||||
staticRouter.PathPrefix("").Handler(http.HandlerFunc(HelloResponseHandler))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func startListening() {
|
func startListeningToWeb() {
|
||||||
port := appConfig.Port
|
port := appConfig.Port
|
||||||
addr := ":" + port
|
addr := ":" + port
|
||||||
fmt.Printf("Listening to %s\n", addr)
|
fmt.Printf("Listening to %s\n", addr)
|
||||||
log.Fatal(http.ListenAndServe(addr, app.router))
|
log.Fatal(http.ListenAndServe(addr, app.router))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
|
|
||||||
url := r.URL.Query().Get("url")
|
|
||||||
ref := r.URL.Query().Get("ref")
|
|
||||||
ua := r.Header.Get("User-Agent")
|
|
||||||
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
|
|
||||||
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
|
|
||||||
_, _ = fmt.Fprint(w, "true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func HelloResponseHandler(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
_, _ = fmt.Fprint(w, "Hello from KISSS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/javascript")
|
|
||||||
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
|
|
||||||
filename := "kis3.js"
|
|
||||||
file, err := app.staticBox.Open(filename)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
stat, err := file.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.ServeContent(w, r, filename, stat.ModTime(), file)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StatsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Require authentication
|
|
||||||
if appConfig.statsAuth() {
|
|
||||||
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Do request
|
|
||||||
queries := r.URL.Query()
|
|
||||||
view := PAGES
|
|
||||||
switch strings.ToLower(queries.Get("view")) {
|
|
||||||
case "pages":
|
|
||||||
view = PAGES
|
|
||||||
case "referrers":
|
|
||||||
view = REFERRERS
|
|
||||||
case "useragents":
|
|
||||||
view = USERAGENTS
|
|
||||||
case "useragentnames":
|
|
||||||
view = USERAGENTNAMES
|
|
||||||
case "hours":
|
|
||||||
view = HOURS
|
|
||||||
case "days":
|
|
||||||
view = DAYS
|
|
||||||
case "weeks":
|
|
||||||
view = WEEKS
|
|
||||||
case "months":
|
|
||||||
view = MONTHS
|
|
||||||
case "allhours":
|
|
||||||
view = ALLHOURS
|
|
||||||
case "alldays":
|
|
||||||
view = ALLDAYS
|
|
||||||
case "count":
|
|
||||||
view = COUNT
|
|
||||||
}
|
|
||||||
result, e := app.db.request(&ViewsRequest{
|
|
||||||
view: view,
|
|
||||||
from: queries.Get("from"),
|
|
||||||
fromRel: queries.Get("fromrel"),
|
|
||||||
to: queries.Get("to"),
|
|
||||||
toRel: queries.Get("torel"),
|
|
||||||
url: queries.Get("url"),
|
|
||||||
ref: queries.Get("ref"),
|
|
||||||
ua: queries.Get("ua"),
|
|
||||||
ordercol: strings.ToLower(queries.Get("ordercol")),
|
|
||||||
order: strings.ToUpper(queries.Get("order")),
|
|
||||||
limit: queries.Get("limit"),
|
|
||||||
})
|
|
||||||
if e != nil {
|
|
||||||
fmt.Println("Database request failed:", e)
|
|
||||||
w.WriteHeader(500)
|
|
||||||
} else if result != nil {
|
|
||||||
w.Header().Set("Cache-Control", "max-age=0")
|
|
||||||
switch queries.Get("format") {
|
|
||||||
case "json":
|
|
||||||
sendJsonResponse(result, w)
|
|
||||||
case "chart":
|
|
||||||
sendChartResponse(result, w)
|
|
||||||
default: // "plain"
|
|
||||||
sendPlainResponse(result, w)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
for _, row := range result {
|
|
||||||
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
jsonBytes, _ := json.Marshal(result)
|
|
||||||
_, _ = fmt.Fprintln(w, string(jsonBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
|
||||||
labels := make([]string, len(result))
|
|
||||||
values := make([]int, len(result))
|
|
||||||
for i, row := range result {
|
|
||||||
labels[i] = row.First
|
|
||||||
values[i] = row.Second
|
|
||||||
}
|
|
||||||
chartJSString, e := app.staticBox.FindString("Chart.min.js")
|
|
||||||
if e != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data := struct {
|
|
||||||
Labels []string
|
|
||||||
Values []int
|
|
||||||
ChartJS template.JS
|
|
||||||
}{
|
|
||||||
Labels: labels,
|
|
||||||
Values: values,
|
|
||||||
ChartJS: template.JS(chartJSString),
|
|
||||||
}
|
|
||||||
chartTemplateString, e := app.staticBox.FindString("chart.html")
|
|
||||||
if e != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t, e := template.New("chart").Parse(chartTemplateString)
|
|
||||||
if e != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = t.Execute(w, data)
|
|
||||||
}
|
|
||||||
|
|
26
main_test.go
26
main_test.go
|
@ -1,26 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHelloResponseHandler(t *testing.T) {
|
|
||||||
t.Run("Hello response", func(t *testing.T) {
|
|
||||||
req, err := http.NewRequest("GET", "/", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
handler := http.HandlerFunc(HelloResponseHandler)
|
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
if status := rr.Code; status != http.StatusOK {
|
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
|
||||||
}
|
|
||||||
expected := "Hello from KISSS"
|
|
||||||
if rr.Body.String() != expected {
|
|
||||||
t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -25,7 +25,7 @@ type report struct {
|
||||||
TGUserId int64 `json:"tgUserId"`
|
TGUserId int64 `json:"tgUserId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupReports() {
|
func startReports() {
|
||||||
scheduler := clockwork.NewScheduler()
|
scheduler := clockwork.NewScheduler()
|
||||||
for _, r := range appConfig.Reports {
|
for _, r := range appConfig.Reports {
|
||||||
scheduledReport := r
|
scheduledReport := r
|
||||||
|
@ -33,7 +33,7 @@ func setupReports() {
|
||||||
executeReport(&scheduledReport)
|
executeReport(&scheduledReport)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
go scheduler.Run()
|
scheduler.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func executeReport(r *report) {
|
func executeReport(r *report) {
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/kis3/kis3/helpers"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initStatsRouter() {
|
||||||
|
app.router.HandleFunc("/stats", StatsHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func StatsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Require authentication
|
||||||
|
if appConfig.statsAuth() {
|
||||||
|
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Do request
|
||||||
|
queries := r.URL.Query()
|
||||||
|
view := PAGES
|
||||||
|
switch strings.ToLower(queries.Get("view")) {
|
||||||
|
case "pages":
|
||||||
|
view = PAGES
|
||||||
|
case "referrers":
|
||||||
|
view = REFERRERS
|
||||||
|
case "useragents":
|
||||||
|
view = USERAGENTS
|
||||||
|
case "useragentnames":
|
||||||
|
view = USERAGENTNAMES
|
||||||
|
case "hours":
|
||||||
|
view = HOURS
|
||||||
|
case "days":
|
||||||
|
view = DAYS
|
||||||
|
case "weeks":
|
||||||
|
view = WEEKS
|
||||||
|
case "months":
|
||||||
|
view = MONTHS
|
||||||
|
case "allhours":
|
||||||
|
view = ALLHOURS
|
||||||
|
case "alldays":
|
||||||
|
view = ALLDAYS
|
||||||
|
case "count":
|
||||||
|
view = COUNT
|
||||||
|
}
|
||||||
|
result, e := request(&ViewsRequest{
|
||||||
|
view: view,
|
||||||
|
from: queries.Get("from"),
|
||||||
|
fromRel: queries.Get("fromrel"),
|
||||||
|
to: queries.Get("to"),
|
||||||
|
toRel: queries.Get("torel"),
|
||||||
|
url: queries.Get("url"),
|
||||||
|
ref: queries.Get("ref"),
|
||||||
|
ua: queries.Get("ua"),
|
||||||
|
ordercol: strings.ToLower(queries.Get("ordercol")),
|
||||||
|
order: strings.ToUpper(queries.Get("order")),
|
||||||
|
limit: queries.Get("limit"),
|
||||||
|
})
|
||||||
|
if e != nil {
|
||||||
|
fmt.Println("Database request failed:", e)
|
||||||
|
w.WriteHeader(500)
|
||||||
|
} else if result != nil {
|
||||||
|
w.Header().Set("Cache-Control", "max-age=0")
|
||||||
|
switch queries.Get("format") {
|
||||||
|
case "json":
|
||||||
|
sendJsonResponse(result, w)
|
||||||
|
case "chart":
|
||||||
|
sendChartResponse(result, w)
|
||||||
|
default: // "plain"
|
||||||
|
sendPlainResponse(result, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
for _, row := range result {
|
||||||
|
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
jsonBytes, _ := json.Marshal(result)
|
||||||
|
_, _ = fmt.Fprintln(w, string(jsonBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
|
||||||
|
labels := make([]string, len(result))
|
||||||
|
values := make([]int, len(result))
|
||||||
|
for i, row := range result {
|
||||||
|
labels[i] = row.First
|
||||||
|
values[i] = row.Second
|
||||||
|
}
|
||||||
|
chartJSString, e := app.staticBox.FindString("Chart.min.js")
|
||||||
|
if e != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := struct {
|
||||||
|
Labels []string
|
||||||
|
Values []int
|
||||||
|
ChartJS template.JS
|
||||||
|
}{
|
||||||
|
Labels: labels,
|
||||||
|
Values: values,
|
||||||
|
ChartJS: template.JS(chartJSString),
|
||||||
|
}
|
||||||
|
chartTemplateString, e := app.staticBox.FindString("chart.html")
|
||||||
|
if e != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t, e := template.New("chart").Parse(chartTemplateString)
|
||||||
|
if e != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = t.Execute(w, data)
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/handlers"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initTrackingRouter() {
|
||||||
|
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"}))
|
||||||
|
viewRouter := app.router.Path("/view").Subrouter()
|
||||||
|
viewRouter.Use(corsHandler)
|
||||||
|
viewRouter.Path("").HandlerFunc(TrackingHandler)
|
||||||
|
scriptRouter := app.router.Path("/kis3.js").Subrouter()
|
||||||
|
scriptRouter.Use(corsHandler)
|
||||||
|
scriptRouter.HandleFunc("", TrackingScriptHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
|
||||||
|
url := r.URL.Query().Get("url")
|
||||||
|
ref := r.URL.Query().Get("ref")
|
||||||
|
ua := r.Header.Get("User-Agent")
|
||||||
|
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
|
||||||
|
go trackView(url, ref, ua) // run with goroutine for awesome speed!
|
||||||
|
_, _ = fmt.Fprint(w, "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/javascript")
|
||||||
|
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
|
||||||
|
filename := "kis3.js"
|
||||||
|
file, err := app.staticBox.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.ServeContent(w, r, filename, stat.ModTime(), file)
|
||||||
|
}
|
Reference in New Issue