diff --git a/cache.go b/cache.go index 063bc8b..b66bef7 100644 --- a/cache.go +++ b/cache.go @@ -14,41 +14,62 @@ func CacheMiddleware(next http.Handler) http.Handler { requestUrl, _ := url.ParseRequestURI(r.RequestURI) path := SlashTrimmedPath(r) if appConfig.cache.enable && - // Check bypass query + // check bypass query !(requestUrl != nil && requestUrl.Query().Get("cache") == "0") { cacheTime, header, body := getCache(path, r.Context()) if cacheTime == 0 { recorder := httptest.NewRecorder() next.ServeHTTP(recorder, r) - // Copy values from recorder + // copy values from recorder code := recorder.Code - // Send response + // send response for k, v := range recorder.Header() { w.Header()[k] = v } + now := time.Now() + setCacheHeaders(w, now.Format(time.RFC1123), time.Unix(now.Unix()+appConfig.cache.expiration, 0).Format(time.RFC1123)) w.Header().Set("GoBlog-Cache", "MISS") w.WriteHeader(code) _, _ = w.Write(recorder.Body.Bytes()) // Save cache if code == http.StatusOK { - saveCache(path, recorder.Header(), recorder.Body.Bytes()) + saveCache(path, now, recorder.Header(), recorder.Body.Bytes()) } return } else { - expiresTime := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) + cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) + expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) + // check conditional request + ifModifiedSinceHeader := r.Header.Get("If-Modified-Since") + if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString { + setCacheHeaders(w, cacheTimeString, expiresTimeString) + // send 304 + w.WriteHeader(http.StatusNotModified) + return + } + // copy cached headers for k, v := range header { w.Header()[k] = v } - w.Header().Set("Expires", expiresTime) + setCacheHeaders(w, cacheTimeString, expiresTimeString) w.Header().Set("GoBlog-Cache", "HIT") + // write cached body _, _ = w.Write(body) + return } } else { next.ServeHTTP(w, r) + return } }) } +func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeString string) { + w.Header().Set("Cache-Control", "public") + w.Header().Set("Last-Modified", cacheTimeString) + w.Header().Set("Expires", expiresTimeString) +} + func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) { var headerBytes []byte allowedTime := time.Now().Unix() - appConfig.cache.expiration @@ -59,16 +80,15 @@ func getCache(path string, context context.Context) (creationTime int64, header return } -func saveCache(path string, header map[string][]string, body []byte) { - now := time.Now().Unix() +func saveCache(path string, now time.Time, header map[string][]string, body []byte) { headerBytes, _ := json.Marshal(header) startWritingToDb() tx, err := appDb.Begin() if err != nil { return } - _, _ = tx.Exec("delete from cache where time