diff --git a/cache.go b/cache.go index feeecc1..b8b2bf9 100644 --- a/cache.go +++ b/cache.go @@ -1,14 +1,16 @@ package main import ( + "bytes" + "crypto/sha256" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" "strconv" "time" - "github.com/araddon/dateparse" lru "github.com/hashicorp/golang-lru" "golang.org/x/sync/singleflight" ) @@ -40,25 +42,21 @@ func cacheMiddleware(next http.Handler) http.Handler { return getCache(key, next, r), nil }) cache := cacheInterface.(*cacheItem) - cacheTimeString := time.Unix(cache.creationTime, 0).Format(time.RFC1123) - expiresTimeString := "" + var expiresIn int64 = 0 if cache.expiration != 0 { - expiresTimeString = time.Unix(cache.creationTime+int64(cache.expiration), 0).Format(time.RFC1123) - } - // check conditional request - if ifModifiedSinceHeader := r.Header.Get("If-Modified-Since"); ifModifiedSinceHeader != "" { - if t, _ := dateparse.ParseIn(ifModifiedSinceHeader, time.Local); t.Unix() == cache.creationTime { - // send 304 - setCacheHeaders(w, cacheTimeString, expiresTimeString) - w.WriteHeader(http.StatusNotModified) - return - } + expiresIn = cache.creationTime + int64(cache.expiration) - time.Now().Unix() } // copy cached headers for k, v := range cache.header { w.Header()[k] = v } - setCacheHeaders(w, cacheTimeString, expiresTimeString) + setCacheHeaders(w, cache.hash, expiresIn) + // check conditional request + if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == cache.hash { + // send 304 + w.WriteHeader(http.StatusNotModified) + return + } // set status code w.WriteHeader(cache.code) // write cached body @@ -74,21 +72,21 @@ func cacheKey(r *http.Request) string { return r.URL.String() } -func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeString string) { +func setCacheHeaders(w http.ResponseWriter, hash string, expiresIn int64) { w.Header().Del(cacheInternalExpirationHeader) - w.Header().Set("Last-Modified", cacheTimeString) - if expiresTimeString != "" { + w.Header().Set("ETag", hash) + if expiresIn != 0 { // Set expires time - w.Header().Set("Cache-Control", "public") - w.Header().Set("Expires", expiresTimeString) + w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d", expiresIn)) } else { w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d", appConfig.Cache.Expiration, appConfig.Cache.Expiration/3)) } } type cacheItem struct { - creationTime int64 expiration int + creationTime int64 + hash string code int header http.Header body []byte @@ -113,10 +111,15 @@ func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) // Cache values from recorder result := recorder.Result() body, _ := ioutil.ReadAll(result.Body) + _ = result.Body.Close() + h := sha256.New() + _, _ = io.Copy(h, bytes.NewReader(body)) + hash := fmt.Sprintf("%x", h.Sum(nil)) exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader)) item = &cacheItem{ - creationTime: time.Now().Unix(), expiration: exp, + creationTime: time.Now().Unix(), + hash: hash, code: result.StatusCode, header: result.Header, body: body,