diff --git a/redirects.go b/redirects.go index 9ea8f4f..3ad41e5 100644 --- a/redirects.go +++ b/redirects.go @@ -10,7 +10,7 @@ import ( var errRedirectNotFound = errors.New("redirect not found") func serveRedirect(w http.ResponseWriter, r *http.Request) { - redirect, more, err := getRedirect(r.Context(), slashTrimmedPath(r)) + redirect, err := getRedirect(r.Context(), slashTrimmedPath(r)) if err == errRedirectNotFound { serve404(w, r) return @@ -18,12 +18,6 @@ func serveRedirect(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - // Flatten redirects - if more { - for more == true { - redirect, more, _ = getRedirect(r.Context(), trimSlash(redirect)) - } - } // Send redirect w.Header().Set("Location", redirect) render(w, templateRedirect, struct { @@ -34,17 +28,16 @@ func serveRedirect(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusFound) } -func getRedirect(context context.Context, fromPath string) (string, bool, error) { +func getRedirect(context context.Context, fromPath string) (string, error) { var toPath string - var moreRedirects int - row := appDb.QueryRowContext(context, "select toPath, (select count(*) from redirects where fromPath=(select toPath from redirects where fromPath=?)) as more from redirects where fromPath=?", fromPath, fromPath) - err := row.Scan(&toPath, &moreRedirects) + row := appDb.QueryRowContext(context, "with recursive f (i, fp, tp) as (select 1, fromPath, toPath from redirects where fromPath = ? union all select f.i + 1, r.fromPath, r.toPath from redirects as r join f on f.tp = r.fromPath) select tp from f order by i desc limit 1", fromPath) + err := row.Scan(&toPath) if err == sql.ErrNoRows { - return "", false, errRedirectNotFound + return "", errRedirectNotFound } else if err != nil { - return "", false, err + return "", err } - return toPath, moreRedirects > 0, nil + return toPath, nil } func allRedirectPaths() ([]string, error) {