package main import ( "context" "net/http" "net/url" "sort" "time" "github.com/dgraph-io/ristretto" "go.goblog.app/app/pkgs/bufferpool" "golang.org/x/sync/singleflight" ) const ( cacheLoggedInKey contextKey = "cacheLoggedIn" cacheExpirationKey contextKey = "cacheExpiration" cacheControl = "Cache-Control" ) type cache struct { g singleflight.Group c *ristretto.Cache } func (a *goBlog) initCache() (err error) { a.cache = &cache{} if a.cfg.Cache != nil && !a.cfg.Cache.Enable { // Cache disabled return nil } a.cache.c, err = ristretto.NewCache(&ristretto.Config{ NumCounters: 40 * 1000, // 4000 items when full with 5 KB items -> x10 = 40.000 MaxCost: 20 * 1000 * 1000, // 20 MB BufferItems: 64, // recommended Metrics: true, }) go func() { ticker := time.NewTicker(15 * time.Minute) for range ticker.C { met := a.cache.c.Metrics a.info("Cache metrics", "metrics", met.String()) } }() return } func cacheLoggedIn(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cacheLoggedInKey, true))) }) } func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Do checks if a.cache.c == nil || !cacheable(r) { next.ServeHTTP(w, r) return } // Check login if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli { // Continue caching, but remove login setLoggedIn(r, false) } else if a.isLoggedIn(r) { // Don't cache logged in requests next.ServeHTTP(w, r) return } // Search and serve cache key := cacheKey(r) // Get cache or render it cacheInterface, _, _ := a.cache.g.Do(key, func() (any, error) { return a.cache.getCache(key, next, r), nil }) ci := cacheInterface.(*cacheItem) // copy and set headers a.setCacheHeaders(w, ci) // check conditional request if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag { // send 304 w.WriteHeader(http.StatusNotModified) return } // set status code w.WriteHeader(ci.code) // write cached body _, _ = w.Write(ci.body) }) } func cacheable(r *http.Request) bool { if r.Method != http.MethodGet && r.Method != http.MethodHead { return false } if r.URL.Query().Get("cache") == "0" || r.URL.Query().Get("cache") == "false" { return false } return true } func cacheKey(r *http.Request) (key string) { buf := bufferpool.Get() // Special cases if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { _, _ = buf.WriteString("as-") } if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { _, _ = buf.WriteString("tor-") } // Add cache URL _, _ = buf.WriteString(r.URL.EscapedPath()) if query := r.URL.Query(); len(query) > 0 { _ = buf.WriteByte('?') keys := make([]string, 0, len(query)) for k := range query { keys = append(keys, k) } sort.Strings(keys) for i, k := range keys { keyEscaped := url.QueryEscape(k) for j, val := range query[k] { if i > 0 || j > 0 { buf.WriteByte('&') } buf.WriteString(keyEscaped) buf.WriteByte('=') buf.WriteString(url.QueryEscape(val)) } } } // Get key as string key = buf.String() // Return buffer to pool bufferpool.Put(buf) return } func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { // Copy headers for k, v := range cache.header.Clone() { w.Header()[k] = v } // Set cache headers w.Header().Set("ETag", cache.eTag) w.Header().Set(cacheControl, "public,no-cache") } type cacheItem struct { expiration int eTag string code int header http.Header body []byte } // Calculate byte size of cache item using size of header, body and etag func (ci *cacheItem) cost() int { headerBuf := bufferpool.Get() _ = ci.header.Write(headerBuf) headerSize := len(headerBuf.Bytes()) bufferpool.Put(headerBuf) return headerSize + len(ci.body) + len(ci.eTag) } func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem { if rItem, ok := c.c.Get(key); ok { return rItem.(*cacheItem) } // No cache available // Make and use copy of r //nolint:contextcheck cr := r.Clone(valueOnlyContext{r.Context()}) // Remove problematic headers cr.Header.Del("If-Modified-Since") cr.Header.Del("If-Unmodified-Since") cr.Header.Del("If-None-Match") cr.Header.Del("If-Match") cr.Header.Del("If-Range") cr.Header.Del("Range") // Record request rec := newCacheRecorder() next.ServeHTTP(rec, cr) item := rec.finish() // Set expiration item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int) // Remove problematic headers item.header.Del("Accept-Ranges") item.header.Del("ETag") item.header.Del("Last-Modified") // Save cache if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") { cost := int64(item.cost()) if item.expiration == 0 { c.c.Set(key, item, cost) } else { c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second) } } return item } func (c *cache) purge() { if c == nil { return } c.c.Clear() } func (a *goBlog) defaultCacheExpiration() int { if a.cfg.Cache != nil { return a.cfg.Cache.Expiration } return 0 }