mirror of https://github.com/jlelse/GoBlog
Simple blogging system written in Go
https://goblog.app
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
245 lines
6.4 KiB
245 lines
6.4 KiB
package main |
|
|
|
import ( |
|
"context" |
|
"crypto/sha256" |
|
"fmt" |
|
"log" |
|
"net/http" |
|
"net/url" |
|
"sort" |
|
"time" |
|
|
|
"github.com/araddon/dateparse" |
|
"github.com/dgraph-io/ristretto" |
|
"go.goblog.app/app/pkgs/bufferpool" |
|
"golang.org/x/sync/singleflight" |
|
) |
|
|
|
const ( |
|
cacheLoggedInKey contextKey = "cacheLoggedIn" |
|
cacheExpirationKey contextKey = "cacheExpiration" |
|
|
|
lastModified = "Last-Modified" |
|
cacheControl = "Cache-Control" |
|
) |
|
|
|
type cache struct { |
|
g singleflight.Group |
|
c *ristretto.Cache |
|
} |
|
|
|
func (a *goBlog) initCache() (err error) { |
|
a.cache = &cache{} |
|
if a.cfg.Cache != nil && !a.cfg.Cache.Enable { |
|
// Cache disabled |
|
return nil |
|
} |
|
a.cache.c, err = ristretto.NewCache(&ristretto.Config{ |
|
NumCounters: 40 * 1000, // 4000 items when full with 5 KB items -> x10 = 40.000 |
|
MaxCost: 20 * 1000 * 1000, // 20 MB |
|
BufferItems: 64, // recommended |
|
Metrics: true, |
|
}) |
|
go func() { |
|
ticker := time.NewTicker(15 * time.Minute) |
|
for range ticker.C { |
|
met := a.cache.c.Metrics |
|
log.Println("\nCache:", met.String()) |
|
} |
|
}() |
|
return |
|
} |
|
|
|
func cacheLoggedIn(next http.Handler) http.Handler { |
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cacheLoggedInKey, true))) |
|
}) |
|
} |
|
|
|
func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { |
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
// Do checks |
|
if a.cache.c == nil || !cacheable(r) { |
|
next.ServeHTTP(w, r) |
|
return |
|
} |
|
// Check login |
|
if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli { |
|
// Continue caching, but remove login |
|
setLoggedIn(r, false) |
|
} else if a.isLoggedIn(r) { |
|
// Don't cache logged in requests |
|
next.ServeHTTP(w, r) |
|
return |
|
} |
|
// Search and serve cache |
|
key := cacheKey(r) |
|
// Get cache or render it |
|
cacheInterface, _, _ := a.cache.g.Do(key, func() (any, error) { |
|
return a.cache.getCache(key, next, r), nil |
|
}) |
|
ci := cacheInterface.(*cacheItem) |
|
// copy and set headers |
|
a.setCacheHeaders(w, ci) |
|
// check conditional request |
|
if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag { |
|
// send 304 |
|
w.WriteHeader(http.StatusNotModified) |
|
return |
|
} |
|
if ifModifiedSinceHeader := r.Header.Get("If-Modified-Since"); ifModifiedSinceHeader != "" { |
|
if t, err := dateparse.ParseAny(ifModifiedSinceHeader); err == nil && t.After(ci.creationTime) { |
|
// send 304 |
|
w.WriteHeader(http.StatusNotModified) |
|
return |
|
} |
|
} |
|
// set status code |
|
w.WriteHeader(ci.code) |
|
// write cached body |
|
_, _ = w.Write(ci.body) |
|
}) |
|
} |
|
|
|
func cacheable(r *http.Request) bool { |
|
if r.Method != http.MethodGet && r.Method != http.MethodHead { |
|
return false |
|
} |
|
if r.URL.Query().Get("cache") == "0" || r.URL.Query().Get("cache") == "false" { |
|
return false |
|
} |
|
return true |
|
} |
|
|
|
func cacheKey(r *http.Request) (key string) { |
|
buf := bufferpool.Get() |
|
// Special cases |
|
if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { |
|
_, _ = buf.WriteString("as-") |
|
} |
|
if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { |
|
_, _ = buf.WriteString("tor-") |
|
} |
|
// Add cache URL |
|
_, _ = buf.WriteString(r.URL.EscapedPath()) |
|
if query := r.URL.Query(); len(query) > 0 { |
|
_ = buf.WriteByte('?') |
|
keys := make([]string, 0, len(query)) |
|
for k := range query { |
|
keys = append(keys, k) |
|
} |
|
sort.Strings(keys) |
|
for i, k := range keys { |
|
keyEscaped := url.QueryEscape(k) |
|
for j, val := range query[k] { |
|
if i > 0 || j > 0 { |
|
buf.WriteByte('&') |
|
} |
|
buf.WriteString(keyEscaped) |
|
buf.WriteByte('=') |
|
buf.WriteString(url.QueryEscape(val)) |
|
} |
|
} |
|
} |
|
// Get key as string |
|
key = buf.String() |
|
// Return buffer to pool |
|
bufferpool.Put(buf) |
|
return |
|
} |
|
|
|
func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { |
|
// Copy headers |
|
for k, v := range cache.header.Clone() { |
|
w.Header()[k] = v |
|
} |
|
// Set cache headers |
|
w.Header().Set("ETag", cache.eTag) |
|
w.Header().Set(lastModified, cache.creationTime.UTC().Format(http.TimeFormat)) |
|
if w.Header().Get(cacheControl) == "" { |
|
if cache.expiration != 0 { |
|
w.Header().Set(cacheControl, fmt.Sprintf("public,max-age=%d,stale-while-revalidate=%d", cache.expiration, cache.expiration)) |
|
} else { |
|
exp := a.cfg.Cache.Expiration |
|
w.Header().Set(cacheControl, fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", exp, exp/3, exp)) |
|
} |
|
} |
|
} |
|
|
|
type cacheItem struct { |
|
expiration int |
|
creationTime time.Time |
|
eTag string |
|
code int |
|
header http.Header |
|
body []byte |
|
} |
|
|
|
// Calculate byte size of cache item using size of header, body and etag |
|
func (ci *cacheItem) cost() int { |
|
headerBuf := bufferpool.Get() |
|
_ = ci.header.Write(headerBuf) |
|
headerSize := len(headerBuf.Bytes()) |
|
bufferpool.Put(headerBuf) |
|
return headerSize + len(ci.body) + len(ci.eTag) |
|
} |
|
|
|
func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem { |
|
if rItem, ok := c.c.Get(key); ok { |
|
return rItem.(*cacheItem) |
|
} |
|
// No cache available |
|
// Make and use copy of r |
|
cr := r.Clone(valueOnlyContext{r.Context()}) |
|
// Remove problematic headers |
|
cr.Header.Del("If-Modified-Since") |
|
cr.Header.Del("If-Unmodified-Since") |
|
cr.Header.Del("If-None-Match") |
|
cr.Header.Del("If-Match") |
|
cr.Header.Del("If-Range") |
|
cr.Header.Del("Range") |
|
// Record request |
|
rec := newCacheRecorder() |
|
next.ServeHTTP(rec, cr) |
|
item := rec.finish() |
|
// Set eTag |
|
item.eTag = item.header.Get("ETag") |
|
if item.eTag == "" { |
|
item.eTag = fmt.Sprintf("%x", sha256.Sum256(item.body)) |
|
} |
|
// Set creation time |
|
item.creationTime = time.Now() |
|
if lm := item.header.Get(lastModified); lm != "" { |
|
if parsedTime, te := dateparse.ParseLocal(lm); te == nil { |
|
item.creationTime = parsedTime |
|
} |
|
} |
|
// Set expiration |
|
item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int) |
|
// Remove problematic headers |
|
item.header.Del("Accept-Ranges") |
|
item.header.Del("ETag") |
|
item.header.Del(lastModified) |
|
// Save cache |
|
if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") { |
|
cost := int64(item.cost()) |
|
if item.expiration == 0 { |
|
c.c.Set(key, item, cost) |
|
} else { |
|
c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second) |
|
} |
|
} |
|
return item |
|
} |
|
|
|
func (c *cache) purge() { |
|
c.c.Clear() |
|
} |
|
|
|
func (a *goBlog) defaultCacheExpiration() int { |
|
if a.cfg.Cache != nil { |
|
return a.cfg.Cache.Expiration |
|
} |
|
return 0 |
|
}
|
|
|