jlelse
/
kis3
Archived
1
Fork 0

Replace environment var based config with config file

This commit is contained in:
Jan-Lukas Else 2019-04-30 15:27:42 +02:00
parent 903dab597f
commit 45d04b9af4
5 changed files with 50 additions and 169 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.idea/ .idea/
data/ data/
config.json

View File

@ -1,68 +1,47 @@
package main package main
import ( import (
"os" "encoding/json"
"strconv" "flag"
"io/ioutil"
) )
type config struct { type config struct {
port string Port string `json:"port"`
dnt bool Dnt bool `json:"dnt"`
dbPath string DbPath string `json:"dbPath"`
statsAuth bool StatsUsername string `json:"statsUsername"`
statsUsername string StatsPassword string `json:"statsPassword"`
statsPassword string
} }
var ( var (
appConfig = &config{} appConfig = &config{
Port: "8080",
Dnt: true,
DbPath: "data/kis3.db",
StatsUsername: "",
StatsPassword: "",
}
) )
func init() { func init() {
appConfig.port = port() parseConfigFile(appConfig)
appConfig.dnt = dnt()
appConfig.dbPath = dbPath()
appConfig.statsUsername = statsUsername()
appConfig.statsPassword = statsPassword()
appConfig.statsAuth = statsAuth(appConfig)
} }
func port() string { func parseConfigFile(appConfig *config) {
port := os.Getenv("PORT") configFile := flag.String("c", "config.json", "Config file")
if len(port) != 0 { flag.Parse()
return port configJson, e := ioutil.ReadFile(*configFile)
} else {
return "8080"
}
}
func dnt() bool {
dnt := os.Getenv("DNT")
dntBool, e := strconv.ParseBool(dnt)
if e != nil { if e != nil {
dntBool = true return
} }
return dntBool e = json.Unmarshal([]byte(configJson), appConfig)
} if e != nil {
return
func dbPath() (dbPath string) {
dbPath = os.Getenv("DB_PATH")
if len(dbPath) == 0 {
dbPath = "data/kis3.db"
} }
return return
} }
func statsUsername() (username string) { func (ac *config) statsAuth() bool {
username = os.Getenv("STATS_USERNAME") return len(ac.StatsUsername) > 0 && len(ac.StatsPassword) > 0
return
}
func statsPassword() (password string) {
password = os.Getenv("STATS_PASSWORD")
return
}
func statsAuth(ac *config) bool {
return len(ac.statsUsername) > 0 && len(ac.statsPassword) > 0
} }

View File

