Improved cache efficiency

pull/25/head
Jan-Lukas Else 7 months ago
parent 3c6c234233
commit 46548df638
  1. 144
      cache.go
  2. 64
      cacheRecorder.go
  3. 71
      cache_test.go
  4. 6
      utils.go

@ -3,16 +3,15 @@ package main
import (
"context"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"net/url"
"sort"
"time"
"github.com/araddon/dateparse"
"github.com/dgraph-io/ristretto"
"go.goblog.app/app/pkgs/bufferpool"
"golang.org/x/sync/singleflight"
)
@ -106,23 +105,41 @@ func cacheable(r *http.Request) bool {
return true
}
func cacheKey(r *http.Request) string {
var buf strings.Builder
func cacheKey(r *http.Request) (key string) {
buf := bufferpool.Get()
// Special cases
if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest {
buf.WriteString("as-")
_, _ = buf.WriteString("as-")
}
if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed {
buf.WriteString("tor-")
_, _ = buf.WriteString("tor-")
}
// Add cache URL
_, _ = buf.WriteString(r.URL.EscapedPath())
if q := r.URL.Query(); len(q) > 0 {
if query := r.URL.Query(); len(query) > 0 {
_ = buf.WriteByte('?')
_, _ = buf.WriteString(q.Encode())
keys := make([]string, 0, len(query))
for k := range query {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
keyEscaped := url.QueryEscape(k)
for j, val := range query[k] {
if i > 0 || j > 0 {
buf.WriteByte('&')
}
buf.WriteString(keyEscaped)
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(val))
}
}
}
// Return string
return buf.String()
// Get key as string
key = buf.String()
// Return buffer to pool
bufferpool.Put(buf)
return
}
func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) {
@ -152,67 +169,58 @@ type cacheItem struct {
body []byte
}
// Calculate byte size of cache item using size of body and header
func (ci *cacheItem) cost() int64 {
var headerBuf strings.Builder
_ = ci.header.Write(&headerBuf)
headerSize := int64(binary.Size(headerBuf.String()))
bodySize := int64(binary.Size(ci.body))
return headerSize + bodySize
// Calculate byte size of cache item using size of header, body and etag
func (ci *cacheItem) cost() int {
headerBuf := bufferpool.Get()
_ = ci.header.Write(headerBuf)
headerSize := len(headerBuf.Bytes())
bufferpool.Put(headerBuf)
return headerSize + len(ci.body) + len(ci.eTag)
}
func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) {
func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem {
if rItem, ok := c.c.Get(key); ok {
item = rItem.(*cacheItem)
return rItem.(*cacheItem)
}
if item == nil {
// No cache available
// Make and use copy of r
cr := r.Clone(valueOnlyContext{r.Context()})
// Remove problematic headers
cr.Header.Del("If-Modified-Since")
cr.Header.Del("If-Unmodified-Since")
cr.Header.Del("If-None-Match")
cr.Header.Del("If-Match")
cr.Header.Del("If-Range")
cr.Header.Del("Range")
// Record request
recorder := httptest.NewRecorder()
next.ServeHTTP(recorder, cr)
recorder.Flush()
// Cache result
result := recorder.Result()
eTag := sha256.New()
body, _ := io.ReadAll(io.TeeReader(result.Body, eTag))
headers := result.Header.Clone()
_ = result.Body.Close()
lastMod := time.Now()
if lm := headers.Get(lastModified); lm != "" {
if parsedTime, te := dateparse.ParseLocal(lm); te == nil {
lastMod = parsedTime
}
}
// Remove problematic headers
headers.Del("Accept-Ranges")
headers.Del("ETag")
headers.Del(lastModified)
// Create cache item
exp, _ := cr.Context().Value(cacheExpirationKey).(int)
item = &cacheItem{
expiration: exp,
creationTime: lastMod,
eTag: fmt.Sprintf("%x", eTag.Sum(nil)),
code: result.StatusCode,
header: headers,
body: body,
// No cache available
// Make and use copy of r
cr := r.Clone(valueOnlyContext{r.Context()})
// Remove problematic headers
cr.Header.Del("If-Modified-Since")
cr.Header.Del("If-Unmodified-Since")
cr.Header.Del("If-None-Match")
cr.Header.Del("If-Match")
cr.Header.Del("If-Range")
cr.Header.Del("Range")
// Record request
rec := newCacheRecorder()
next.ServeHTTP(rec, cr)
item := rec.finish()
// Set eTag
item.eTag = item.header.Get("ETag")
if item.eTag == "" {
item.eTag = fmt.Sprintf("%x", sha256.Sum256(item.body))
}
// Set creation time
item.creationTime = time.Now()
if lm := item.header.Get(lastModified); lm != "" {
if parsedTime, te := dateparse.ParseLocal(lm); te == nil {
item.creationTime = parsedTime
}
// Save cache
if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") {
if exp == 0 {
c.c.Set(key, item, item.cost())
} else {
c.c.SetWithTTL(key, item, item.cost(), time.Duration(exp)*time.Second)
}
}
// Set expiration
item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int)
// Remove problematic headers
item.header.Del("Accept-Ranges")
item.header.Del("ETag")
item.header.Del(lastModified)
// Save cache
if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") {
cost := int64(item.cost())
if item.expiration == 0 {
c.c.Set(key, item, cost)
} else {
c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second)
}
}
return item

@ -0,0 +1,64 @@
package main
import (
"fmt"
"net/http"
)
// cacheRecorder is an implementation of http.ResponseWriter
type cacheRecorder struct {
item *cacheItem
}
func newCacheRecorder() *cacheRecorder {
return &cacheRecorder{
item: &cacheItem{
code: http.StatusOK,
header: make(http.Header),
},
}
}
func (c *cacheRecorder) finish() (ci *cacheItem) {
ci = c.item
c.item = nil
return
}
// Header implements http.ResponseWriter.
func (rw *cacheRecorder) Header() http.Header {
if rw.item == nil {
return nil
}
return rw.item.header
}
// Write implements http.ResponseWriter.
func (rw *cacheRecorder) Write(buf []byte) (int, error) {
if rw.item == nil {
return 0, nil
}
rw.item.body = append(rw.item.body, buf...)
return len(buf), nil
}
// WriteString implements io.StringWriter.
func (rw *cacheRecorder) WriteString(str string) (int, error) {
return rw.Write([]byte(str))
}
// WriteHeader implements http.ResponseWriter.
func (rw *cacheRecorder) WriteHeader(code int) {
if rw.item == nil {
return
}
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
rw.item.code = code
}
// Flush implements http.Flusher.
func (rw *cacheRecorder) Flush() {
// Do nothing
}

@ -0,0 +1,71 @@
package main
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/dgraph-io/ristretto"
"github.com/stretchr/testify/assert"
)
func Benchmark_cacheItem_cost(b *testing.B) {
ci := &cacheItem{
creationTime: time.Now(),
eTag: "abc",
code: 200,
header: http.Header{
"Content-Type": []string{"text/html"},
},
body: []byte("<html>abcdefghijklmnopqrstuvwxyz</html>"),
}
b.RunParallel(func(p *testing.PB) {
for p.Next() {
ci.cost()
}
})
}
func Test_cacheItem_cost(t *testing.T) {
ci := &cacheItem{
header: http.Header{
"Content-Type": []string{"text/html"},
},
body: []byte("<html>abcdefghijklmnopqrstuvwxyz</html>"),
eTag: "abc",
}
bodyLen := len(ci.body)
assert.Equal(t, 39, bodyLen)
eTagLen := len(ci.eTag)
assert.Equal(t, 3, eTagLen)
assert.Greater(t, ci.cost(), bodyLen+eTagLen)
}
func Benchmark_cacheKey(b *testing.B) {
req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil)
b.RunParallel(func(p *testing.PB) {
for p.Next() {
cacheKey(req)
}
})
}
func Benchmark_cache_getCache(b *testing.B) {
c := &cache{}
c.c, _ = ristretto.NewCache(&ristretto.Config{
NumCounters: 40 * 1000,
MaxCost: 20 * 1000 * 1000,
BufferItems: 64,
})
req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "abcdefghijklmnopqrstuvwxyz")
_, _ = w.Write([]byte("abcdefghijklmnopqrstuvwxyz"))
})
for i := 0; i < b.N; i++ {
c.getCache(strconv.Itoa(i), handler, req)
}
}

@ -23,6 +23,7 @@ import (
tdl "github.com/mergestat/timediff/locale"
"github.com/microcosm-cc/bluemonday"
"github.com/thoas/go-funk"
"go.goblog.app/app/pkgs/bufferpool"
"golang.org/x/text/language"
)
@ -237,7 +238,8 @@ func htmlTextFromReader(r io.Reader) (string, error) {
if err != nil {
return "", err
}
var text strings.Builder
text := bufferpool.Get()
defer bufferpool.Put(text)
if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 {
// Input was real HTML, so build the text from the body
// Declare recursive function to print childs
@ -249,7 +251,7 @@ func htmlTextFromReader(r io.Reader) (string, error) {
_, _ = text.WriteString("\n\n")
}
if sel.Is("ol > li") { // List item in ordered list
_, _ = fmt.Fprintf(&text, "%d. ", i+1) // Add list item number
_, _ = fmt.Fprintf(text, "%d. ", i+1) // Add list item number
}
if sel.Children().Length() > 0 { // Has children
printChilds(sel.Children()) // Recursive call to print childs

Loading…
Cancel
Save