Improvements (refactoring, test coverage)
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
2389deb85f
commit
5f62857286
60
database.go
60
database.go
|
@ -2,14 +2,34 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/lopezator/migrator"
|
"github.com/lopezator/migrator"
|
||||||
)
|
)
|
||||||
|
|
||||||
func migrateDatabase() {
|
func (a *app) openDatabase() (err error) {
|
||||||
dbWriteLock.Lock()
|
if a.config.DBPath == "" {
|
||||||
defer dbWriteLock.Unlock()
|
return errors.New("empty database path")
|
||||||
|
}
|
||||||
|
_ = os.MkdirAll(filepath.Dir(a.config.DBPath), 0644)
|
||||||
|
a.database, err = sql.Open("sqlite3", a.config.DBPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.shutdown.Add(func() {
|
||||||
|
_ = a.database.Close()
|
||||||
|
log.Println("Closed database")
|
||||||
|
})
|
||||||
|
a.migrateDatabase()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *app) migrateDatabase() {
|
||||||
|
a.write.Lock()
|
||||||
|
defer a.write.Unlock()
|
||||||
m, err := migrator.New(
|
m, err := migrator.New(
|
||||||
migrator.Migrations(
|
migrator.Migrations(
|
||||||
&migrator.Migration{
|
&migrator.Migration{
|
||||||
|
@ -29,8 +49,40 @@ func migrateDatabase() {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := m.Migrate(appDb); err != nil {
|
if err := m.Migrate(a.database); err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
typUrl = "url"
|
||||||
|
typText = "text"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *app) insertRedirect(slug string, url string, typ string) error {
|
||||||
|
a.write.Lock()
|
||||||
|
defer a.write.Unlock()
|
||||||
|
_, err := a.database.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, ?)", slug, url, typ)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *app) deleteSlug(slug string) error {
|
||||||
|
a.write.Lock()
|
||||||
|
defer a.write.Unlock()
|
||||||
|
_, err := a.database.Exec("DELETE FROM redirect WHERE slug = ?", slug)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *app) increaseHits(slug string) {
|
||||||
|
go func() {
|
||||||
|
a.write.Lock()
|
||||||
|
defer a.write.Unlock()
|
||||||
|
_, _ = a.database.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *app) slugExists(slug string) (exists bool, err error) {
|
||||||
|
err = a.database.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
187
main.go
187
main.go
|
@ -8,43 +8,51 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
goshutdowner "git.jlel.se/jlelse/go-shutdowner"
|
gsd "git.jlel.se/jlelse/go-shutdowner"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type app struct {
|
||||||
appDb *sql.DB
|
config *config
|
||||||
appRouter *chi.Mux
|
database *sql.DB
|
||||||
dbWriteLock sync.Mutex
|
write sync.Mutex
|
||||||
shutdown goshutdowner.Shutdowner
|
shutdown gsd.Shutdowner
|
||||||
)
|
}
|
||||||
|
|
||||||
func initRouter() {
|
type config struct {
|
||||||
appRouter = chi.NewRouter()
|
Port int `mapstructure:"port"`
|
||||||
appRouter.Use(middleware.GetHead)
|
DBPath string `mapstructure:"dbPath"`
|
||||||
appRouter.Group(func(r chi.Router) {
|
Password string `mapstructure:"password"`
|
||||||
r.Use(loginMiddleware)
|
ShortUrl string `mapstructure:"shortUrl"`
|
||||||
|
DefaultUrl string `mapstructure:"defaultUrl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *app) initRouter() (router *chi.Mux) {
|
||||||
|
router = chi.NewMux()
|
||||||
|
router.Use(middleware.GetHead)
|
||||||
|
router.Group(func(r chi.Router) {
|
||||||
|
r.Use(a.loginMiddleware)
|
||||||
r.Get("/s", shortenFormHandler)
|
r.Get("/s", shortenFormHandler)
|
||||||
r.Post("/s", shortenHandler)
|
r.Post("/s", a.shortenHandler)
|
||||||
r.Get("/t", shortenTextFormHandler)
|
r.Get("/t", shortenTextFormHandler)
|
||||||
r.Post("/t", shortenTextHandler)
|
r.Post("/t", a.shortenTextHandler)
|
||||||
r.Get("/u", updateFormHandler)
|
r.Get("/u", updateFormHandler)
|
||||||
r.Get("/ut", updateTextFormHandler)
|
r.Get("/ut", updateTextFormHandler)
|
||||||
r.Post("/u", updateHandler)
|
r.Post("/u", a.updateHandler)
|
||||||
r.Get("/d", deleteFormHandler)
|
r.Get("/d", deleteFormHandler)
|
||||||
r.Post("/d", deleteHandler)
|
r.Post("/d", a.deleteHandler)
|
||||||
r.Get("/l", listHandler)
|
r.Get("/l", a.listHandler)
|
||||||
})
|
})
|
||||||
appRouter.Get("/{slug}", shortenedURLHandler)
|
router.Get("/{slug}", a.shortenedURLHandler)
|
||||||
appRouter.Get("/", defaultURLRedirectHandler)
|
router.Get("/", a.defaultURLRedirectHandler)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -58,46 +66,43 @@ func main() {
|
||||||
viper.AddConfigPath(".")
|
viper.AddConfigPath(".")
|
||||||
_ = viper.ReadInConfig()
|
_ = viper.ReadInConfig()
|
||||||
|
|
||||||
if !viper.IsSet("dbPath") {
|
|
||||||
log.Fatal("No database path (dbPath) is configured.")
|
|
||||||
}
|
|
||||||
if !viper.IsSet("password") {
|
if !viper.IsSet("password") {
|
||||||
log.Fatal("No password (password) is configured.")
|
log.Fatal("No password (password) is configured.")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if !viper.IsSet("shortUrl") {
|
if !viper.IsSet("shortUrl") {
|
||||||
log.Fatal("No short URL (shortUrl) is configured.")
|
log.Fatal("No short URL (shortUrl) is configured.")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if !viper.IsSet("defaultUrl") {
|
if !viper.IsSet("defaultUrl") {
|
||||||
log.Fatal("No default URL (defaultUrl) is configured.")
|
log.Fatal("No default URL (defaultUrl) is configured.")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
app := &app{}
|
||||||
dbPath := viper.GetString("dbPath")
|
|
||||||
_ = os.MkdirAll(filepath.Dir(dbPath), 0644)
|
app.config = &config{}
|
||||||
appDb, err = sql.Open("sqlite3", dbPath+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
|
err := viper.Unmarshal(app.config)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to unmarshal config:", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = app.openDatabase()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Error opening database:", err.Error())
|
log.Println("Error opening database:", err.Error())
|
||||||
|
app.shutdown.ShutdownAndWait()
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
shutdown.Add(func() {
|
|
||||||
_ = appDb.Close()
|
|
||||||
log.Println("Closed database")
|
|
||||||
})
|
|
||||||
|
|
||||||
migrateDatabase()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_ = appDb.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
initRouter()
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: ":" + strconv.Itoa(viper.GetInt("port")),
|
Addr: ":" + strconv.Itoa(app.config.Port),
|
||||||
Handler: appRouter,
|
Handler: app.initRouter(),
|
||||||
ReadTimeout: 5 * time.Minute,
|
ReadTimeout: 5 * time.Minute,
|
||||||
WriteTimeout: 5 * time.Minute,
|
WriteTimeout: 5 * time.Minute,
|
||||||
}
|
}
|
||||||
shutdown.Add(func() {
|
app.shutdown.Add(func() {
|
||||||
toc, c := context.WithTimeout(context.Background(), 5*time.Second)
|
toc, c := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer c()
|
defer c()
|
||||||
if err := httpServer.Shutdown(toc); err != nil {
|
if err := httpServer.Shutdown(toc); err != nil {
|
||||||
|
@ -112,12 +117,12 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
shutdown.Wait()
|
app.shutdown.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func loginMiddleware(next http.Handler) http.Handler {
|
func (a *app) loginMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !checkPassword(w, r) {
|
if !a.checkPassword(w, r) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
@ -171,9 +176,9 @@ func generateTextForm(w http.ResponseWriter, title string, url string, fields []
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func shortenHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) shortenHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
writeShortenedURL := func(w http.ResponseWriter, slug string) {
|
writeShortenedURL := func(w http.ResponseWriter, slug string) {
|
||||||
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug))
|
_, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
|
||||||
}
|
}
|
||||||
|
|
||||||
requestURL := r.FormValue("url")
|
requestURL := r.FormValue("url")
|
||||||
|
@ -185,13 +190,13 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
slug := r.FormValue("slug")
|
slug := r.FormValue("slug")
|
||||||
manualSlug := false
|
manualSlug := false
|
||||||
if slug == "" {
|
if slug == "" {
|
||||||
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug)
|
_ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ?", requestURL).Scan(&slug)
|
||||||
} else {
|
} else {
|
||||||
manualSlug = true
|
manualSlug = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if slug != "" {
|
if slug != "" {
|
||||||
if e, _ := slugExists(slug); e {
|
if e, _ := a.slugExists(slug); e {
|
||||||
if manualSlug {
|
if manualSlug {
|
||||||
http.Error(w, "slug already in use", http.StatusBadRequest)
|
http.Error(w, "slug already in use", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -204,7 +209,7 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
var err error
|
var err error
|
||||||
for exists {
|
for exists {
|
||||||
slug = generateSlug()
|
slug = generateSlug()
|
||||||
exists, err = slugExists(slug)
|
exists, err = a.slugExists(slug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -212,21 +217,18 @@ func shortenHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dbWriteLock.Lock()
|
if err := a.insertRedirect(slug, requestURL, typUrl); err != nil {
|
||||||
if _, err := appDb.Exec("INSERT INTO redirect (slug, url) VALUES (?, ?)", slug, requestURL); err != nil {
|
|
||||||
dbWriteLock.Unlock()
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dbWriteLock.Unlock()
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
writeShortenedURL(w, slug)
|
writeShortenedURL(w, slug)
|
||||||
}
|
}
|
||||||
|
|
||||||
func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) shortenTextHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
writeShortenedURL := func(w http.ResponseWriter, slug string) {
|
writeShortenedURL := func(w http.ResponseWriter, slug string) {
|
||||||
_, _ = w.Write([]byte(viper.GetString("shortUrl") + "/" + slug))
|
_, _ = w.Write([]byte(a.config.ShortUrl + "/" + slug))
|
||||||
}
|
}
|
||||||
|
|
||||||
requestText := r.FormValue("text")
|
requestText := r.FormValue("text")
|
||||||
|
@ -238,13 +240,13 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
slug := r.FormValue("slug")
|
slug := r.FormValue("slug")
|
||||||
manualSlug := false
|
manualSlug := false
|
||||||
if slug == "" {
|
if slug == "" {
|
||||||
_ = appDb.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug)
|
_ = a.database.QueryRow("SELECT slug FROM redirect WHERE url = ? and type = 'text'", requestText).Scan(&slug)
|
||||||
} else {
|
} else {
|
||||||
manualSlug = true
|
manualSlug = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if slug != "" {
|
if slug != "" {
|
||||||
if e, _ := slugExists(slug); e {
|
if e, _ := a.slugExists(slug); e {
|
||||||
if manualSlug {
|
if manualSlug {
|
||||||
http.Error(w, "slug already in use", http.StatusBadRequest)
|
http.Error(w, "slug already in use", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -257,7 +259,7 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
var err error
|
var err error
|
||||||
for exists {
|
for exists {
|
||||||
slug = generateSlug()
|
slug = generateSlug()
|
||||||
exists, err = slugExists(slug)
|
exists, err = a.slugExists(slug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -265,19 +267,16 @@ func shortenTextHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dbWriteLock.Lock()
|
if err := a.insertRedirect(slug, requestText, typText); err != nil {
|
||||||
if _, err := appDb.Exec("INSERT INTO redirect (slug, url, type) VALUES (?, ?, 'text')", slug, requestText); err != nil {
|
|
||||||
dbWriteLock.Unlock()
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dbWriteLock.Unlock()
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
writeShortenedURL(w, slug)
|
writeShortenedURL(w, slug)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) updateHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
slug := r.FormValue("slug")
|
slug := r.FormValue("slug")
|
||||||
if slug == "" {
|
if slug == "" {
|
||||||
http.Error(w, "Specify the slug to update", http.StatusBadRequest)
|
http.Error(w, "Specify the slug to update", http.StatusBadRequest)
|
||||||
|
@ -295,55 +294,52 @@ func updateHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
typeString = "url"
|
typeString = "url"
|
||||||
}
|
}
|
||||||
|
|
||||||
if e, err := slugExists(slug); err != nil || !e {
|
if e, err := a.slugExists(slug); err != nil || !e {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dbWriteLock.Lock()
|
a.write.Lock()
|
||||||
if _, err := appDb.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil {
|
if _, err := a.database.Exec("UPDATE redirect SET url = ?, type = ? WHERE slug = ?", newURL, typeString, slug); err != nil {
|
||||||
dbWriteLock.Unlock()
|
a.write.Unlock()
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dbWriteLock.Unlock()
|
a.write.Unlock()
|
||||||
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
_, _ = w.Write([]byte("Slug updated"))
|
_, _ = w.Write([]byte("Slug updated"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) deleteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
slug := r.FormValue("slug")
|
slug := r.FormValue("slug")
|
||||||
if slug == "" {
|
if slug == "" {
|
||||||
http.Error(w, "Specify the slug to delete", http.StatusBadRequest)
|
http.Error(w, "Specify the slug to delete", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if e, err := slugExists(slug); !e || err != nil {
|
if e, err := a.slugExists(slug); !e || err != nil {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dbWriteLock.Lock()
|
if err := a.deleteSlug(slug); err != nil {
|
||||||
if _, err := appDb.Exec("DELETE FROM redirect WHERE slug = ?", slug); err != nil {
|
|
||||||
dbWriteLock.Unlock()
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dbWriteLock.Unlock()
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
_, _ = w.Write([]byte("Slug deleted"))
|
_, _ = w.Write([]byte("Slug deleted"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func listHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) listHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
type row struct {
|
type row struct {
|
||||||
Slug string
|
Slug string
|
||||||
URL string
|
URL string
|
||||||
Hits int
|
Hits int
|
||||||
}
|
}
|
||||||
var list []row
|
var list []row
|
||||||
rows, err := appDb.Query("SELECT slug, url, hits FROM redirect")
|
rows, err := a.database.Query("SELECT slug, url, hits FROM redirect")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -364,16 +360,19 @@ func listHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkPassword(w http.ResponseWriter, r *http.Request) bool {
|
func (a *app) checkPassword(w http.ResponseWriter, r *http.Request) bool {
|
||||||
if r.FormValue("password") == viper.GetString("password") {
|
// Check basic auth
|
||||||
|
if _, pass, ok := r.BasicAuth(); ok && pass == a.config.Password {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if _, pass, ok := r.BasicAuth(); !ok || pass != viper.GetString("password") {
|
// Check query or form param
|
||||||
|
if r.FormValue("password") == a.config.Password {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Require password
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`)
|
w.Header().Set("WWW-Authenticate", `Basic realm="Please enter a password!"`)
|
||||||
http.Error(w, "Not authenticated", http.StatusUnauthorized)
|
http.Error(w, "Not authenticated", http.StatusUnauthorized)
|
||||||
return false
|
return false
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSlug() string {
|
func generateSlug() string {
|
||||||
|
@ -385,34 +384,26 @@ func generateSlug() string {
|
||||||
return string(s)
|
return string(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func slugExists(slug string) (exists bool, err error) {
|
func (a *app) shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
err = appDb.QueryRow("SELECT EXISTS(SELECT 1 FROM redirect WHERE slug = ?)", slug).Scan(&exists)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func shortenedURLHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
slug := chi.URLParam(r, "slug")
|
slug := chi.URLParam(r, "slug")
|
||||||
|
|
||||||
var redirectURL, typeString string
|
var redirectURL, typeString string
|
||||||
err := appDb.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString)
|
err := a.database.QueryRow("SELECT url, type FROM redirect WHERE slug = ?", slug).Scan(&redirectURL, &typeString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
a.increaseHits(slug)
|
||||||
dbWriteLock.Lock()
|
|
||||||
_, _ = appDb.Exec("UPDATE redirect SET hits = hits + 1 WHERE slug = ?", slug)
|
|
||||||
dbWriteLock.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if typeString == "text" {
|
switch typeString {
|
||||||
|
case typText:
|
||||||
_, _ = w.Write([]byte(redirectURL))
|
_, _ = w.Write([]byte(redirectURL))
|
||||||
} else {
|
default:
|
||||||
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
func (a *app) defaultURLRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, viper.GetString("defaultUrl"), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, a.config.DefaultUrl, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
120
main_test.go
120
main_test.go
|
@ -1,43 +1,43 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupFakeDB(t *testing.T) {
|
func testApp(t *testing.T) *app {
|
||||||
var err error
|
app := &app{
|
||||||
appDb, err = sql.Open("sqlite3", filepath.Join(t.TempDir(), "data.db")+"?cache=shared&mode=rwc&_journal_mode=WAL&_busy_timeout=100")
|
config: &config{
|
||||||
if err != nil {
|
DBPath: filepath.Join(t.TempDir(), "data.db"),
|
||||||
t.Fatal(err)
|
},
|
||||||
}
|
}
|
||||||
migrateDatabase()
|
err := app.openDatabase()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return app
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeFakeDB(t *testing.T) {
|
func closeTestApp(_ *testing.T, app *app) {
|
||||||
err := appDb.Close()
|
app.shutdown.ShutdownAndWait()
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_slugExists(t *testing.T) {
|
func Test_slugExists(t *testing.T) {
|
||||||
t.Run("Test slugs", func(t *testing.T) {
|
t.Run("Test slugs", func(t *testing.T) {
|
||||||
setupFakeDB(t)
|
app := testApp(t)
|
||||||
|
|
||||||
exists, err := slugExists("source")
|
exists, err := app.slugExists("source")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, exists)
|
assert.True(t, exists)
|
||||||
exists, err = slugExists("test")
|
exists, err = app.slugExists("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, exists)
|
assert.False(t, exists)
|
||||||
|
|
||||||
closeFakeDB(t)
|
closeTestApp(t, app)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,22 +48,25 @@ func Test_generateSlug(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShortenedUrlHandler(t *testing.T) {
|
func TestShortenedUrlHandler(t *testing.T) {
|
||||||
viper.Set("defaultUrl", "http://long.example.com")
|
|
||||||
t.Run("Test ShortenedUrlHandler", func(t *testing.T) {
|
t.Run("Test ShortenedUrlHandler", func(t *testing.T) {
|
||||||
setupFakeDB(t)
|
app := testApp(t)
|
||||||
initRouter()
|
|
||||||
|
app.config.DefaultUrl = "http://long.example.com"
|
||||||
|
|
||||||
|
router := app.initRouter()
|
||||||
|
|
||||||
t.Run("Test redirect code", func(t *testing.T) {
|
t.Run("Test redirect code", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/source", nil)
|
req := httptest.NewRequest("GET", "http://example.com/source", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
appRouter.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||||
})
|
})
|
||||||
t.Run("Test redirect location header", func(t *testing.T) {
|
t.Run("Test default redirect location header", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/source", nil)
|
req := httptest.NewRequest("GET", "http://example.com/source", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
appRouter.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
||||||
assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location"))
|
assert.Equal(t, "https://git.jlel.se/jlelse/GoShort", resp.Header.Get("Location"))
|
||||||
|
@ -71,7 +74,7 @@ func TestShortenedUrlHandler(t *testing.T) {
|
||||||
t.Run("Test missing slug redirect code", func(t *testing.T) {
|
t.Run("Test missing slug redirect code", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
appRouter.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
@ -79,43 +82,100 @@ func TestShortenedUrlHandler(t *testing.T) {
|
||||||
t.Run("Test no slug mux var", func(t *testing.T) {
|
t.Run("Test no slug mux var", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
appRouter.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||||
assert.Equal(t, "http://long.example.com", resp.Header.Get("Location"))
|
assert.Equal(t, "http://long.example.com", resp.Header.Get("Location"))
|
||||||
})
|
})
|
||||||
closeFakeDB(t)
|
t.Run("Test custom url redirect", func(t *testing.T) {
|
||||||
|
err := app.insertRedirect("customurl", "https://example.net", typUrl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/customurl", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
assert.Equal(t, "https://example.net", resp.Header.Get("Location"))
|
||||||
|
})
|
||||||
|
t.Run("Test custom text", func(t *testing.T) {
|
||||||
|
err := app.insertRedirect("customtext", "Hello!", typText)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/customtext", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, "Hello!", string(respBody))
|
||||||
|
})
|
||||||
|
|
||||||
|
closeTestApp(t, app)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_checkPassword(t *testing.T) {
|
func Test_checkPassword(t *testing.T) {
|
||||||
viper.Set("password", "abc")
|
app := testApp(t)
|
||||||
|
|
||||||
|
app.config.Password = "abc"
|
||||||
|
|
||||||
t.Run("No password", func(t *testing.T) {
|
t.Run("No password", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
|
||||||
assert.False(t, checkPassword(httptest.NewRecorder(), req))
|
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
|
||||||
})
|
})
|
||||||
t.Run("Password via query", func(t *testing.T) {
|
t.Run("Password via query", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test?password=abc", nil)
|
||||||
|
|
||||||
assert.True(t, checkPassword(httptest.NewRecorder(), req))
|
assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
|
||||||
})
|
})
|
||||||
t.Run("Wrong password via query", func(t *testing.T) {
|
t.Run("Wrong password via query", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test?password=wrong", nil)
|
||||||
|
|
||||||
assert.False(t, checkPassword(httptest.NewRecorder(), req))
|
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
|
||||||
})
|
})
|
||||||
t.Run("Password via BasicAuth", func(t *testing.T) {
|
t.Run("Password via BasicAuth", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
req.SetBasicAuth("username", "abc")
|
req.SetBasicAuth("username", "abc")
|
||||||
|
|
||||||
assert.True(t, checkPassword(httptest.NewRecorder(), req))
|
assert.True(t, app.checkPassword(httptest.NewRecorder(), req))
|
||||||
})
|
})
|
||||||
t.Run("Wrong password via BasicAuth", func(t *testing.T) {
|
t.Run("Wrong password via BasicAuth", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
req.SetBasicAuth("username", "wrong")
|
req.SetBasicAuth("username", "wrong")
|
||||||
|
|
||||||
assert.False(t, checkPassword(httptest.NewRecorder(), req))
|
assert.False(t, app.checkPassword(httptest.NewRecorder(), req))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Test login middleware success", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
req.SetBasicAuth("username", "abc")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusNotModified)
|
||||||
|
})).ServeHTTP(w, req)
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusNotModified, resp.StatusCode)
|
||||||
|
})
|
||||||
|
t.Run("Test login middleware fail", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
req.SetBasicAuth("username", "xyz")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
app.loginMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusNotModified)
|
||||||
|
})).ServeHTTP(w, req)
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
})
|
||||||
|
|
||||||
|
closeTestApp(t, app)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue