1
Fork 0

Improvements (refactoring, test coverage)
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Jan-Lukas Else 2021-08-11 10:09:22 +02:00
parent 2389deb85f
commit 5f62857286
3 changed files with 237 additions and 134 deletions

View File

@ -2,14 +2,34 @@ package main
import (
"database/sql"
"errors"
"log"
"os"
"path/filepath"
"github.com/lopezator/migrator"
)
func migrateDatabase() {
dbWriteLock.Lock()
defer dbWriteLock.Unlock()
func (a *app) openDatabase() (err error) {
if a.config.DBPath == "" {
return errors.New("empty database path")
}
_ = os.MkdirAll(filepath.Dir(a.config.DBPath), 0644)
a.database, err = sql.Open("sqlite3", a.config.DBPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
if err != nil {
return err
}
a.shutdown.Add(func() {
_ = a.database.Close()
log.Println("Closed database")
})
a.migrateDatabase()
return nil
}
func (a *app) migrateDatabase() {
a.write.Lock()
defer a.write.Unlock()
m, err := migrator.New(
migrator.Migrations(
&migrator.Migration{
@ -29,8 +49,40 @@ func migrateDatabase() {
log.Fatal(err.Error())
return
}
if err := m.Migrate(appDb); err != nil {
if err := m.Migrate(a.database); err != nil {
log.Fatal(err.Error())
return
}
}
const (
typUrl = "url"
typText = "text"
)
func (a *app) insertRedirect(slug string, url string, typ string) error {
a.write.Lock()
defer a.write.Unlock()
_, err := a.database.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, ?)", slug, url, typ)
return err
}
func (a *app) deleteSlug(slug string) error {
a.write.Lock()
defer a.write.Unlock()
_, err := a.database.Exec("DELETE FROM redirect WHERE slug = ?", slug)
return err
}
func (a *app) increaseHits(slug string) {
go func() {
a.write.Lock()
defer a.write.Unlock()
_, _ = a.database.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
}()
}
func (a *app) slugExists(slug string) (exists bool, err error) {
err = a.database.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
return
}

191
main.go
View File

@ -8,43 +8,51 @@ import (
"math/rand"
"net/http"
"os"
"path/filepath"
"strconv"
"sync"
"time"
goshutdowner "git.jlel.se/jlelse/go-shutdowner"
gsd "git.jlel.se/jlelse/go-shutdowner"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
_ "github.com/mattn/go-sqlite3"
"github.com/spf13/viper"
)
var (
appDb *sql.DB
appRouter *chi.Mux
dbWriteLock sync.Mutex
shutdown goshutdowner.Shutdowner
)
type app struct {
config *config
database *sql.DB
write sync.Mutex
shutdown gsd.Shutdowner
}
func initRouter() {
appRouter = chi.NewRouter()
appRouter.Use(middleware.GetHead)
appRouter.Group(func(r chi.Router) {
r.Use(loginMiddleware)
type config struct {
Port int `mapstructure:"port"`
DBPath string `mapstructure:"dbPath"`
Password string `mapstructure:"password"`
ShortUrl string `mapstructure:"shortUrl"`
DefaultUrl string `mapstructure:"defaultUrl"`
}
func (a *app) initRouter() (router *chi.Mux) {
router = chi.NewMux()
router.Use(middleware.GetHead)
router.Group(func(r chi.Router) {
r.Use(a.loginMiddleware)
r.Get("/s", shortenFormHandler)
r.Post("/s", shortenHandler)
r.Post("/s", a.shortenHandler)
r.Get("/t", shortenTextFormHandler)
r.Post("/t", shortenTextHandler)
r.Post("/t", a.shortenTextHandler)
r.Get("/u", updateFormHandler)
r.Get("/ut", updateTextFormHandler)
r.Post("/u", updateHandler)
r.Post("/u", a.updateHandler)
r.Get("/d", deleteFormHandler)
r.Post("/d", deleteHandler)
r.Get("/l", listHandler)
r.Post("/d", a.deleteHandler)
r.Get("/l", a.listHandler)
})
appRouter.Get("/{slug}", shortenedURLHandler)
appRouter.Get("/", defaultURLRedirectHandler)
router.Get("/{slug}", a.shortenedURLHandler)
router.Get("/", a.defaultURLRedirectHandler)
return
}
func main() {
@ -58,46 +66,43 @@ func main() {
viper.AddConfigPath(".")
_ = viper.ReadInConfig()
if !viper.IsSet("dbPath") {
log.Fatal("No database path (dbPath) is configured.")
}
if !viper.IsSet("password") {
log.Fatal("No password (password) is configured.")
return
}
if !viper.IsSet("shortUrl") {
log.Fatal("No short URL (shortUrl) is configured.")
return
}
if !viper.IsSet("defaultUrl") {
log.Fatal("No default URL (defaultUrl) is configured.")
return
}
var err error
dbPath := viper.GetString("dbPath")
_ = os.MkdirAll(filepath.Dir(dbPath), 0644)
appDb, err = sql.Open("sqlite3", dbPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
app := &app{}
app.config = &config{}
err := viper.Unmarshal(app.config)
if err != nil {
log.Fatal("Failed to unmarshal config:", err.Error())
return
}
err = app.openDatabase()
if err != nil {
log.Println("Error opening database:", err.Error())
app.shutdown.ShutdownAndWait()
os.Exit(1)
return
}
shutdown.Add(func() {
_ = appDb.Close()
log.Println("Closed database")
})
migrateDatabase()
defer func() {
_ = appDb.Close()
}()
initRouter()
httpServer := &http.Server{
Addr: ":" + strconv.Itoa(viper.GetInt("port")),
Handler: appRouter,
Addr: ":" + strconv.Itoa(app.config.Port),
Handler: app.initRouter(),
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
}
shutdown.Add(func() {
app.shutdown.Add(func() {
toc, c := context.WithTimeout(context.Background(), 5*time.Second)
defer c()
if err := httpServer.Shutdown(toc); err != nil {
@ -112,12 +117,12 @@ func main() {
}
}()
shutdown.Wait()
app.shutdown.Wait()
}
func loginMiddleware(next http.Handler) http.Handler {
func (a *app) loginMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !checkPassword(w, r) {
if !a.checkPassword(w, r) {
return
}
next.ServeHTTP(w, r)
@ -171,9 +176,9 @@ func generateTextForm(w http.ResponseWriter, title string, url string, fields []
})
}
func shortenHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) shortenHandler(w http.ResponseWriter, r *http.Request) {
writeShortenedURL := func(w http.ResponseWriter, slug string) {
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug))
_, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
}
requestURL := r.FormValue("url")
@ -185,13 +190,13 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug")
manualSlug := false
if slug == "" {
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug)
_ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug)
} else {
manualSlug = true
}
if slug != "" {
if e, _ := slugExists(slug); e {
if e, _ := a.slugExists(slug); e {
if manualSlug {
http.Error(w, "slug already in use", http.StatusBadRequest)
return
@ -204,7 +209,7 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
var err error
for exists {
slug = generateSlug()
exists, err = slugExists(slug)
exists, err = a.slugExists(slug)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -212,21 +217,18 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
}
}
dbWriteLock.Lock()
if _, err := appDb.Exec("INSERT INTO redirect (slug, url) VALUES (?, ?)", slug, requestURL); err != nil {
dbWriteLock.Unlock()
if err := a.insertRedirect(slug, requestURL, typUrl); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
dbWriteLock.Unlock()
w.WriteHeader(http.StatusCreated)
writeShortenedURL(w, slug)
}
func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) shortenTextHandler(w http.ResponseWriter, r *http.Request) {
writeShortenedURL := func(w http.ResponseWriter, slug string) {
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug))
_, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
}
requestText := r.FormValue("text")
@ -238,13 +240,13 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug")
manualSlug := false
if slug == "" {
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug)
_ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug)
} else {
manualSlug = true
}
if slug != "" {
if e, _ := slugExists(slug); e {
if e, _ := a.slugExists(slug); e {
if manualSlug {
http.Error(w, "slug already in use", http.StatusBadRequest)
return
@ -257,7 +259,7 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
var err error
for exists {
slug = generateSlug()
exists, err = slugExists(slug)
exists, err = a.slugExists(slug)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -265,19 +267,16 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
}
}
dbWriteLock.Lock()
if _, err := appDb.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, 'text')", slug, requestText); err != nil {
dbWriteLock.Unlock()
if err := a.insertRedirect(slug, requestText, typText); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
dbWriteLock.Unlock()
w.WriteHeader(http.StatusCreated)
writeShortenedURL(w, slug)
}
func updateHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) updateHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug")
if slug == "" {
http.Error(w, "Specify the slug to update", http.StatusBadRequest)
@ -295,55 +294,52 @@ func updateHandler(w http.ResponseWriter, r *http.Request) {
typeString = "url"
}
if e, err := slugExists(slug); err != nil || !e {
if e, err := a.slugExists(slug); err != nil || !e {
http.NotFound(w, r)
return
}
dbWriteLock.Lock()
if _, err := appDb.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil {
dbWriteLock.Unlock()
a.write.Lock()
if _, err := a.database.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil {
a.write.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
dbWriteLock.Unlock()
a.write.Unlock()
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("Slug updated"))
}
func deleteHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) deleteHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug")
if slug == "" {
http.Error(w, "Specify the slug to delete", http.StatusBadRequest)
return
}
if e, err := slugExists(slug); !e || err != nil {
if e, err := a.slugExists(slug); !e || err != nil {
http.NotFound(w, r)
return
}
dbWriteLock.Lock()
if _, err := appDb.Exec("DELETE FROM redirect WHERE slug = ?", slug); err != nil {
dbWriteLock.Unlock()
if err := a.deleteSlug(slug); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
dbWriteLock.Unlock()
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("Slug deleted"))
}
func listHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) listHandler(w http.ResponseWriter, r *http.Request) {
type row struct {
Slug string
URL string
Hits int
}
var list []row
rows, err := appDb.Query("SELECT slug, url, hits FROM redirect")
rows, err := a.database.Query("SELECT slug, url, hits FROM redirect")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -364,16 +360,19 @@ func listHandler(w http.ResponseWriter, r *http.Request) {
}
}
func checkPassword(w http.ResponseWriter, r *http.Request) bool {
if r.FormValue("password") == viper.GetString("password") {
func (a *app) checkPassword(w http.ResponseWriter, r *http.Request) bool {
// Check basic auth
if _, pass, ok := r.BasicAuth(); ok && pass == a.config.Password {
return true
}
if _, pass, ok := r.BasicAuth(); !ok || pass != viper.GetString("password") {
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`)
http.Error(w, "Not authenticated", http.StatusUnauthorized)
return false
// Check query or form param
if r.FormValue("password") == a.config.Password {
return true
}
return true
// Require password
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`)
http.Error(w, "Not authenticated", http.StatusUnauthorized)
return false
}
func generateSlug() string {
@ -385,34 +384,26 @@ func generateSlug() string {
return string(s)
}
func slugExists(slug string) (exists bool, err error) {
err = appDb.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
return
}
func shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
func (a *app) shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug")
var redirectURL, typeString string
err := appDb.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString)
err := a.database.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString)
if err != nil {
http.NotFound(w, r)
return
}
go func() {
dbWriteLock.Lock()
_, _ = appDb.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
dbWriteLock.Unlock()
}()
a.increaseHits(slug)
if typeString == "text" {
switch typeString {
case typText:
_, _ = w.Write([]byte(redirectURL))
} else {
default:
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}
}
func defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, viper.GetString("defaultUrl"), http.StatusTemporaryRedirect)
func (a *app) defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, a.config.DefaultUrl, http.StatusTemporaryRedirect)
}

View File

@ -1,43 +1,43 @@
package main
import (
"database/sql"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupFakeDB(t *testing.T) {
var err error
appDb, err = sql.Open("sqlite3", filepath.Join(t.TempDir(), "data.db")+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
if err != nil {
t.Fatal(err)
func testApp(t *testing.T) *app {
app := &app{
config: &config{
DBPath: filepath.Join(t.TempDir(), "data.db"),
},
}
migrateDatabase()
err := app.openDatabase()
require.NoError(t, err)
return app
}
func closeFakeDB(t *testing.T) {
err := appDb.Close()
require.NoError(t, err)
func closeTestApp(_ *testing.T, app *app) {
app.shutdown.ShutdownAndWait()
}
func Test_slugExists(t *testing.T) {
t.Run("Test slugs", func(t *testing.T) {
setupFakeDB(t)
app := testApp(t)
exists, err := slugExists("source")
exists, err := app.slugExists("source")
assert.NoError(t, err)
assert.True(t, exists)
exists, err = slugExists("test")
exists, err = app.slugExists("test")
assert.NoError(t, err)
assert.False(t, exists)
closeFakeDB(t)
closeTestApp(t, app)
})
}
@ -48,22 +48,25 @@ func Test_generateSlug(t *testing.T) {
}
func TestShortenedUrlHandler(t *testing.T) {
viper.Set("defaultUrl", "http://long.example.com")
t.Run("Test ShortenedUrlHandler", func(t *testing.T) {
setupFakeDB(t)
initRouter()
app := testApp(t)
app.config.DefaultUrl = "http://long.example.com"
router := app.initRouter()
t.Run("Test redirect code", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/source", nil)
w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req)
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
t.Run("Test redirect location header", func(t *testing.T) {
t.Run("Test default redirect location header", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/source", nil)
w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req)
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location"))
@ -71,7 +74,7 @@ func TestShortenedUrlHandler(t *testing.T) {
t.Run("Test missing slug redirect code", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req)
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
@ -79,43 +82,100 @@ func TestShortenedUrlHandler(t *testing.T) {
t.Run("Test no slug mux var", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/", nil)
w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req)
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
assert.Equal(t, "http://long.example.com", resp.Header.Get("Location"))
})
closeFakeDB(t)
t.Run("Test custom url redirect", func(t *testing.T) {
err := app.insertRedirect("customurl", "https://example.net", typUrl)
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/customurl", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, "https://example.net", resp.Header.Get("Location"))
})
t.Run("Test custom text", func(t *testing.T) {
err := app.insertRedirect("customtext", "Hello!", typText)
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/customtext", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
resp := w.Result()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
_ = resp.Body.Close()
assert.Equal(t, "Hello!", string(respBody))
})
closeTestApp(t, app)
})
}
func Test_checkPassword(t *testing.T) {
viper.Set("password", "abc")
app := testApp(t)
app.config.Password = "abc"
t.Run("No password", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
assert.False(t, checkPassword(httptest.NewRecorder(), req))
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
})
t.Run("Password via query", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil)
assert.True(t, checkPassword(httptest.NewRecorder(), req))
assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
})
t.Run("Wrong password via query", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil)
assert.False(t, checkPassword(httptest.NewRecorder(), req))
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
})
t.Run("Password via BasicAuth", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "abc")
assert.True(t, checkPassword(httptest.NewRecorder(), req))
assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
})
t.Run("Wrong password via BasicAuth", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "wrong")
assert.False(t, checkPassword(httptest.NewRecorder(), req))
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
})
t.Run("Test login middleware success", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "abc")
w := httptest.NewRecorder()
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNotModified)
})).ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusNotModified, resp.StatusCode)
})
t.Run("Test login middleware fail", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "xyz")
w := httptest.NewRecorder()
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNotModified)
})).ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
closeTestApp(t, app)
}