1
Fork 0

Improvements (refactoring, test coverage)
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Jan-Lukas Else 2021-08-11 10:09:22 +02:00
parent 2389deb85f
commit 5f62857286
3 changed files with 237 additions and 134 deletions

View File

@ -2,14 +2,34 @@ package main
import ( import (
"database/sql" "database/sql"
"errors"
"log" "log"
"os"
"path/filepath"
"github.com/lopezator/migrator" "github.com/lopezator/migrator"
) )
func migrateDatabase() { func (a *app) openDatabase() (err error) {
dbWriteLock.Lock() if a.config.DBPath == "" {
defer dbWriteLock.Unlock() return errors.New("empty database path")
}
_ = os.MkdirAll(filepath.Dir(a.config.DBPath), 0644)
a.database, err = sql.Open("sqlite3", a.config.DBPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
if err != nil {
return err
}
a.shutdown.Add(func() {
_ = a.database.Close()
log.Println("Closed database")
})
a.migrateDatabase()
return nil
}
func (a *app) migrateDatabase() {
a.write.Lock()
defer a.write.Unlock()
m, err := migrator.New( m, err := migrator.New(
migrator.Migrations( migrator.Migrations(
&migrator.Migration{ &migrator.Migration{
@ -29,8 +49,40 @@ func migrateDatabase() {
log.Fatal(err.Error()) log.Fatal(err.Error())
return return
} }
if err := m.Migrate(appDb); err != nil { if err := m.Migrate(a.database); err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
return return
} }
} }
const (
typUrl = "url"
typText = "text"
)
func (a *app) insertRedirect(slug string, url string, typ string) error {
a.write.Lock()
defer a.write.Unlock()
_, err := a.database.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, ?)", slug, url, typ)
return err
}
func (a *app) deleteSlug(slug string) error {
a.write.Lock()
defer a.write.Unlock()
_, err := a.database.Exec("DELETE FROM redirect WHERE slug = ?", slug)
return err
}
func (a *app) increaseHits(slug string) {
go func() {
a.write.Lock()
defer a.write.Unlock()
_, _ = a.database.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
}()
}
func (a *app) slugExists(slug string) (exists bool, err error) {
err = a.database.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
return
}

191
main.go
View File

@ -8,43 +8,51 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strconv" "strconv"
"sync" "sync"
"time" "time"
goshutdowner "git.jlel.se/jlelse/go-shutdowner" gsd "git.jlel.se/jlelse/go-shutdowner"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
var ( type app struct {
appDb *sql.DB config *config
appRouter *chi.Mux database *sql.DB
dbWriteLock sync.Mutex write sync.Mutex
shutdown goshutdowner.Shutdowner shutdown gsd.Shutdowner
) }
func initRouter() { type config struct {
appRouter = chi.NewRouter() Port int `mapstructure:"port"`
appRouter.Use(middleware.GetHead) DBPath string `mapstructure:"dbPath"`
appRouter.Group(func(r chi.Router) { Password string `mapstructure:"password"`
r.Use(loginMiddleware) ShortUrl string `mapstructure:"shortUrl"`
DefaultUrl string `mapstructure:"defaultUrl"`
}
func (a *app) initRouter() (router *chi.Mux) {
router = chi.NewMux()
router.Use(middleware.GetHead)
router.Group(func(r chi.Router) {
r.Use(a.loginMiddleware)
r.Get("/s", shortenFormHandler) r.Get("/s", shortenFormHandler)
r.Post("/s", shortenHandler) r.Post("/s", a.shortenHandler)
r.Get("/t", shortenTextFormHandler) r.Get("/t", shortenTextFormHandler)
r.Post("/t", shortenTextHandler) r.Post("/t", a.shortenTextHandler)
r.Get("/u", updateFormHandler) r.Get("/u", updateFormHandler)
r.Get("/ut", updateTextFormHandler) r.Get("/ut", updateTextFormHandler)
r.Post("/u", updateHandler) r.Post("/u", a.updateHandler)
r.Get("/d", deleteFormHandler) r.Get("/d", deleteFormHandler)
r.Post("/d", deleteHandler) r.Post("/d", a.deleteHandler)
r.Get("/l", listHandler) r.Get("/l", a.listHandler)
}) })
appRouter.Get("/{slug}", shortenedURLHandler) router.Get("/{slug}", a.shortenedURLHandler)
appRouter.Get("/", defaultURLRedirectHandler) router.Get("/", a.defaultURLRedirectHandler)
return
} }
func main() { func main() {
@ -58,46 +66,43 @@ func main() {
viper.AddConfigPath(".") viper.AddConfigPath(".")
_ = viper.ReadInConfig() _ = viper.ReadInConfig()
if !viper.IsSet("dbPath") {
log.Fatal("No database path (dbPath) is configured.")
}
if !viper.IsSet("password") { if !viper.IsSet("password") {
log.Fatal("No password (password) is configured.") log.Fatal("No password (password) is configured.")
return
} }
if !viper.IsSet("shortUrl") { if !viper.IsSet("shortUrl") {
log.Fatal("No short URL (shortUrl) is configured.") log.Fatal("No short URL (shortUrl) is configured.")
return
} }
if !viper.IsSet("defaultUrl") { if !viper.IsSet("defaultUrl") {
log.Fatal("No default URL (defaultUrl) is configured.") log.Fatal("No default URL (defaultUrl) is configured.")
return
} }
var err error app := &app{}
dbPath := viper.GetString("dbPath")
_ = os.MkdirAll(filepath.Dir(dbPath), 0644) app.config = &config{}
appDb, err = sql.Open("sqlite3", dbPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100") err := viper.Unmarshal(app.config)
if err != nil {
log.Fatal("Failed to unmarshal config:", err.Error())
return
}
err = app.openDatabase()
if err != nil { if err != nil {
log.Println("Error opening database:", err.Error()) log.Println("Error opening database:", err.Error())
app.shutdown.ShutdownAndWait()
os.Exit(1)
return
} }
shutdown.Add(func() {
_ = appDb.Close()
log.Println("Closed database")
})
migrateDatabase()
defer func() {
_ = appDb.Close()
}()
initRouter()
httpServer := &http.Server{ httpServer := &http.Server{
Addr: ":" + strconv.Itoa(viper.GetInt("port")), Addr: ":" + strconv.Itoa(app.config.Port),
Handler: appRouter, Handler: app.initRouter(),
ReadTimeout: 5 * time.Minute, ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute,
} }
shutdown.Add(func() { app.shutdown.Add(func() {
toc, c := context.WithTimeout(context.Background(), 5*time.Second) toc, c := context.WithTimeout(context.Background(), 5*time.Second)
defer c() defer c()
if err := httpServer.Shutdown(toc); err != nil { if err := httpServer.Shutdown(toc); err != nil {
@ -112,12 +117,12 @@ func main() {
} }
}() }()
shutdown.Wait() app.shutdown.Wait()
} }
func loginMiddleware(next http.Handler) http.Handler { func (a *app) loginMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !checkPassword(w, r) { if !a.checkPassword(w, r) {
return return
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -171,9 +176,9 @@ func generateTextForm(w http.ResponseWriter, title string, url string, fields []
}) })
} }
func shortenHandler(w http.ResponseWriter, r *http.Request) { func (a *app) shortenHandler(w http.ResponseWriter, r *http.Request) {
writeShortenedURL := func(w http.ResponseWriter, slug string) { writeShortenedURL := func(w http.ResponseWriter, slug string) {
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug)) _, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
} }
requestURL := r.FormValue("url") requestURL := r.FormValue("url")
@ -185,13 +190,13 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug") slug := r.FormValue("slug")
manualSlug := false manualSlug := false
if slug == "" { if slug == "" {
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug) _ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug)
} else { } else {
manualSlug = true manualSlug = true
} }
if slug != "" { if slug != "" {
if e, _ := slugExists(slug); e { if e, _ := a.slugExists(slug); e {
if manualSlug { if manualSlug {
http.Error(w, "slug already in use", http.StatusBadRequest) http.Error(w, "slug already in use", http.StatusBadRequest)
return return
@ -204,7 +209,7 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
var err error var err error
for exists { for exists {
slug = generateSlug() slug = generateSlug()
exists, err = slugExists(slug) exists, err = a.slugExists(slug)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -212,21 +217,18 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
dbWriteLock.Lock() if err := a.insertRedirect(slug, requestURL, typUrl); err != nil {
if _, err := appDb.Exec("INSERT INTO redirect (slug, url) VALUES (?, ?)", slug, requestURL); err != nil {
dbWriteLock.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
dbWriteLock.Unlock()
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
writeShortenedURL(w, slug) writeShortenedURL(w, slug)
} }
func shortenTextHandler(w http.ResponseWriter, r *http.Request) { func (a *app) shortenTextHandler(w http.ResponseWriter, r *http.Request) {
writeShortenedURL := func(w http.ResponseWriter, slug string) { writeShortenedURL := func(w http.ResponseWriter, slug string) {
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug)) _, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
} }
requestText := r.FormValue("text") requestText := r.FormValue("text")
@ -238,13 +240,13 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug") slug := r.FormValue("slug")
manualSlug := false manualSlug := false
if slug == "" { if slug == "" {
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug) _ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug)
} else { } else {
manualSlug = true manualSlug = true
} }
if slug != "" { if slug != "" {
if e, _ := slugExists(slug); e { if e, _ := a.slugExists(slug); e {
if manualSlug { if manualSlug {
http.Error(w, "slug already in use", http.StatusBadRequest) http.Error(w, "slug already in use", http.StatusBadRequest)
return return
@ -257,7 +259,7 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
var err error var err error
for exists { for exists {
slug = generateSlug() slug = generateSlug()
exists, err = slugExists(slug) exists, err = a.slugExists(slug)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -265,19 +267,16 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
dbWriteLock.Lock() if err := a.insertRedirect(slug, requestText, typText); err != nil {
if _, err := appDb.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, 'text')", slug, requestText); err != nil {
dbWriteLock.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
dbWriteLock.Unlock()
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
writeShortenedURL(w, slug) writeShortenedURL(w, slug)
} }
func updateHandler(w http.ResponseWriter, r *http.Request) { func (a *app) updateHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug") slug := r.FormValue("slug")
if slug == "" { if slug == "" {
http.Error(w, "Specify the slug to update", http.StatusBadRequest) http.Error(w, "Specify the slug to update", http.StatusBadRequest)
@ -295,55 +294,52 @@ func updateHandler(w http.ResponseWriter, r *http.Request) {
typeString = "url" typeString = "url"
} }
if e, err := slugExists(slug); err != nil || !e { if e, err := a.slugExists(slug); err != nil || !e {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
dbWriteLock.Lock() a.write.Lock()
if _, err := appDb.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil { if _, err := a.database.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil {
dbWriteLock.Unlock() a.write.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
dbWriteLock.Unlock() a.write.Unlock()
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("Slug updated")) _, _ = w.Write([]byte("Slug updated"))
} }
func deleteHandler(w http.ResponseWriter, r *http.Request) { func (a *app) deleteHandler(w http.ResponseWriter, r *http.Request) {
slug := r.FormValue("slug") slug := r.FormValue("slug")
if slug == "" { if slug == "" {
http.Error(w, "Specify the slug to delete", http.StatusBadRequest) http.Error(w, "Specify the slug to delete", http.StatusBadRequest)
return return
} }
if e, err := slugExists(slug); !e || err != nil { if e, err := a.slugExists(slug); !e || err != nil {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
dbWriteLock.Lock() if err := a.deleteSlug(slug); err != nil {
if _, err := appDb.Exec("DELETE FROM redirect WHERE slug = ?", slug); err != nil {
dbWriteLock.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
dbWriteLock.Unlock()
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("Slug deleted")) _, _ = w.Write([]byte("Slug deleted"))
} }
func listHandler(w http.ResponseWriter, r *http.Request) { func (a *app) listHandler(w http.ResponseWriter, r *http.Request) {
type row struct { type row struct {
Slug string Slug string
URL string URL string
Hits int Hits int
} }
var list []row var list []row
rows, err := appDb.Query("SELECT slug, url, hits FROM redirect") rows, err := a.database.Query("SELECT slug, url, hits FROM redirect")
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -364,16 +360,19 @@ func listHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func checkPassword(w http.ResponseWriter, r *http.Request) bool { func (a *app) checkPassword(w http.ResponseWriter, r *http.Request) bool {
if r.FormValue("password") == viper.GetString("password") { // Check basic auth
if _, pass, ok := r.BasicAuth(); ok && pass == a.config.Password {
return true return true
} }
if _, pass, ok := r.BasicAuth(); !ok || pass != viper.GetString("password") { // Check query or form param
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`) if r.FormValue("password") == a.config.Password {
http.Error(w, "Not authenticated", http.StatusUnauthorized) return true
return false
} }
return true // Require password
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`)
http.Error(w, "Not authenticated", http.StatusUnauthorized)
return false
} }
func generateSlug() string { func generateSlug() string {
@ -385,34 +384,26 @@ func generateSlug() string {
return string(s) return string(s)
} }
func slugExists(slug string) (exists bool, err error) { func (a *app) shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
err = appDb.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
return
}
func shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug") slug := chi.URLParam(r, "slug")
var redirectURL, typeString string var redirectURL, typeString string
err := appDb.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString) err := a.database.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString)
if err != nil { if err != nil {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
go func() { a.increaseHits(slug)
dbWriteLock.Lock()
_, _ = appDb.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
dbWriteLock.Unlock()
}()
if typeString == "text" { switch typeString {
case typText:
_, _ = w.Write([]byte(redirectURL)) _, _ = w.Write([]byte(redirectURL))
} else { default:
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
} }
} }
func defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) { func (a *app) defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, viper.GetString("defaultUrl"), http.StatusTemporaryRedirect) http.Redirect(w, r, a.config.DefaultUrl, http.StatusTemporaryRedirect)
} }

View File

@ -1,43 +1,43 @@
package main package main
import ( import (
"database/sql" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func setupFakeDB(t *testing.T) { func testApp(t *testing.T) *app {
var err error app := &app{
appDb, err = sql.Open("sqlite3", filepath.Join(t.TempDir(), "data.db")+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100") config: &config{
if err != nil { DBPath: filepath.Join(t.TempDir(), "data.db"),
t.Fatal(err) },
} }
migrateDatabase() err := app.openDatabase()
require.NoError(t, err)
return app
} }
func closeFakeDB(t *testing.T) { func closeTestApp(_ *testing.T, app *app) {
err := appDb.Close() app.shutdown.ShutdownAndWait()
require.NoError(t, err)
} }
func Test_slugExists(t *testing.T) { func Test_slugExists(t *testing.T) {
t.Run("Test slugs", func(t *testing.T) { t.Run("Test slugs", func(t *testing.T) {
setupFakeDB(t) app := testApp(t)
exists, err := slugExists("source") exists, err := app.slugExists("source")
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists) assert.True(t, exists)
exists, err = slugExists("test") exists, err = app.slugExists("test")
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, exists) assert.False(t, exists)
closeFakeDB(t) closeTestApp(t, app)
}) })
} }
@ -48,22 +48,25 @@ func Test_generateSlug(t *testing.T) {
} }
func TestShortenedUrlHandler(t *testing.T) { func TestShortenedUrlHandler(t *testing.T) {
viper.Set("defaultUrl", "http://long.example.com")
t.Run("Test ShortenedUrlHandler", func(t *testing.T) { t.Run("Test ShortenedUrlHandler", func(t *testing.T) {
setupFakeDB(t) app := testApp(t)
initRouter()
app.config.DefaultUrl = "http://long.example.com"
router := app.initRouter()
t.Run("Test redirect code", func(t *testing.T) { t.Run("Test redirect code", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/source", nil) req := httptest.NewRequest("GET", "http://example.com/source", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req) router.ServeHTTP(w, req)
resp := w.Result() resp := w.Result()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
}) })
t.Run("Test redirect location header", func(t *testing.T) { t.Run("Test default redirect location header", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/source", nil) req := httptest.NewRequest("GET", "http://example.com/source", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req) router.ServeHTTP(w, req)
resp := w.Result() resp := w.Result()
assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location")) assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location"))
@ -71,7 +74,7 @@ func TestShortenedUrlHandler(t *testing.T) {
t.Run("Test missing slug redirect code", func(t *testing.T) { t.Run("Test missing slug redirect code", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil) req := httptest.NewRequest("GET", "http://example.com/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req) router.ServeHTTP(w, req)
resp := w.Result() resp := w.Result()
assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.Equal(t, http.StatusNotFound, resp.StatusCode)
@ -79,43 +82,100 @@ func TestShortenedUrlHandler(t *testing.T) {
t.Run("Test no slug mux var", func(t *testing.T) { t.Run("Test no slug mux var", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/", nil) req := httptest.NewRequest("GET", "http://example.com/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
appRouter.ServeHTTP(w, req) router.ServeHTTP(w, req)
resp := w.Result() resp := w.Result()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
assert.Equal(t, "http://long.example.com", resp.Header.Get("Location")) assert.Equal(t, "http://long.example.com", resp.Header.Get("Location"))
}) })
closeFakeDB(t) t.Run("Test custom url redirect", func(t *testing.T) {
err := app.insertRedirect("customurl", "https://example.net", typUrl)
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/customurl", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, "https://example.net", resp.Header.Get("Location"))
})
t.Run("Test custom text", func(t *testing.T) {
err := app.insertRedirect("customtext", "Hello!", typText)
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/customtext", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
resp := w.Result()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
_ = resp.Body.Close()
assert.Equal(t, "Hello!", string(respBody))
})
closeTestApp(t, app)
}) })
} }
func Test_checkPassword(t *testing.T) { func Test_checkPassword(t *testing.T) {
viper.Set("password", "abc") app := testApp(t)
app.config.Password = "abc"
t.Run("No password", func(t *testing.T) { t.Run("No password", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil) req := httptest.NewRequest("GET", "http://example.com/test", nil)
assert.False(t, checkPassword(httptest.NewRecorder(), req)) assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
}) })
t.Run("Password via query", func(t *testing.T) { t.Run("Password via query", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil) req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil)
assert.True(t, checkPassword(httptest.NewRecorder(), req)) assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
}) })
t.Run("Wrong password via query", func(t *testing.T) { t.Run("Wrong password via query", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil) req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil)
assert.False(t, checkPassword(httptest.NewRecorder(), req)) assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
}) })
t.Run("Password via BasicAuth", func(t *testing.T) { t.Run("Password via BasicAuth", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil) req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "abc") req.SetBasicAuth("username", "abc")
assert.True(t, checkPassword(httptest.NewRecorder(), req)) assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
}) })
t.Run("Wrong password via BasicAuth", func(t *testing.T) { t.Run("Wrong password via BasicAuth", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil) req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "wrong") req.SetBasicAuth("username", "wrong")
assert.False(t, checkPassword(httptest.NewRecorder(), req)) assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
}) })
t.Run("Test login middleware success", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "abc")
w := httptest.NewRecorder()
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNotModified)
})).ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusNotModified, resp.StatusCode)
})
t.Run("Test login middleware fail", func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
req.SetBasicAuth("username", "xyz")
w := httptest.NewRecorder()
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNotModified)
})).ServeHTTP(w, req)
resp := w.Result()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
closeTestApp(t, app)
} }