diff --git a/database_migrations.go b/database_migrations.go index ab88afb..3b70ec2 100644 --- a/database_migrations.go +++ b/database_migrations.go @@ -17,6 +17,13 @@ func migrateDb() error { return err }, }, + &migrator.Migration{ + Name: "00002", + Func: func(tx *sql.Tx) error { + _, err := tx.Exec("create table redirects (fromPath text not null, toPath text not null, primary key (fromPath, toPath));") + return err + }, + }, ), ) if err != nil { diff --git a/http.go b/http.go index 35202d0..cc61b9f 100644 --- a/http.go +++ b/http.go @@ -9,6 +9,7 @@ import ( "os" "os/signal" "strconv" + "strings" "sync" "syscall" "time" @@ -62,7 +63,18 @@ func buildHandler() (http.Handler, error) { } else { for _, path := range allPostPaths { if path != "" { - r.Get(path, servePost) + r.With(TrimSlash).Get(path, servePost) + } + } + } + + allRedirectPaths, err := allRedirectPaths() + if err != nil { + return nil, err + } else { + for _, path := range allRedirectPaths { + if path != "" { + r.With(TrimSlash).Get(path, serveRedirect) } } } @@ -90,3 +102,12 @@ func (d *dynamicHandler) swapHandler(h http.Handler) { func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { d.realHandler.ServeHTTP(w, r) } + +func TrimSlash(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.RequestURI) > 1 { + r.RequestURI = strings.TrimSuffix(r.RequestURI, "/") + } + next.ServeHTTP(w, r) + }) +} diff --git a/posts.go b/posts.go index 20cdbad..1a8ae84 100644 --- a/posts.go +++ b/posts.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "net/http" - "strings" ) var postNotFound = errors.New("post not found") @@ -18,11 +17,7 @@ type post struct { } func servePost(w http.ResponseWriter, r *http.Request) { - path := r.RequestURI - if len(path) > 1 { - path = strings.TrimSuffix(path, "/") - } - post, err := getPost(path, r.Context()) + post, err := getPost(r.RequestURI, r.Context()) if err == postNotFound { http.NotFound(w, r) return diff --git a/redirects.go b/redirects.go new file mode 100644 index 0000000..266e524 --- /dev/null +++ b/redirects.go @@ -0,0 +1,54 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "net/http" +) + +var redirectNotFound = errors.New("redirect not found") + +type redirect struct { + fromPath string + toPath string +} + +func serveRedirect(w http.ResponseWriter, r *http.Request) { + redirect, err := getRedirect(r.RequestURI, r.Context()) + if err == redirectNotFound { + http.NotFound(w, r) + return + } else if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // TODO: Change status code + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) +} + +func getRedirect(fromPath string, context context.Context) (string, error) { + var toPath string + row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath) + err := row.Scan(&toPath) + if err == sql.ErrNoRows { + return "", redirectNotFound + } else if err != nil { + return "", err + } + return toPath, nil +} + +func allRedirectPaths() ([]string, error) { + var redirectPaths []string + rows, err := appDb.Query("select fromPath from redirects") + if err != nil { + return nil, err + } + for rows.Next() { + var path string + _ = rows.Scan(&path) + redirectPaths = append(redirectPaths, path) + } + return redirectPaths, nil +}