Replace environment var based config with config file
This commit is contained in:
parent
903dab597f
commit
45d04b9af4
|
@ -1,2 +1,3 @@
|
|||
.idea/
|
||||
data/
|
||||
config.json
|
||||
|
|
73
config.go
73
config.go
|
@ -1,68 +1,47 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
port string
|
||||
dnt bool
|
||||
dbPath string
|
||||
statsAuth bool
|
||||
statsUsername string
|
||||
statsPassword string
|
||||
Port string `json:"port"`
|
||||
Dnt bool `json:"dnt"`
|
||||
DbPath string `json:"dbPath"`
|
||||
StatsUsername string `json:"statsUsername"`
|
||||
StatsPassword string `json:"statsPassword"`
|
||||
}
|
||||
|
||||
var (
|
||||
appConfig = &config{}
|
||||
appConfig = &config{
|
||||
Port: "8080",
|
||||
Dnt: true,
|
||||
DbPath: "data/kis3.db",
|
||||
StatsUsername: "",
|
||||
StatsPassword: "",
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
appConfig.port = port()
|
||||
appConfig.dnt = dnt()
|
||||
appConfig.dbPath = dbPath()
|
||||
appConfig.statsUsername = statsUsername()
|
||||
appConfig.statsPassword = statsPassword()
|
||||
appConfig.statsAuth = statsAuth(appConfig)
|
||||
parseConfigFile(appConfig)
|
||||
}
|
||||
|
||||
func port() string {
|
||||
port := os.Getenv("PORT")
|
||||
if len(port) != 0 {
|
||||
return port
|
||||
} else {
|
||||
return "8080"
|
||||
}
|
||||
}
|
||||
|
||||
func dnt() bool {
|
||||
dnt := os.Getenv("DNT")
|
||||
dntBool, e := strconv.ParseBool(dnt)
|
||||
func parseConfigFile(appConfig *config) {
|
||||
configFile := flag.String("c", "config.json", "Config file")
|
||||
flag.Parse()
|
||||
configJson, e := ioutil.ReadFile(*configFile)
|
||||
if e != nil {
|
||||
dntBool = true
|
||||
return
|
||||
}
|
||||
return dntBool
|
||||
}
|
||||
|
||||
func dbPath() (dbPath string) {
|
||||
dbPath = os.Getenv("DB_PATH")
|
||||
if len(dbPath) == 0 {
|
||||
dbPath = "data/kis3.db"
|
||||
e = json.Unmarshal([]byte(configJson), appConfig)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func statsUsername() (username string) {
|
||||
username = os.Getenv("STATS_USERNAME")
|
||||
return
|
||||
}
|
||||
|
||||
func statsPassword() (password string) {
|
||||
password = os.Getenv("STATS_PASSWORD")
|
||||
return
|
||||
}
|
||||
|
||||
func statsAuth(ac *config) bool {
|
||||
return len(ac.statsUsername) > 0 && len(ac.statsPassword) > 0
|
||||
func (ac *config) statsAuth() bool {
|
||||
return len(ac.StatsUsername) > 0 && len(ac.StatsPassword) > 0
|
||||
}
|
||||
|
|
131
config_test.go
131
config_test.go
|
@ -1,131 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_port(t *testing.T) {
|
||||
func Test_config_statsAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
StatsUsername string
|
||||
StatsPassword string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
want string
|
||||
}{
|
||||
{name: "default", envVar: "", want: "8080"},
|
||||
{name: "custom", envVar: "1234", want: "1234"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_ = os.Setenv("PORT", tt.envVar)
|
||||
if got := port(); got != tt.want {
|
||||
t.Errorf("port() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dnt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
fields fields
|
||||
want bool
|
||||
}{
|
||||
{name: "default", envVar: "", want: true},
|
||||
{envVar: "true", want: true},
|
||||
{envVar: "t", want: true},
|
||||
{envVar: "TRUE", want: true},
|
||||
{envVar: "1", want: true},
|
||||
{envVar: "false", want: false},
|
||||
{envVar: "f", want: false},
|
||||
{envVar: "0", want: false},
|
||||
{envVar: "abc", want: true},
|
||||
{"No username nor password", fields{"", ""}, false},
|
||||
{"Only username", fields{"abc", ""}, false},
|
||||
{"Only password", fields{"", "abc"}, false},
|
||||
{"Username and password", fields{"abc", "abc"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_ = os.Setenv("DNT", tt.envVar)
|
||||
if got := dnt(); got != tt.want {
|
||||
t.Errorf("dnt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dbPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
wantDbPath string
|
||||
}{
|
||||
{name: "default", envVar: "", wantDbPath: "data/kis3.db"},
|
||||
{envVar: "kis3.db", wantDbPath: "kis3.db"},
|
||||
{envVar: "data.db", wantDbPath: "data.db"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_ = os.Setenv("DB_PATH", tt.envVar)
|
||||
if gotDbPath := dbPath(); gotDbPath != tt.wantDbPath {
|
||||
t.Errorf("dbPath() = %v, want %v", gotDbPath, tt.wantDbPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_statsUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
wantUsername string
|
||||
}{
|
||||
{name: "default", envVar: "", wantUsername: ""},
|
||||
{envVar: "abc", wantUsername: "abc"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_ = os.Setenv("STATS_USERNAME", tt.envVar)
|
||||
if gotUsername := statsUsername(); gotUsername != tt.wantUsername {
|
||||
t.Errorf("statsUsername() = %v, want %v", gotUsername, tt.wantUsername)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_statsPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envVar string
|
||||
wantPassword string
|
||||
}{
|
||||
{name: "default", envVar: "", wantPassword: ""},
|
||||
{envVar: "def", wantPassword: "def"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_ = os.Setenv("STATS_PASSWORD", tt.envVar)
|
||||
if gotPassword := statsPassword(); gotPassword != tt.wantPassword {
|
||||
t.Errorf("statsPassword() = %v, want %v", gotPassword, tt.wantPassword)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_statsAuth(t *testing.T) {
|
||||
type args struct {
|
||||
ac *config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{name: "default", args: struct{ ac *config }{ac: &config{}}, want: false},
|
||||
{name: "only username set", args: struct{ ac *config }{ac: &config{statsUsername: "abc"}}, want: false},
|
||||
{name: "only password set", args: struct{ ac *config }{ac: &config{statsPassword: "def"}}, want: false},
|
||||
{name: "username and password set", args: struct{ ac *config }{ac: &config{statsUsername: "abc", statsPassword: "def"}}, want: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := statsAuth(tt.args.ac); got != tt.want {
|
||||
t.Errorf("statsAuth() = %v, want %v", got, tt.want)
|
||||
ac := &config{
|
||||
StatsUsername: tt.fields.StatsUsername,
|
||||
StatsPassword: tt.fields.StatsPassword,
|
||||
}
|
||||
if got := ac.statsAuth(); got != tt.want {
|
||||
t.Errorf("config.statsAuth() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,10 +19,10 @@ type Database struct {
|
|||
|
||||
func initDatabase() (database *Database, e error) {
|
||||
database = &Database{}
|
||||
if _, err := os.Stat(appConfig.dbPath); os.IsNotExist(err) {
|
||||
_ = os.MkdirAll(filepath.Dir(appConfig.dbPath), os.ModePerm)
|
||||
if _, err := os.Stat(appConfig.DbPath); os.IsNotExist(err) {
|
||||
_ = os.MkdirAll(filepath.Dir(appConfig.DbPath), os.ModePerm)
|
||||
}
|
||||
database.sqlDB, e = sql.Open("sqlite3", appConfig.dbPath)
|
||||
database.sqlDB, e = sql.Open("sqlite3", appConfig.DbPath)
|
||||
if e != nil {
|
||||
return
|
||||
}
|
||||
|
|
8
main.go
8
main.go
|
@ -60,7 +60,7 @@ func setupRouter() {
|
|||
}
|
||||
|
||||
func startListening() {
|
||||
port := appConfig.port
|
||||
port := appConfig.Port
|
||||
addr := ":" + port
|
||||
fmt.Printf("Listening to %s\n", addr)
|
||||
log.Fatal(http.ListenAndServe(addr, app.router))
|
||||
|
@ -71,7 +71,7 @@ func trackView(w http.ResponseWriter, r *http.Request) {
|
|||
url := r.URL.Query().Get("url")
|
||||
ref := r.URL.Query().Get("ref")
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !(r.Header.Get("DNT") == "1" && appConfig.dnt) {
|
||||
if !(r.Header.Get("DNT") == "1" && appConfig.Dnt) {
|
||||
go app.db.trackView(url, ref, ua) // run with goroutine for awesome speed!
|
||||
_, _ = fmt.Fprint(w, "true")
|
||||
}
|
||||
|
@ -99,8 +99,8 @@ func serveTrackingScript(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func requestStats(w http.ResponseWriter, r *http.Request) {
|
||||
// Require authentication
|
||||
if appConfig.statsAuth {
|
||||
if !helpers.CheckAuth(w, r, appConfig.statsUsername, appConfig.statsPassword) {
|
||||
if appConfig.statsAuth() {
|
||||
if !helpers.CheckAuth(w, r, appConfig.StatsUsername, appConfig.StatsPassword) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Reference in New Issue