From 46548df638b91c2df85ec785780edc186ab8326f Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Mon, 21 Feb 2022 18:47:41 +0100 Subject: [PATCH] Improved cache efficiency --- cache.go | 144 +++++++++++++++++++++++++---------------------- cacheRecorder.go | 64 +++++++++++++++++++++ cache_test.go | 71 +++++++++++++++++++++++ utils.go | 6 +- 4 files changed, 215 insertions(+), 70 deletions(-) create mode 100644 cacheRecorder.go create mode 100644 cache_test.go diff --git a/cache.go b/cache.go index 8448af7..8450192 100644 --- a/cache.go +++ b/cache.go @@ -3,16 +3,15 @@ package main import ( "context" "crypto/sha256" - "encoding/binary" "fmt" - "io" "net/http" - "net/http/httptest" - "strings" + "net/url" + "sort" "time" "github.com/araddon/dateparse" "github.com/dgraph-io/ristretto" + "go.goblog.app/app/pkgs/bufferpool" "golang.org/x/sync/singleflight" ) @@ -106,23 +105,41 @@ func cacheable(r *http.Request) bool { return true } -func cacheKey(r *http.Request) string { - var buf strings.Builder +func cacheKey(r *http.Request) (key string) { + buf := bufferpool.Get() // Special cases if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { - buf.WriteString("as-") + _, _ = buf.WriteString("as-") } if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { - buf.WriteString("tor-") + _, _ = buf.WriteString("tor-") } // Add cache URL _, _ = buf.WriteString(r.URL.EscapedPath()) - if q := r.URL.Query(); len(q) > 0 { + if query := r.URL.Query(); len(query) > 0 { _ = buf.WriteByte('?') - _, _ = buf.WriteString(q.Encode()) + keys := make([]string, 0, len(query)) + for k := range query { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + keyEscaped := url.QueryEscape(k) + for j, val := range query[k] { + if i > 0 || j > 0 { + buf.WriteByte('&') + } + buf.WriteString(keyEscaped) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(val)) + } + } } - // Return string - return buf.String() + // Get key as string + key = buf.String() + // Return buffer to pool + bufferpool.Put(buf) + return } func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { @@ -152,67 +169,58 @@ type cacheItem struct { body []byte } -// Calculate byte size of cache item using size of body and header -func (ci *cacheItem) cost() int64 { - var headerBuf strings.Builder - _ = ci.header.Write(&headerBuf) - headerSize := int64(binary.Size(headerBuf.String())) - bodySize := int64(binary.Size(ci.body)) - return headerSize + bodySize +// Calculate byte size of cache item using size of header, body and etag +func (ci *cacheItem) cost() int { + headerBuf := bufferpool.Get() + _ = ci.header.Write(headerBuf) + headerSize := len(headerBuf.Bytes()) + bufferpool.Put(headerBuf) + return headerSize + len(ci.body) + len(ci.eTag) } -func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) { +func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem { if rItem, ok := c.c.Get(key); ok { - item = rItem.(*cacheItem) + return rItem.(*cacheItem) } - if item == nil { - // No cache available - // Make and use copy of r - cr := r.Clone(valueOnlyContext{r.Context()}) - // Remove problematic headers - cr.Header.Del("If-Modified-Since") - cr.Header.Del("If-Unmodified-Since") - cr.Header.Del("If-None-Match") - cr.Header.Del("If-Match") - cr.Header.Del("If-Range") - cr.Header.Del("Range") - // Record request - recorder := httptest.NewRecorder() - next.ServeHTTP(recorder, cr) - recorder.Flush() - // Cache result - result := recorder.Result() - eTag := sha256.New() - body, _ := io.ReadAll(io.TeeReader(result.Body, eTag)) - headers := result.Header.Clone() - _ = result.Body.Close() - lastMod := time.Now() - if lm := headers.Get(lastModified); lm != "" { - if parsedTime, te := dateparse.ParseLocal(lm); te == nil { - lastMod = parsedTime - } + // No cache available + // Make and use copy of r + cr := r.Clone(valueOnlyContext{r.Context()}) + // Remove problematic headers + cr.Header.Del("If-Modified-Since") + cr.Header.Del("If-Unmodified-Since") + cr.Header.Del("If-None-Match") + cr.Header.Del("If-Match") + cr.Header.Del("If-Range") + cr.Header.Del("Range") + // Record request + rec := newCacheRecorder() + next.ServeHTTP(rec, cr) + item := rec.finish() + // Set eTag + item.eTag = item.header.Get("ETag") + if item.eTag == "" { + item.eTag = fmt.Sprintf("%x", sha256.Sum256(item.body)) + } + // Set creation time + item.creationTime = time.Now() + if lm := item.header.Get(lastModified); lm != "" { + if parsedTime, te := dateparse.ParseLocal(lm); te == nil { + item.creationTime = parsedTime } - // Remove problematic headers - headers.Del("Accept-Ranges") - headers.Del("ETag") - headers.Del(lastModified) - // Create cache item - exp, _ := cr.Context().Value(cacheExpirationKey).(int) - item = &cacheItem{ - expiration: exp, - creationTime: lastMod, - eTag: fmt.Sprintf("%x", eTag.Sum(nil)), - code: result.StatusCode, - header: headers, - body: body, - } - // Save cache - if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") { - if exp == 0 { - c.c.Set(key, item, item.cost()) - } else { - c.c.SetWithTTL(key, item, item.cost(), time.Duration(exp)*time.Second) - } + } + // Set expiration + item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int) + // Remove problematic headers + item.header.Del("Accept-Ranges") + item.header.Del("ETag") + item.header.Del(lastModified) + // Save cache + if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") { + cost := int64(item.cost()) + if item.expiration == 0 { + c.c.Set(key, item, cost) + } else { + c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second) } } return item diff --git a/cacheRecorder.go b/cacheRecorder.go new file mode 100644 index 0000000..7b5b232 --- /dev/null +++ b/cacheRecorder.go @@ -0,0 +1,64 @@ +package main + +import ( + "fmt" + "net/http" +) + +// cacheRecorder is an implementation of http.ResponseWriter +type cacheRecorder struct { + item *cacheItem +} + +func newCacheRecorder() *cacheRecorder { + return &cacheRecorder{ + item: &cacheItem{ + code: http.StatusOK, + header: make(http.Header), + }, + } +} + +func (c *cacheRecorder) finish() (ci *cacheItem) { + ci = c.item + c.item = nil + return +} + +// Header implements http.ResponseWriter. +func (rw *cacheRecorder) Header() http.Header { + if rw.item == nil { + return nil + } + return rw.item.header +} + +// Write implements http.ResponseWriter. +func (rw *cacheRecorder) Write(buf []byte) (int, error) { + if rw.item == nil { + return 0, nil + } + rw.item.body = append(rw.item.body, buf...) + return len(buf), nil +} + +// WriteString implements io.StringWriter. +func (rw *cacheRecorder) WriteString(str string) (int, error) { + return rw.Write([]byte(str)) +} + +// WriteHeader implements http.ResponseWriter. +func (rw *cacheRecorder) WriteHeader(code int) { + if rw.item == nil { + return + } + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } + rw.item.code = code +} + +// Flush implements http.Flusher. +func (rw *cacheRecorder) Flush() { + // Do nothing +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..63e63f7 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/dgraph-io/ristretto" + "github.com/stretchr/testify/assert" +) + +func Benchmark_cacheItem_cost(b *testing.B) { + ci := &cacheItem{ + creationTime: time.Now(), + eTag: "abc", + code: 200, + header: http.Header{ + "Content-Type": []string{"text/html"}, + }, + body: []byte("abcdefghijklmnopqrstuvwxyz"), + } + b.RunParallel(func(p *testing.PB) { + for p.Next() { + ci.cost() + } + }) +} + +func Test_cacheItem_cost(t *testing.T) { + ci := &cacheItem{ + header: http.Header{ + "Content-Type": []string{"text/html"}, + }, + body: []byte("abcdefghijklmnopqrstuvwxyz"), + eTag: "abc", + } + bodyLen := len(ci.body) + assert.Equal(t, 39, bodyLen) + eTagLen := len(ci.eTag) + assert.Equal(t, 3, eTagLen) + assert.Greater(t, ci.cost(), bodyLen+eTagLen) +} + +func Benchmark_cacheKey(b *testing.B) { + req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil) + b.RunParallel(func(p *testing.PB) { + for p.Next() { + cacheKey(req) + } + }) +} + +func Benchmark_cache_getCache(b *testing.B) { + c := &cache{} + c.c, _ = ristretto.NewCache(&ristretto.Config{ + NumCounters: 40 * 1000, + MaxCost: 20 * 1000 * 1000, + BufferItems: 64, + }) + req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "abcdefghijklmnopqrstuvwxyz") + _, _ = w.Write([]byte("abcdefghijklmnopqrstuvwxyz")) + }) + for i := 0; i < b.N; i++ { + c.getCache(strconv.Itoa(i), handler, req) + } +} diff --git a/utils.go b/utils.go index cd5e273..8e85f5e 100644 --- a/utils.go +++ b/utils.go @@ -23,6 +23,7 @@ import ( tdl "github.com/mergestat/timediff/locale" "github.com/microcosm-cc/bluemonday" "github.com/thoas/go-funk" + "go.goblog.app/app/pkgs/bufferpool" "golang.org/x/text/language" ) @@ -237,7 +238,8 @@ func htmlTextFromReader(r io.Reader) (string, error) { if err != nil { return "", err } - var text strings.Builder + text := bufferpool.Get() + defer bufferpool.Put(text) if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 { // Input was real HTML, so build the text from the body // Declare recursive function to print childs @@ -249,7 +251,7 @@ func htmlTextFromReader(r io.Reader) (string, error) { _, _ = text.WriteString("\n\n") } if sel.Is("ol > li") { // List item in ordered list - _, _ = fmt.Fprintf(&text, "%d. ", i+1) // Add list item number + _, _ = fmt.Fprintf(text, "%d. ", i+1) // Add list item number } if sel.Children().Length() > 0 { // Has children printChilds(sel.Children()) // Recursive call to print childs