diff --git a/database.go b/database.go index fa87c7b..17ffd35 100644 --- a/database.go +++ b/database.go @@ -2,14 +2,34 @@ package main import ( "database/sql" + "errors" "log" + "os" + "path/filepath" "github.com/lopezator/migrator" ) -func migrateDatabase() { - dbWriteLock.Lock() - defer dbWriteLock.Unlock() +func (a *app) openDatabase() (err error) { + if a.config.DBPath == "" { + return errors.New("empty database path") + } + _ = os.MkdirAll(filepath.Dir(a.config.DBPath), 0644) + a.database, err = sql.Open("sqlite3", a.config.DBPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100") + if err != nil { + return err + } + a.shutdown.Add(func() { + _ = a.database.Close() + log.Println("Closed database") + }) + a.migrateDatabase() + return nil +} + +func (a *app) migrateDatabase() { + a.write.Lock() + defer a.write.Unlock() m, err := migrator.New( migrator.Migrations( &migrator.Migration{ @@ -29,8 +49,40 @@ func migrateDatabase() { log.Fatal(err.Error()) return } - if err := m.Migrate(appDb); err != nil { + if err := m.Migrate(a.database); err != nil { log.Fatal(err.Error()) return } } + +const ( + typUrl = "url" + typText = "text" +) + +func (a *app) insertRedirect(slug string, url string, typ string) error { + a.write.Lock() + defer a.write.Unlock() + _, err := a.database.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, ?)", slug, url, typ) + return err +} + +func (a *app) deleteSlug(slug string) error { + a.write.Lock() + defer a.write.Unlock() + _, err := a.database.Exec("DELETE FROM redirect WHERE slug = ?", slug) + return err +} + +func (a *app) increaseHits(slug string) { + go func() { + a.write.Lock() + defer a.write.Unlock() + _, _ = a.database.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug) + }() +} + +func (a *app) slugExists(slug string) (exists bool, err error) { + err = a.database.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists) + return +} diff --git a/main.go b/main.go index 77e9fa2..2ef41f3 100644 --- a/main.go +++ b/main.go @@ -8,43 +8,51 @@ import ( "math/rand" "net/http" "os" - "path/filepath" "strconv" "sync" "time" - goshutdowner "git.jlel.se/jlelse/go-shutdowner" + gsd "git.jlel.se/jlelse/go-shutdowner" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" _ "github.com/mattn/go-sqlite3" "github.com/spf13/viper" ) -var ( - appDb *sql.DB - appRouter *chi.Mux - dbWriteLock sync.Mutex - shutdown goshutdowner.Shutdowner -) +type app struct { + config *config + database *sql.DB + write sync.Mutex + shutdown gsd.Shutdowner +} -func initRouter() { - appRouter = chi.NewRouter() - appRouter.Use(middleware.GetHead) - appRouter.Group(func(r chi.Router) { - r.Use(loginMiddleware) +type config struct { + Port int `mapstructure:"port"` + DBPath string `mapstructure:"dbPath"` + Password string `mapstructure:"password"` + ShortUrl string `mapstructure:"shortUrl"` + DefaultUrl string `mapstructure:"defaultUrl"` +} + +func (a *app) initRouter() (router *chi.Mux) { + router = chi.NewMux() + router.Use(middleware.GetHead) + router.Group(func(r chi.Router) { + r.Use(a.loginMiddleware) r.Get("/s", shortenFormHandler) - r.Post("/s", shortenHandler) + r.Post("/s", a.shortenHandler) r.Get("/t", shortenTextFormHandler) - r.Post("/t", shortenTextHandler) + r.Post("/t", a.shortenTextHandler) r.Get("/u", updateFormHandler) r.Get("/ut", updateTextFormHandler) - r.Post("/u", updateHandler) + r.Post("/u", a.updateHandler) r.Get("/d", deleteFormHandler) - r.Post("/d", deleteHandler) - r.Get("/l", listHandler) + r.Post("/d", a.deleteHandler) + r.Get("/l", a.listHandler) }) - appRouter.Get("/{slug}", shortenedURLHandler) - appRouter.Get("/", defaultURLRedirectHandler) + router.Get("/{slug}", a.shortenedURLHandler) + router.Get("/", a.defaultURLRedirectHandler) + return } func main() { @@ -58,46 +66,43 @@ func main() { viper.AddConfigPath(".") _ = viper.ReadInConfig() - if !viper.IsSet("dbPath") { - log.Fatal("No database path (dbPath) is configured.") - } if !viper.IsSet("password") { log.Fatal("No password (password) is configured.") + return } if !viper.IsSet("shortUrl") { log.Fatal("No short URL (shortUrl) is configured.") + return } if !viper.IsSet("defaultUrl") { log.Fatal("No default URL (defaultUrl) is configured.") + return } - var err error - dbPath := viper.GetString("dbPath") - _ = os.MkdirAll(filepath.Dir(dbPath), 0644) - appDb, err = sql.Open("sqlite3", dbPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100") + app := &app{} + + app.config = &config{} + err := viper.Unmarshal(app.config) + if err != nil { + log.Fatal("Failed to unmarshal config:", err.Error()) + return + } + + err = app.openDatabase() if err != nil { log.Println("Error opening database:", err.Error()) + app.shutdown.ShutdownAndWait() + os.Exit(1) + return } - shutdown.Add(func() { - _ = appDb.Close() - log.Println("Closed database") - }) - - migrateDatabase() - - defer func() { - _ = appDb.Close() - }() - - initRouter() httpServer := &http.Server{ - Addr: ":" + strconv.Itoa(viper.GetInt("port")), - Handler: appRouter, + Addr: ":" + strconv.Itoa(app.config.Port), + Handler: app.initRouter(), ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, } - shutdown.Add(func() { + app.shutdown.Add(func() { toc, c := context.WithTimeout(context.Background(), 5*time.Second) defer c() if err := httpServer.Shutdown(toc); err != nil { @@ -112,12 +117,12 @@ func main() { } }() - shutdown.Wait() + app.shutdown.Wait() } -func loginMiddleware(next http.Handler) http.Handler { +func (a *app) loginMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !checkPassword(w, r) { + if !a.checkPassword(w, r) { return } next.ServeHTTP(w, r) @@ -171,9 +176,9 @@ func generateTextForm(w http.ResponseWriter, title string, url string, fields [] }) } -func shortenHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) shortenHandler(w http.ResponseWriter, r *http.Request) { writeShortenedURL := func(w http.ResponseWriter, slug string) { - _, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug)) + _, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug)) } requestURL := r.FormValue("url") @@ -185,13 +190,13 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) { slug := r.FormValue("slug") manualSlug := false if slug == "" { - _ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug) + _ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug) } else { manualSlug = true } if slug != "" { - if e, _ := slugExists(slug); e { + if e, _ := a.slugExists(slug); e { if manualSlug { http.Error(w, "slug already in use", http.StatusBadRequest) return @@ -204,7 +209,7 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) { var err error for exists { slug = generateSlug() - exists, err = slugExists(slug) + exists, err = a.slugExists(slug) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -212,21 +217,18 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) { } } - dbWriteLock.Lock() - if _, err := appDb.Exec("INSERT INTO redirect (slug, url) VALUES (?, ?)", slug, requestURL); err != nil { - dbWriteLock.Unlock() + if err := a.insertRedirect(slug, requestURL, typUrl); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - dbWriteLock.Unlock() w.WriteHeader(http.StatusCreated) writeShortenedURL(w, slug) } -func shortenTextHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) shortenTextHandler(w http.ResponseWriter, r *http.Request) { writeShortenedURL := func(w http.ResponseWriter, slug string) { - _, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug)) + _, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug)) } requestText := r.FormValue("text") @@ -238,13 +240,13 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) { slug := r.FormValue("slug") manualSlug := false if slug == "" { - _ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug) + _ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug) } else { manualSlug = true } if slug != "" { - if e, _ := slugExists(slug); e { + if e, _ := a.slugExists(slug); e { if manualSlug { http.Error(w, "slug already in use", http.StatusBadRequest) return @@ -257,7 +259,7 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) { var err error for exists { slug = generateSlug() - exists, err = slugExists(slug) + exists, err = a.slugExists(slug) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -265,19 +267,16 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) { } } - dbWriteLock.Lock() - if _, err := appDb.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, 'text')", slug, requestText); err != nil { - dbWriteLock.Unlock() + if err := a.insertRedirect(slug, requestText, typText); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - dbWriteLock.Unlock() w.WriteHeader(http.StatusCreated) writeShortenedURL(w, slug) } -func updateHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) updateHandler(w http.ResponseWriter, r *http.Request) { slug := r.FormValue("slug") if slug == "" { http.Error(w, "Specify the slug to update", http.StatusBadRequest) @@ -295,55 +294,52 @@ func updateHandler(w http.ResponseWriter, r *http.Request) { typeString = "url" } - if e, err := slugExists(slug); err != nil || !e { + if e, err := a.slugExists(slug); err != nil || !e { http.NotFound(w, r) return } - dbWriteLock.Lock() - if _, err := appDb.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil { - dbWriteLock.Unlock() + a.write.Lock() + if _, err := a.database.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil { + a.write.Unlock() http.Error(w, err.Error(), http.StatusInternalServerError) return } - dbWriteLock.Unlock() + a.write.Unlock() w.WriteHeader(http.StatusAccepted) _, _ = w.Write([]byte("Slug updated")) } -func deleteHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) deleteHandler(w http.ResponseWriter, r *http.Request) { slug := r.FormValue("slug") if slug == "" { http.Error(w, "Specify the slug to delete", http.StatusBadRequest) return } - if e, err := slugExists(slug); !e || err != nil { + if e, err := a.slugExists(slug); !e || err != nil { http.NotFound(w, r) return } - dbWriteLock.Lock() - if _, err := appDb.Exec("DELETE FROM redirect WHERE slug = ?", slug); err != nil { - dbWriteLock.Unlock() + if err := a.deleteSlug(slug); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - dbWriteLock.Unlock() w.WriteHeader(http.StatusAccepted) _, _ = w.Write([]byte("Slug deleted")) } -func listHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) listHandler(w http.ResponseWriter, r *http.Request) { type row struct { Slug string URL string Hits int } var list []row - rows, err := appDb.Query("SELECT slug, url, hits FROM redirect") + rows, err := a.database.Query("SELECT slug, url, hits FROM redirect") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -364,16 +360,19 @@ func listHandler(w http.ResponseWriter, r *http.Request) { } } -func checkPassword(w http.ResponseWriter, r *http.Request) bool { - if r.FormValue("password") == viper.GetString("password") { +func (a *app) checkPassword(w http.ResponseWriter, r *http.Request) bool { + // Check basic auth + if _, pass, ok := r.BasicAuth(); ok && pass == a.config.Password { return true } - if _, pass, ok := r.BasicAuth(); !ok || pass != viper.GetString("password") { - w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`) - http.Error(w, "Not authenticated", http.StatusUnauthorized) - return false + // Check query or form param + if r.FormValue("password") == a.config.Password { + return true } - return true + // Require password + w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`) + http.Error(w, "Not authenticated", http.StatusUnauthorized) + return false } func generateSlug() string { @@ -385,34 +384,26 @@ func generateSlug() string { return string(s) } -func slugExists(slug string) (exists bool, err error) { - err = appDb.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists) - return -} - -func shortenedURLHandler(w http.ResponseWriter, r *http.Request) { +func (a *app) shortenedURLHandler(w http.ResponseWriter, r *http.Request) { slug := chi.URLParam(r, "slug") var redirectURL, typeString string - err := appDb.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString) + err := a.database.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString) if err != nil { http.NotFound(w, r) return } - go func() { - dbWriteLock.Lock() - _, _ = appDb.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug) - dbWriteLock.Unlock() - }() + a.increaseHits(slug) - if typeString == "text" { + switch typeString { + case typText: _, _ = w.Write([]byte(redirectURL)) - } else { + default: http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) } } -func defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, viper.GetString("defaultUrl"), http.StatusTemporaryRedirect) +func (a *app) defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, a.config.DefaultUrl, http.StatusTemporaryRedirect) } diff --git a/main_test.go b/main_test.go index 613e95c..661ecc7 100644 --- a/main_test.go +++ b/main_test.go @@ -1,43 +1,43 @@ package main import ( - "database/sql" + "io" "net/http" "net/http/httptest" "path/filepath" "testing" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setupFakeDB(t *testing.T) { - var err error - appDb, err = sql.Open("sqlite3", filepath.Join(t.TempDir(), "data.db")+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100") - if err != nil { - t.Fatal(err) +func testApp(t *testing.T) *app { + app := &app{ + config: &config{ + DBPath: filepath.Join(t.TempDir(), "data.db"), + }, } - migrateDatabase() + err := app.openDatabase() + require.NoError(t, err) + return app } -func closeFakeDB(t *testing.T) { - err := appDb.Close() - require.NoError(t, err) +func closeTestApp(_ *testing.T, app *app) { + app.shutdown.ShutdownAndWait() } func Test_slugExists(t *testing.T) { t.Run("Test slugs", func(t *testing.T) { - setupFakeDB(t) + app := testApp(t) - exists, err := slugExists("source") + exists, err := app.slugExists("source") assert.NoError(t, err) assert.True(t, exists) - exists, err = slugExists("test") + exists, err = app.slugExists("test") assert.NoError(t, err) assert.False(t, exists) - closeFakeDB(t) + closeTestApp(t, app) }) } @@ -48,22 +48,25 @@ func Test_generateSlug(t *testing.T) { } func TestShortenedUrlHandler(t *testing.T) { - viper.Set("defaultUrl", "http://long.example.com") t.Run("Test ShortenedUrlHandler", func(t *testing.T) { - setupFakeDB(t) - initRouter() + app := testApp(t) + + app.config.DefaultUrl = "http://long.example.com" + + router := app.initRouter() + t.Run("Test redirect code", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/source", nil) w := httptest.NewRecorder() - appRouter.ServeHTTP(w, req) + router.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) }) - t.Run("Test redirect location header", func(t *testing.T) { + t.Run("Test default redirect location header", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/source", nil) w := httptest.NewRecorder() - appRouter.ServeHTTP(w, req) + router.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location")) @@ -71,7 +74,7 @@ func TestShortenedUrlHandler(t *testing.T) { t.Run("Test missing slug redirect code", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test", nil) w := httptest.NewRecorder() - appRouter.ServeHTTP(w, req) + router.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, http.StatusNotFound, resp.StatusCode) @@ -79,43 +82,100 @@ func TestShortenedUrlHandler(t *testing.T) { t.Run("Test no slug mux var", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/", nil) w := httptest.NewRecorder() - appRouter.ServeHTTP(w, req) + router.ServeHTTP(w, req) resp := w.Result() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) assert.Equal(t, "http://long.example.com", resp.Header.Get("Location")) }) - closeFakeDB(t) + t.Run("Test custom url redirect", func(t *testing.T) { + err := app.insertRedirect("customurl", "https://example.net", typUrl) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "http://example.com/customurl", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + resp := w.Result() + + assert.Equal(t, "https://example.net", resp.Header.Get("Location")) + }) + t.Run("Test custom text", func(t *testing.T) { + err := app.insertRedirect("customtext", "Hello!", typText) + require.NoError(t, err) + + req := httptest.NewRequest("GET", "http://example.com/customtext", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + resp := w.Result() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + _ = resp.Body.Close() + + assert.Equal(t, "Hello!", string(respBody)) + }) + + closeTestApp(t, app) }) } func Test_checkPassword(t *testing.T) { - viper.Set("password", "abc") + app := testApp(t) + + app.config.Password = "abc" + t.Run("No password", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test", nil) - assert.False(t, checkPassword(httptest.NewRecorder(), req)) + assert.False(t, app.checkPassword(httptest.NewRecorder(), req)) }) t.Run("Password via query", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil) - assert.True(t, checkPassword(httptest.NewRecorder(), req)) + assert.True(t, app.checkPassword(httptest.NewRecorder(), req)) }) t.Run("Wrong password via query", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil) - assert.False(t, checkPassword(httptest.NewRecorder(), req)) + assert.False(t, app.checkPassword(httptest.NewRecorder(), req)) }) t.Run("Password via BasicAuth", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test", nil) req.SetBasicAuth("username", "abc") - assert.True(t, checkPassword(httptest.NewRecorder(), req)) + assert.True(t, app.checkPassword(httptest.NewRecorder(), req)) }) t.Run("Wrong password via BasicAuth", func(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/test", nil) req.SetBasicAuth("username", "wrong") - assert.False(t, checkPassword(httptest.NewRecorder(), req)) + assert.False(t, app.checkPassword(httptest.NewRecorder(), req)) }) + + t.Run("Test login middleware success", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com/test", nil) + req.SetBasicAuth("username", "abc") + + w := httptest.NewRecorder() + app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusNotModified) + })).ServeHTTP(w, req) + resp := w.Result() + + assert.Equal(t, http.StatusNotModified, resp.StatusCode) + }) + t.Run("Test login middleware fail", func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com/test", nil) + req.SetBasicAuth("username", "xyz") + + w := httptest.NewRecorder() + app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusNotModified) + })).ServeHTTP(w, req) + resp := w.Result() + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + + closeTestApp(t, app) }