Flatten redirect chains using recursive sql query

This commit is contained in:
Jan-Lukas Else 2020-08-01 11:39:24 +02:00
parent 484ca8948a
commit b2072fddd2
1 changed files with 7 additions and 14 deletions

View File

@ -10,7 +10,7 @@ import (
var errRedirectNotFound = errors.New("redirect not found") var errRedirectNotFound = errors.New("redirect not found")
func serveRedirect(w http.ResponseWriter, r *http.Request) { func serveRedirect(w http.ResponseWriter, r *http.Request) {
redirect, more, err := getRedirect(r.Context(), slashTrimmedPath(r)) redirect, err := getRedirect(r.Context(), slashTrimmedPath(r))
if err == errRedirectNotFound { if err == errRedirectNotFound {
serve404(w, r) serve404(w, r)
return return
@ -18,12 +18,6 @@ func serveRedirect(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
// Flatten redirects
if more {
for more == true {
redirect, more, _ = getRedirect(r.Context(), trimSlash(redirect))
}
}
// Send redirect // Send redirect
w.Header().Set("Location", redirect) w.Header().Set("Location", redirect)
render(w, templateRedirect, struct { render(w, templateRedirect, struct {
@ -34,17 +28,16 @@ func serveRedirect(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound) w.WriteHeader(http.StatusFound)
} }
func getRedirect(context context.Context, fromPath string) (string, bool, error) { func getRedirect(context context.Context, fromPath string) (string, error) {
var toPath string var toPath string
var moreRedirects int row := appDb.QueryRowContext(context, "with recursive f (i, fp, tp) as (select 1, fromPath, toPath from redirects where fromPath = ? union all select f.i + 1, r.fromPath, r.toPath from redirects as r join f on f.tp = r.fromPath) select tp from f order by i desc limit 1", fromPath)
row := appDb.QueryRowContext(context, "select toPath, (select count(*) from redirects where fromPath=(select toPath from redirects where fromPath=?)) as more from redirects where fromPath=?", fromPath, fromPath) err := row.Scan(&toPath)
err := row.Scan(&toPath, &moreRedirects)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", false, errRedirectNotFound return "", errRedirectNotFound
} else if err != nil { } else if err != nil {
return "", false, err return "", err
} }
return toPath, moreRedirects > 0, nil return toPath, nil
} }
func allRedirectPaths() ([]string, error) { func allRedirectPaths() ([]string, error) {