@ -1,131 +1,32 @@
package main package main
import ( import (
"os"
"testing" "testing"
) )
func Test_port(t *testing.T) { func Test_config_statsAuth(t *testing.T) {
type fields struct {
StatsUsername string
StatsPassword string
}
tests := []struct { tests := []struct {
name string name string
envVar string fields fields
want string
}{
{name: "default", envVar: "", want: "8080"},
{name: "custom", envVar: "1234", want: "1234"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("PORT", tt.envVar)
if got := port(); got != tt.want {
t.Errorf("port() = %v, want %v", got, tt.want)
}
})
}
}
func Test_dnt(t *testing.T) {
tests := []struct {
name string
envVar string
want bool want bool
}{ }{
{name: "default", envVar: "", want: true}, {"No username nor password", fields{"", ""}, false},
{envVar: "true", want: true}, {"Only username", fields{"abc", ""}, false},
{envVar: "t", want: true}, {"Only password", fields{"", "abc"}, false},
{envVar: "TRUE", want: true}, {"Username and password", fields{"abc", "abc"}, true},
{envVar: "1", want: true},
{envVar: "false", want: false},
{envVar: "f", want: false},
{envVar: "0", want: false},
{envVar: "abc", want: true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("DNT", tt.envVar) ac := &config{
if got := dnt(); got != tt.want { StatsUsername: tt.fields.StatsUsername,
t.Errorf("dnt() = %v, want %v", got, tt.want) StatsPassword: tt.fields.StatsPassword,
} }
}) if got := ac.statsAuth(); got != tt.want {
} t.Errorf("config.statsAuth() = %v, want %v", got, tt.want)
}
func Test_dbPath(t *testing.T) {
tests := []struct {
name string
envVar string
wantDbPath string
}{
{name: "default", envVar: "", wantDbPath: "data/kis3.db"},
{envVar: "kis3.db", wantDbPath: "kis3.db"},
{envVar: "data.db", wantDbPath: "data.db"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("DB_PATH", tt.envVar)
if gotDbPath := dbPath(); gotDbPath != tt.wantDbPath {
t.Errorf("dbPath() = %v, want %v", gotDbPath, tt.wantDbPath)
}
})
}
}
func Test_statsUsername(t *testing.T) {
tests := []struct {
name string
envVar string
wantUsername string
}{
{name: "default", envVar: "", wantUsername: ""},
{envVar: "abc", wantUsername: "abc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("STATS_USERNAME", tt.envVar)
if gotUsername := statsUsername(); gotUsername != tt.wantUsername {
t.Errorf("statsUsername() = %v, want %v", gotUsername, tt.wantUsername)
}
})
}
}
func Test_statsPassword(t *testing.T) {
tests := []struct {
name string
envVar string
wantPassword string
}{
{name: "default", envVar: "", wantPassword: ""},
{envVar: "def", wantPassword: "def"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("STATS_PASSWORD", tt.envVar)
if gotPassword := statsPassword(); gotPassword != tt.wantPassword {
t.Errorf("statsPassword() = %v, want %v", gotPassword, tt.wantPassword)
}
})
}
}
func Test_statsAuth(t *testing.T) {
type args struct {
ac *config
}
tests := []struct {
name string
args args
want bool
}{
{name: "default", args: struct{ ac *config }{ac: &config{}}, want: false},
{name: "only username set", args: struct{ ac *config }{ac: &config{statsUsername: "abc"}}, want: false},
{name: "only password set", args: struct{ ac *config }{ac: &config{statsPassword: "def"}}, want: false},
{name: "username and password set", args: struct{ ac *config }{ac: &config{statsUsername: "abc", statsPassword: "def"}}, want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := statsAuth(tt.args.ac); got != tt.want {
t.Errorf("statsAuth() = %v, want %v", got, tt.want)
} }
}) })
} }

View File

@ -19,10 +19,10 @@ type Database struct {
func initDatabase() (database *Database, e error) { func initDatabase() (database *Database, e error) {
database = &Database{} database = &Database{}
if _, err := os.Stat(appConfig.dbPath); os.IsNotExist(err) { if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
_ = os.MkdirAll(filepath.Dir(appConfig.dbPath), os.ModePerm) _ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
} }
database.sqlDB, e = sql.Open("sqlite3", appConfig.dbPath) database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
if e != nil { if e != nil {
return return
} }

View File

@ -60,7 +60,7 @@ func setupRouter() {
} }
func startListening() { func startListening() {
port := appConfig.port port := appConfig.Port
addr := ":" + port addr := ":" + port
fmt.Printf("Listening to %s\n", addr) fmt.Printf("Listening to %s\n", addr)
log.Fatal(http.ListenAndServe(addr, app.router)) log.Fatal(http.ListenAndServe(addr, app.router))
@ -71,7 +71,7 @@ func trackView(w http.ResponseWriter, r *http.Request) {
url := r.URL.Query().Get("url") url := r.URL.Query().Get("url")
ref := r.URL.Query().Get("ref") ref := r.URL.Query().Get("ref")
ua := r.Header.Get("User-Agent") ua := r.Header.Get("User-Agent")
if !(r.Header.Get("DNT") == "1" && appConfig.dnt) { if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed! go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
_, _ = fmt.Fprint(w, "true") _, _ = fmt.Fprint(w, "true")
} }
@ -99,8 +99,8 @@ func serveTrackingScript(w http.ResponseWriter, r *http.Request) {
func requestStats(w http.ResponseWriter, r *http.Request) { func requestStats(w http.ResponseWriter, r *http.Request) {
// Require authentication // Require authentication
if appConfig.statsAuth { if appConfig.statsAuth() {
if !helpers.CheckAuth(w, r, appConfig.statsUsername, appConfig.statsPassword) { if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
return return
} }
} }