mirror of https://github.com/jlelse/GoBlog
Add redirects table and simple redirect functionality
This commit is contained in:
parent
8dbf795902
commit
afab7686f8
|
@ -17,6 +17,13 @@ func migrateDb() error {
|
|||
return err
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "00002",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
_, err := tx.Exec("create table redirects (fromPath text not null, toPath text not null, primary key (fromPath, toPath));")
|
||||
return err
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
|
|
23
http.go
23
http.go
|
@ -9,6 +9,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
@ -62,7 +63,18 @@ func buildHandler() (http.Handler, error) {
|
|||
} else {
|
||||
for _, path := range allPostPaths {
|
||||
if path != "" {
|
||||
r.Get(path, servePost)
|
||||
r.With(TrimSlash).Get(path, servePost)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allRedirectPaths, err := allRedirectPaths()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
for _, path := range allRedirectPaths {
|
||||
if path != "" {
|
||||
r.With(TrimSlash).Get(path, serveRedirect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -90,3 +102,12 @@ func (d *dynamicHandler) swapHandler(h http.Handler) {
|
|||
func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
d.realHandler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func TrimSlash(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if len(r.RequestURI) > 1 {
|
||||
r.RequestURI = strings.TrimSuffix(r.RequestURI, "/")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
7
posts.go
7
posts.go
|
@ -5,7 +5,6 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var postNotFound = errors.New("post not found")
|
||||
|
@ -18,11 +17,7 @@ type post struct {
|
|||
}
|
||||
|
||||
func servePost(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.RequestURI
|
||||
if len(path) > 1 {
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
}
|
||||
post, err := getPost(path, r.Context())
|
||||
post, err := getPost(r.RequestURI, r.Context())
|
||||
if err == postNotFound {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var redirectNotFound = errors.New("redirect not found")
|
||||
|
||||
type redirect struct {
|
||||
fromPath string
|
||||
toPath string
|
||||
}
|
||||
|
||||
func serveRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
redirect, err := getRedirect(r.RequestURI, r.Context())
|
||||
if err == redirectNotFound {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
} else if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// TODO: Change status code
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func getRedirect(fromPath string, context context.Context) (string, error) {
|
||||
var toPath string
|
||||
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
|
||||
err := row.Scan(&toPath)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", redirectNotFound
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return toPath, nil
|
||||
}
|
||||
|
||||
func allRedirectPaths() ([]string, error) {
|
||||
var redirectPaths []string
|
||||
rows, err := appDb.Query("select fromPath from redirects")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var path string
|
||||
_ = rows.Scan(&path)
|
||||
redirectPaths = append(redirectPaths, path)
|
||||
}
|
||||
return redirectPaths, nil
|
||||
}
|
Loading…
Reference in New Issue