diff --git a/cache.go b/cache.go index bf2925e..063bc8b 100644 --- a/cache.go +++ b/cache.go @@ -2,7 +2,9 @@ package main import ( "context" + "encoding/json" "net/http" + "net/http/httptest" "net/url" "time" ) @@ -10,18 +12,36 @@ import ( func CacheMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestUrl, _ := url.ParseRequestURI(r.RequestURI) + path := SlashTrimmedPath(r) if appConfig.cache.enable && // Check bypass query !(requestUrl != nil && requestUrl.Query().Get("cache") == "0") { - mime, t, cache := getCache(SlashTrimmedPath(r), r.Context()) - if cache == nil { - next.ServeHTTP(w, r) + cacheTime, header, body := getCache(path, r.Context()) + if cacheTime == 0 { + recorder := httptest.NewRecorder() + next.ServeHTTP(recorder, r) + // Copy values from recorder + code := recorder.Code + // Send response + for k, v := range recorder.Header() { + w.Header()[k] = v + } + w.Header().Set("GoBlog-Cache", "MISS") + w.WriteHeader(code) + _, _ = w.Write(recorder.Body.Bytes()) + // Save cache + if code == http.StatusOK { + saveCache(path, recorder.Header(), recorder.Body.Bytes()) + } return } else { - expiresTime := time.Unix(t+appConfig.cache.expiration, 0).Format(time.RFC1123) + expiresTime := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) + for k, v := range header { + w.Header()[k] = v + } w.Header().Set("Expires", expiresTime) - w.Header().Set("Content-Type", mime) - _, _ = w.Write(cache) + w.Header().Set("GoBlog-Cache", "HIT") + _, _ = w.Write(body) } } else { next.ServeHTTP(w, r) @@ -29,25 +49,26 @@ func CacheMiddleware(next http.Handler) http.Handler { }) } -func getCache(path string, context context.Context) (string, int64, []byte) { - var mime string - var t int64 - var cache []byte +func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) { + var headerBytes []byte allowedTime := time.Now().Unix() - appConfig.cache.expiration - row := appDb.QueryRowContext(context, "select COALESCE(mime, ''), COALESCE(time, 0), value from cache where path=? and time>=?", path, allowedTime) - _ = row.Scan(&mime, &t, &cache) - return mime, t, cache + row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime) + _ = row.Scan(&creationTime, &headerBytes, &body) + header = make(map[string][]string) + _ = json.Unmarshal(headerBytes, &header) + return } -func saveCache(path string, mime string, value []byte) { +func saveCache(path string, header map[string][]string, body []byte) { now := time.Now().Unix() + headerBytes, _ := json.Marshal(header) startWritingToDb() tx, err := appDb.Begin() if err != nil { return } _, _ = tx.Exec("delete from cache where time