jlelse
/
kis3
Archived
1
Fork 0
This commit is contained in:
Jan-Lukas Else 2019-05-26 22:32:02 +02:00
parent 86152eafb9
commit 94251f1968
8 changed files with 197 additions and 207 deletions

View File

@ -27,7 +27,7 @@ var (
} }
) )
func init() { func initConfig() {
parseConfigFile(appConfig) parseConfigFile(appConfig)
// Replace values that are set via environment vars (to make it compatible with old method) // Replace values that are set via environment vars (to make it compatible with old method)
overwriteEnvVarValues(appConfig) overwriteEnvVarValues(appConfig)

View File

@ -19,17 +19,20 @@ type Database struct {
trackingStmt *sql.Stmt trackingStmt *sql.Stmt
} }
func initDatabase() (database *Database, e error) { var (
database = &Database{} db = &Database{}
)
func initDatabase() (e error) {
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) { if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm) _ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
} }
database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath) db.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
if e != nil { if e != nil {
return return
} }
e = migrateDatabase(database.sqlDB) e = migrateDatabase(db.sqlDB)
database.trackingStmt, e = database.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)") db.trackingStmt, e = db.sqlDB.Prepare("insert into views(url, ref, useragent) values(:url, :ref, :ua)")
if e != nil { if e != nil {
return return
} }
@ -46,7 +49,7 @@ func migrateDatabase(database *sql.DB) (e error) {
// Tracking // Tracking
func (db *Database) trackView(urlString string, ref string, ua string) { func trackView(urlString string, ref string, ua string) {
if len(urlString) == 0 { if len(urlString) == 0 {
// Don't track empty urls // Don't track empty urls
return return
@ -104,7 +107,7 @@ type RequestResultRow struct {
Second int `json:"second"` Second int `json:"second"`
} }
func (db *Database) request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) { func request(request *ViewsRequest) (resultRows []*RequestResultRow, e error) {
statement, parameters := request.buildStatement() statement, parameters := request.buildStatement()
namedArgs := make([]interface{}, len(parameters)) namedArgs := make([]interface{}, len(parameters))
for i, v := range parameters { for i, v := range parameters {

View File

@ -5,8 +5,6 @@ import (
"testing" "testing"
) )
func TestViewsRequest_buildDateTimeFilter(t *testing.T) { func TestViewsRequest_buildDateTimeFilter(t *testing.T) {
t.Run("No DateTime filter", func(t *testing.T) { t.Run("No DateTime filter", func(t *testing.T) {
request := &ViewsRequest{ request := &ViewsRequest{

187
main.go
View File

@ -1,21 +1,17 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gobuffalo/packr/v2" "github.com/gobuffalo/packr/v2"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/kis3/kis3/helpers"
"html/template"
"log" "log"
"net/http" "net/http"
"strconv" "os"
"strings" "os/signal"
"syscall"
) )
type kis3 struct { type kis3 struct {
db *Database
router *mux.Router router *mux.Router
staticBox *packr.Box staticBox *packr.Box
} }
@ -27,181 +23,34 @@ var (
) )
func init() { func init() {
e := setupDB() initConfig()
e := initDatabase()
if e != nil { if e != nil {
log.Fatal("Database setup failed:", e) log.Fatal("Database setup failed:", e)
} }
setupRouter() initRouter()
setupReports()
} }
func main() { func main() {
startListening() go startListeningToWeb()
go startReports()
// Graceful stop
var gracefulStop = make(chan os.Signal, 1)
signal.Notify(gracefulStop, os.Interrupt, syscall.SIGTERM)
sig := <-gracefulStop
fmt.Printf("Received signal: %+v", sig)
os.Exit(0)
} }
func setupDB() (e error) { func initRouter() {
app.db, e = initDatabase()
return
}
func setupRouter() {
app.router = mux.NewRouter() app.router = mux.NewRouter()
initStatsRouter()
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"})) initTrackingRouter()
viewRouter := app.router.PathPrefix("/view").Subrouter()
viewRouter.Use(corsHandler)
viewRouter.Path("").HandlerFunc(TrackingHandler)
app.router.HandleFunc("/stats", StatsHandler)
staticRouter := app.router.PathPrefix("").Subrouter()
staticRouter.Use(corsHandler)
staticRouter.HandleFunc("/kis3.js", TrackingScriptHandler)
staticRouter.PathPrefix("").Handler(http.HandlerFunc(HelloResponseHandler))
} }
func startListening() { func startListeningToWeb() {
port := appConfig.Port port := appConfig.Port
addr := ":" + port addr := ":" + port
fmt.Printf("Listening to %s\n", addr) fmt.Printf("Listening to %s\n", addr)
log.Fatal(http.ListenAndServe(addr, app.router)) log.Fatal(http.ListenAndServe(addr, app.router))
} }
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
url := r.URL.Query().Get("url")
ref := r.URL.Query().Get("ref")
ua := r.Header.Get("User-Agent")
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
_, _ = fmt.Fprint(w, "true")
}
}
func HelloResponseHandler(w http.ResponseWriter, _ *http.Request) {
_, _ = fmt.Fprint(w, "Hello from KISSS")
}
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/javascript")
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
filename := "kis3.js"
file, err := app.staticBox.Open(filename)
if err != nil {
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return
}
http.ServeContent(w, r, filename, stat.ModTime(), file)
}
func StatsHandler(w http.ResponseWriter, r *http.Request) {
// Require authentication
if appConfig.statsAuth() {
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
return
}
}
// Do request
queries := r.URL.Query()
view := PAGES
switch strings.ToLower(queries.Get("view")) {
case "pages":
view = PAGES
case "referrers":
view = REFERRERS
case "useragents":
view = USERAGENTS
case "useragentnames":
view = USERAGENTNAMES
case "hours":
view = HOURS
case "days":
view = DAYS
case "weeks":
view = WEEKS
case "months":
view = MONTHS
case "allhours":
view = ALLHOURS
case "alldays":
view = ALLDAYS
case "count":
view = COUNT
}
result, e := app.db.request(&ViewsRequest{
view: view,
from: queries.Get("from"),
fromRel: queries.Get("fromrel"),
to: queries.Get("to"),
toRel: queries.Get("torel"),
url: queries.Get("url"),
ref: queries.Get("ref"),
ua: queries.Get("ua"),
ordercol: strings.ToLower(queries.Get("ordercol")),
order: strings.ToUpper(queries.Get("order")),
limit: queries.Get("limit"),
})
if e != nil {
fmt.Println("Database request failed:", e)
w.WriteHeader(500)
} else if result != nil {
w.Header().Set("Cache-Control", "max-age=0")
switch queries.Get("format") {
case "json":
sendJsonResponse(result, w)
case "chart":
sendChartResponse(result, w)
default: // "plain"
sendPlainResponse(result, w)
}
}
}
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain")
for _, row := range result {
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
}
}
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
jsonBytes, _ := json.Marshal(result)
_, _ = fmt.Fprintln(w, string(jsonBytes))
}
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
labels := make([]string, len(result))
values := make([]int, len(result))
for i, row := range result {
labels[i] = row.First
values[i] = row.Second
}
chartJSString, e := app.staticBox.FindString("Chart.min.js")
if e != nil {
return
}
data := struct {
Labels []string
Values []int
ChartJS template.JS
}{
Labels: labels,
Values: values,
ChartJS: template.JS(chartJSString),
}
chartTemplateString, e := app.staticBox.FindString("chart.html")
if e != nil {
return
}
t, e := template.New("chart").Parse(chartTemplateString)
if e != nil {
return
}
_ = t.Execute(w, data)
}

View File

@ -1,26 +0,0 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestHelloResponseHandler(t *testing.T) {
t.Run("Hello response", func(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(HelloResponseHandler)
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
expected := "Hello from KISSS"
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected)
}
})
}

View File

@ -25,7 +25,7 @@ type report struct {
TGUserId int64 `json:"tgUserId"` TGUserId int64 `json:"tgUserId"`
} }
func setupReports() { func startReports() {
scheduler := clockwork.NewScheduler() scheduler := clockwork.NewScheduler()
for _, r := range appConfig.Reports { for _, r := range appConfig.Reports {
scheduledReport := r scheduledReport := r
@ -33,7 +33,7 @@ func setupReports() {
executeReport(&scheduledReport) executeReport(&scheduledReport)
}) })
} }
go scheduler.Run() scheduler.Run()
} }
func executeReport(r *report) { func executeReport(r *report) {

122
stats.go Normal file
View File

@ -0,0 +1,122 @@
package main
import (
"encoding/json"
"fmt"
"github.com/kis3/kis3/helpers"
"html/template"
"net/http"
"strconv"
"strings"
)
func initStatsRouter() {
app.router.HandleFunc("/stats", StatsHandler)
}
func StatsHandler(w http.ResponseWriter, r *http.Request) {
// Require authentication
if appConfig.statsAuth() {
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
return
}
}
// Do request
queries := r.URL.Query()
view := PAGES
switch strings.ToLower(queries.Get("view")) {
case "pages":
view = PAGES
case "referrers":
view = REFERRERS
case "useragents":
view = USERAGENTS
case "useragentnames":
view = USERAGENTNAMES
case "hours":
view = HOURS
case "days":
view = DAYS
case "weeks":
view = WEEKS
case "months":
view = MONTHS
case "allhours":
view = ALLHOURS
case "alldays":
view = ALLDAYS
case "count":
view = COUNT
}
result, e := request(&ViewsRequest{
view: view,
from: queries.Get("from"),
fromRel: queries.Get("fromrel"),
to: queries.Get("to"),
toRel: queries.Get("torel"),
url: queries.Get("url"),
ref: queries.Get("ref"),
ua: queries.Get("ua"),
ordercol: strings.ToLower(queries.Get("ordercol")),
order: strings.ToUpper(queries.Get("order")),
limit: queries.Get("limit"),
})
if e != nil {
fmt.Println("Database request failed:", e)
w.WriteHeader(500)
} else if result != nil {
w.Header().Set("Cache-Control", "max-age=0")
switch queries.Get("format") {
case "json":
sendJsonResponse(result, w)
case "chart":
sendChartResponse(result, w)
default: // "plain"
sendPlainResponse(result, w)
}
}
}
func sendPlainResponse(result []*RequestResultRow, w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain")
for _, row := range result {
_, _ = fmt.Fprintln(w, (*row).First+": "+strconv.Itoa((*row).Second))
}
}
func sendJsonResponse(result []*RequestResultRow, w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
jsonBytes, _ := json.Marshal(result)
_, _ = fmt.Fprintln(w, string(jsonBytes))
}
func sendChartResponse(result []*RequestResultRow, w http.ResponseWriter) {
labels := make([]string, len(result))
values := make([]int, len(result))
for i, row := range result {
labels[i] = row.First
values[i] = row.Second
}
chartJSString, e := app.staticBox.FindString("Chart.min.js")
if e != nil {
return
}
data := struct {
Labels []string
Values []int
ChartJS template.JS
}{
Labels: labels,
Values: values,
ChartJS: template.JS(chartJSString),
}
chartTemplateString, e := app.staticBox.FindString("chart.html")
if e != nil {
return
}
t, e := template.New("chart").Parse(chartTemplateString)
if e != nil {
return
}
_ = t.Execute(w, data)
}

44
tracking.go Normal file
View File

@ -0,0 +1,44 @@
package main
import (
"fmt"
"github.com/gorilla/handlers"
"net/http"
)
func initTrackingRouter() {
corsHandler := handlers.CORS(handlers.AllowedOrigins([]string{"*"}))
viewRouter := app.router.Path("/view").Subrouter()
viewRouter.Use(corsHandler)
viewRouter.Path("").HandlerFunc(TrackingHandler)
scriptRouter := app.router.Path("/kis3.js").Subrouter()
scriptRouter.Use(corsHandler)
scriptRouter.HandleFunc("", TrackingScriptHandler)
}
func TrackingHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
url := r.URL.Query().Get("url")
ref := r.URL.Query().Get("ref")
ua := r.Header.Get("User-Agent")
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
go trackView(url, ref, ua) // run with goroutine for awesome speed!
_, _ = fmt.Fprint(w, "true")
}
}
func TrackingScriptHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/javascript")
w.Header().Set("Cache-Control", "public, max-age=432000") // 5 days
filename := "kis3.js"
file, err := app.staticBox.Open(filename)
if err != nil {
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return
}
http.ServeContent(w, r, filename, stat.ModTime(), file)
}