jlelse
/
kis3
Archived
1
Fork 0

Replace environment var based config with config file

This commit is contained in:
Jan-Lukas Else 2019-04-30 15:27:42 +02:00
parent 903dab597f
commit 45d04b9af4
5 changed files with 50 additions and 169 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.idea/
data/
config.json

View File

@ -1,68 +1,47 @@
package main
import (
"os"
"strconv"
"encoding/json"
"flag"
"io/ioutil"
)
type config struct {
port string
dnt bool
dbPath string
statsAuth bool
statsUsername string
statsPassword string
Port string `json:"port"`
Dnt bool `json:"dnt"`
DbPath string `json:"dbPath"`
StatsUsername string `json:"statsUsername"`
StatsPassword string `json:"statsPassword"`
}
var (
appConfig = &config{}
appConfig = &config{
Port: "8080",
Dnt: true,
DbPath: "data/kis3.db",
StatsUsername: "",
StatsPassword: "",
}
)
func init() {
appConfig.port = port()
appConfig.dnt = dnt()
appConfig.dbPath = dbPath()
appConfig.statsUsername = statsUsername()
appConfig.statsPassword = statsPassword()
appConfig.statsAuth = statsAuth(appConfig)
parseConfigFile(appConfig)
}
func port() string {
port := os.Getenv("PORT")
if len(port) != 0 {
return port
} else {
return "8080"
}
}
func dnt() bool {
dnt := os.Getenv("DNT")
dntBool, e := strconv.ParseBool(dnt)
func parseConfigFile(appConfig *config) {
configFile := flag.String("c", "config.json", "Config file")
flag.Parse()
configJson, e := ioutil.ReadFile(*configFile)
if e != nil {
dntBool = true
return
}
return dntBool
}
func dbPath() (dbPath string) {
dbPath = os.Getenv("DB_PATH")
if len(dbPath) == 0 {
dbPath = "data/kis3.db"
e = json.Unmarshal([]byte(configJson), appConfig)
if e != nil {
return
}
return
}
func statsUsername() (username string) {
username = os.Getenv("STATS_USERNAME")
return
}
func statsPassword() (password string) {
password = os.Getenv("STATS_PASSWORD")
return
}
func statsAuth(ac *config) bool {
return len(ac.statsUsername) > 0 && len(ac.statsPassword) > 0
func (ac *config) statsAuth() bool {
return len(ac.StatsUsername) > 0 && len(ac.StatsPassword) > 0
}

View File

@ -1,131 +1,32 @@
package main
import (
"os"
"testing"
)
func Test_port(t *testing.T) {
func Test_config_statsAuth(t *testing.T) {
type fields struct {
StatsUsername string
StatsPassword string
}
tests := []struct {
name string
envVar string
want string
}{
{name: "default", envVar: "", want: "8080"},
{name: "custom", envVar: "1234", want: "1234"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("PORT", tt.envVar)
if got := port(); got != tt.want {
t.Errorf("port() = %v, want %v", got, tt.want)
}
})
}
}
func Test_dnt(t *testing.T) {
tests := []struct {
name string
envVar string
fields fields
want bool
}{
{name: "default", envVar: "", want: true},
{envVar: "true", want: true},
{envVar: "t", want: true},
{envVar: "TRUE", want: true},
{envVar: "1", want: true},
{envVar: "false", want: false},
{envVar: "f", want: false},
{envVar: "0", want: false},
{envVar: "abc", want: true},
{"No username nor password", fields{"", ""}, false},
{"Only username", fields{"abc", ""}, false},
{"Only password", fields{"", "abc"}, false},
{"Username and password", fields{"abc", "abc"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("DNT", tt.envVar)
if got := dnt(); got != tt.want {
t.Errorf("dnt() = %v, want %v", got, tt.want)
}
})
}
}
func Test_dbPath(t *testing.T) {
tests := []struct {
name string
envVar string
wantDbPath string
}{
{name: "default", envVar: "", wantDbPath: "data/kis3.db"},
{envVar: "kis3.db", wantDbPath: "kis3.db"},
{envVar: "data.db", wantDbPath: "data.db"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("DB_PATH", tt.envVar)
if gotDbPath := dbPath(); gotDbPath != tt.wantDbPath {
t.Errorf("dbPath() = %v, want %v", gotDbPath, tt.wantDbPath)
}
})
}
}
func Test_statsUsername(t *testing.T) {
tests := []struct {
name string
envVar string
wantUsername string
}{
{name: "default", envVar: "", wantUsername: ""},
{envVar: "abc", wantUsername: "abc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("STATS_USERNAME", tt.envVar)
if gotUsername := statsUsername(); gotUsername != tt.wantUsername {
t.Errorf("statsUsername() = %v, want %v", gotUsername, tt.wantUsername)
}
})
}
}
func Test_statsPassword(t *testing.T) {
tests := []struct {
name string
envVar string
wantPassword string
}{
{name: "default", envVar: "", wantPassword: ""},
{envVar: "def", wantPassword: "def"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_ = os.Setenv("STATS_PASSWORD", tt.envVar)
if gotPassword := statsPassword(); gotPassword != tt.wantPassword {
t.Errorf("statsPassword() = %v, want %v", gotPassword, tt.wantPassword)
}
})
}
}
func Test_statsAuth(t *testing.T) {
type args struct {
ac *config
}
tests := []struct {
name string
args args
want bool
}{
{name: "default", args: struct{ ac *config }{ac: &config{}}, want: false},
{name: "only username set", args: struct{ ac *config }{ac: &config{statsUsername: "abc"}}, want: false},
{name: "only password set", args: struct{ ac *config }{ac: &config{statsPassword: "def"}}, want: false},
{name: "username and password set", args: struct{ ac *config }{ac: &config{statsUsername: "abc", statsPassword: "def"}}, want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := statsAuth(tt.args.ac); got != tt.want {
t.Errorf("statsAuth() = %v, want %v", got, tt.want)
ac := &config{
StatsUsername: tt.fields.StatsUsername,
StatsPassword: tt.fields.StatsPassword,
}
if got := ac.statsAuth(); got != tt.want {
t.Errorf("config.statsAuth() = %v, want %v", got, tt.want)
}
})
}

View File

@ -19,10 +19,10 @@ type Database struct {
func initDatabase() (database *Database, e error) {
database = &Database{}
if _, err := os.Stat(appConfig.dbPath); os.IsNotExist(err) {
_ = os.MkdirAll(filepath.Dir(appConfig.dbPath), os.ModePerm)
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
}
database.sqlDB, e = sql.Open("sqlite3", appConfig.dbPath)
database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
if e != nil {
return
}

View File

@ -60,7 +60,7 @@ func setupRouter() {
}
func startListening() {
port := appConfig.port
port := appConfig.Port
addr := ":" + port
fmt.Printf("Listening to %s\n", addr)
log.Fatal(http.ListenAndServe(addr, app.router))
@ -71,7 +71,7 @@ func trackView(w http.ResponseWriter, r *http.Request) {
url := r.URL.Query().Get("url")
ref := r.URL.Query().Get("ref")
ua := r.Header.Get("User-Agent")
if !(r.Header.Get("DNT") == "1" && appConfig.dnt) {
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
_, _ = fmt.Fprint(w, "true")
}
@ -99,8 +99,8 @@ func serveTrackingScript(w http.ResponseWriter, r *http.Request) {
func requestStats(w http.ResponseWriter, r *http.Request) {
// Require authentication
if appConfig.statsAuth {
if !helpers.CheckAuth(w, r, appConfig.statsUsername, appConfig.statsPassword) {
if appConfig.statsAuth() {
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
return
}
}