diff --git a/cache.go b/cache.go index b66bef7..13ff59c 100644 --- a/cache.go +++ b/cache.go @@ -9,14 +9,14 @@ import ( "time" ) -func CacheMiddleware(next http.Handler) http.Handler { +func cacheMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestUrl, _ := url.ParseRequestURI(r.RequestURI) - path := SlashTrimmedPath(r) + path := slashTrimmedPath(r) if appConfig.cache.enable && // check bypass query !(requestUrl != nil && requestUrl.Query().Get("cache") == "0") { - cacheTime, header, body := getCache(path, r.Context()) + cacheTime, header, body := getCache(r.Context(), path) if cacheTime == 0 { recorder := httptest.NewRecorder() next.ServeHTTP(recorder, r) @@ -36,31 +36,29 @@ func CacheMiddleware(next http.Handler) http.Handler { saveCache(path, now, recorder.Header(), recorder.Body.Bytes()) } return - } else { - cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) - expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) - // check conditional request - ifModifiedSinceHeader := r.Header.Get("If-Modified-Since") - if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString { - setCacheHeaders(w, cacheTimeString, expiresTimeString) - // send 304 - w.WriteHeader(http.StatusNotModified) - return - } - // copy cached headers - for k, v := range header { - w.Header()[k] = v - } + } + cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) + expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) + // check conditional request + ifModifiedSinceHeader := r.Header.Get("If-Modified-Since") + if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString { setCacheHeaders(w, cacheTimeString, expiresTimeString) - w.Header().Set("GoBlog-Cache", "HIT") - // write cached body - _, _ = w.Write(body) + // send 304 + w.WriteHeader(http.StatusNotModified) return } - } else { - next.ServeHTTP(w, r) + // copy cached headers + for k, v := range header { + w.Header()[k] = v + } + setCacheHeaders(w, cacheTimeString, expiresTimeString) + w.Header().Set("GoBlog-Cache", "HIT") + // write cached body + _, _ = w.Write(body) return } + next.ServeHTTP(w, r) + return }) } @@ -70,7 +68,7 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS w.Header().Set("Expires", expiresTimeString) } -func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) { +func getCache(context context.Context, path string) (creationTime int64, header map[string][]string, body []byte) { var headerBytes []byte allowedTime := time.Now().Unix() - appConfig.cache.expiration row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime) diff --git a/database.go b/database.go index dad44b9..012493d 100644 --- a/database.go +++ b/database.go @@ -30,7 +30,7 @@ func closeDb() error { return appDb.Close() } -func vacuumDb() { +func vacuumDb() { startWritingToDb() _, _ = appDb.Exec("VACUUM;") finishWritingToDb() diff --git a/http.go b/http.go index 95b85f8..0490e0f 100644 --- a/http.go +++ b/http.go @@ -61,22 +61,20 @@ func buildHandler() (http.Handler, error) { allPostPaths, err := allPostPaths() if err != nil { return nil, err - } else { - for _, path := range allPostPaths { - if path != "" { - r.With(CacheMiddleware).Get(path, servePost) - } + } + for _, path := range allPostPaths { + if path != "" { + r.With(cacheMiddleware).Get(path, servePost) } } allRedirectPaths, err := allRedirectPaths() if err != nil { return nil, err - } else { - for _, path := range allRedirectPaths { - if path != "" { - r.Get(path, serveRedirect) - } + } + for _, path := range allRedirectPaths { + if path != "" { + r.Get(path, serveRedirect) } } @@ -104,7 +102,7 @@ func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { d.realHandler.ServeHTTP(w, r) } -func SlashTrimmedPath(r *http.Request) string { +func slashTrimmedPath(r *http.Request) string { path := r.URL.Path if len(path) > 1 { path = strings.TrimSuffix(path, "/") diff --git a/posts.go b/posts.go index a2f9e82..0e4e594 100644 --- a/posts.go +++ b/posts.go @@ -7,7 +7,7 @@ import ( "net/http" ) -var postNotFound = errors.New("post not found") +var errPostNotFound = errors.New("post not found") type post struct { path string @@ -18,9 +18,9 @@ type post struct { } func servePost(w http.ResponseWriter, r *http.Request) { - path := SlashTrimmedPath(r) - post, err := getPost(path, r.Context()) - if err == postNotFound { + path := slashTrimmedPath(r) + post, err := getPost(r.Context(), path) + if err == errPostNotFound { http.NotFound(w, r) return } else if err != nil { @@ -36,12 +36,12 @@ func servePost(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(htmlContent) } -func getPost(path string, context context.Context) (*post, error) { +func getPost(context context.Context, path string) (*post, error) { queriedPost := &post{} row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path) err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated) if err == sql.ErrNoRows { - return nil, postNotFound + return nil, errPostNotFound } else if err != nil { return nil, err } diff --git a/redirects.go b/redirects.go index 144fb3d..80c11e8 100644 --- a/redirects.go +++ b/redirects.go @@ -7,27 +7,26 @@ import ( "net/http" ) -var redirectNotFound = errors.New("redirect not found") +var errRedirectNotFound = errors.New("redirect not found") func serveRedirect(w http.ResponseWriter, r *http.Request) { - redirect, err := getRedirect(SlashTrimmedPath(r), r.Context()) - if err == redirectNotFound { + redirect, err := getRedirect(r.Context(), slashTrimmedPath(r)) + if err == errRedirectNotFound { http.NotFound(w, r) return } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - // TODO: Change status code - http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) + http.Redirect(w, r, redirect, http.StatusFound) } -func getRedirect(fromPath string, context context.Context) (string, error) { +func getRedirect(context context.Context, fromPath string) (string, error) { var toPath string row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath) err := row.Scan(&toPath) if err == sql.ErrNoRows { - return "", redirectNotFound + return "", errRedirectNotFound } else if err != nil { return "", err }