mirror of https://github.com/jlelse/GoBlog
Improved cache efficiency
This commit is contained in:
parent
3c6c234233
commit
46548df638
110
cache.go
110
cache.go
|
@ -3,16 +3,15 @@ package main
|
|||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"net/url"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/araddon/dateparse"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"go.goblog.app/app/pkgs/bufferpool"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
|
@ -106,23 +105,41 @@ func cacheable(r *http.Request) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func cacheKey(r *http.Request) string {
|
||||
var buf strings.Builder
|
||||
func cacheKey(r *http.Request) (key string) {
|
||||
buf := bufferpool.Get()
|
||||
// Special cases
|
||||
if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest {
|
||||
buf.WriteString("as-")
|
||||
_, _ = buf.WriteString("as-")
|
||||
}
|
||||
if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed {
|
||||
buf.WriteString("tor-")
|
||||
_, _ = buf.WriteString("tor-")
|
||||
}
|
||||
// Add cache URL
|
||||
_, _ = buf.WriteString(r.URL.EscapedPath())
|
||||
if q := r.URL.Query(); len(q) > 0 {
|
||||
if query := r.URL.Query(); len(query) > 0 {
|
||||
_ = buf.WriteByte('?')
|
||||
_, _ = buf.WriteString(q.Encode())
|
||||
keys := make([]string, 0, len(query))
|
||||
for k := range query {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Return string
|
||||
return buf.String()
|
||||
sort.Strings(keys)
|
||||
for i, k := range keys {
|
||||
keyEscaped := url.QueryEscape(k)
|
||||
for j, val := range query[k] {
|
||||
if i > 0 || j > 0 {
|
||||
buf.WriteByte('&')
|
||||
}
|
||||
buf.WriteString(keyEscaped)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(url.QueryEscape(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
// Get key as string
|
||||
key = buf.String()
|
||||
// Return buffer to pool
|
||||
bufferpool.Put(buf)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) {
|
||||
|
@ -152,20 +169,19 @@ type cacheItem struct {
|
|||
body []byte
|
||||
}
|
||||
|
||||
// Calculate byte size of cache item using size of body and header
|
||||
func (ci *cacheItem) cost() int64 {
|
||||
var headerBuf strings.Builder
|
||||
_ = ci.header.Write(&headerBuf)
|
||||
headerSize := int64(binary.Size(headerBuf.String()))
|
||||
bodySize := int64(binary.Size(ci.body))
|
||||
return headerSize + bodySize
|
||||
// Calculate byte size of cache item using size of header, body and etag
|
||||
func (ci *cacheItem) cost() int {
|
||||
headerBuf := bufferpool.Get()
|
||||
_ = ci.header.Write(headerBuf)
|
||||
headerSize := len(headerBuf.Bytes())
|
||||
bufferpool.Put(headerBuf)
|
||||
return headerSize + len(ci.body) + len(ci.eTag)
|
||||
}
|
||||
|
||||
func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) {
|
||||
func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem {
|
||||
if rItem, ok := c.c.Get(key); ok {
|
||||
item = rItem.(*cacheItem)
|
||||
return rItem.(*cacheItem)
|
||||
}
|
||||
if item == nil {
|
||||
// No cache available
|
||||
// Make and use copy of r
|
||||
cr := r.Clone(valueOnlyContext{r.Context()})
|
||||
|
@ -177,42 +193,34 @@ func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *
|
|||
cr.Header.Del("If-Range")
|
||||
cr.Header.Del("Range")
|
||||
// Record request
|
||||
recorder := httptest.NewRecorder()
|
||||
next.ServeHTTP(recorder, cr)
|
||||
recorder.Flush()
|
||||
// Cache result
|
||||
result := recorder.Result()
|
||||
eTag := sha256.New()
|
||||
body, _ := io.ReadAll(io.TeeReader(result.Body, eTag))
|
||||
headers := result.Header.Clone()
|
||||
_ = result.Body.Close()
|
||||
lastMod := time.Now()
|
||||
if lm := headers.Get(lastModified); lm != "" {
|
||||
rec := newCacheRecorder()
|
||||
next.ServeHTTP(rec, cr)
|
||||
item := rec.finish()
|
||||
// Set eTag
|
||||
item.eTag = item.header.Get("ETag")
|
||||
if item.eTag == "" {
|
||||
item.eTag = fmt.Sprintf("%x", sha256.Sum256(item.body))
|
||||
}
|
||||
// Set creation time
|
||||
item.creationTime = time.Now()
|
||||
if lm := item.header.Get(lastModified); lm != "" {
|
||||
if parsedTime, te := dateparse.ParseLocal(lm); te == nil {
|
||||
lastMod = parsedTime
|
||||
item.creationTime = parsedTime
|
||||
}
|
||||
}
|
||||
// Set expiration
|
||||
item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int)
|
||||
// Remove problematic headers
|
||||
headers.Del("Accept-Ranges")
|
||||
headers.Del("ETag")
|
||||
headers.Del(lastModified)
|
||||
// Create cache item
|
||||
exp, _ := cr.Context().Value(cacheExpirationKey).(int)
|
||||
item = &cacheItem{
|
||||
expiration: exp,
|
||||
creationTime: lastMod,
|
||||
eTag: fmt.Sprintf("%x", eTag.Sum(nil)),
|
||||
code: result.StatusCode,
|
||||
header: headers,
|
||||
body: body,
|
||||
}
|
||||
item.header.Del("Accept-Ranges")
|
||||
item.header.Del("ETag")
|
||||
item.header.Del(lastModified)
|
||||
// Save cache
|
||||
if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") {
|
||||
if exp == 0 {
|
||||
c.c.Set(key, item, item.cost())
|
||||
cost := int64(item.cost())
|
||||
if item.expiration == 0 {
|
||||
c.c.Set(key, item, cost)
|
||||
} else {
|
||||
c.c.SetWithTTL(key, item, item.cost(), time.Duration(exp)*time.Second)
|
||||
}
|
||||
c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second)
|
||||
}
|
||||
}
|
||||
return item
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// cacheRecorder is an implementation of http.ResponseWriter
|
||||
type cacheRecorder struct {
|
||||
item *cacheItem
|
||||
}
|
||||
|
||||
func newCacheRecorder() *cacheRecorder {
|
||||
return &cacheRecorder{
|
||||
item: &cacheItem{
|
||||
code: http.StatusOK,
|
||||
header: make(http.Header),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cacheRecorder) finish() (ci *cacheItem) {
|
||||
ci = c.item
|
||||
c.item = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Header implements http.ResponseWriter.
|
||||
func (rw *cacheRecorder) Header() http.Header {
|
||||
if rw.item == nil {
|
||||
return nil
|
||||
}
|
||||
return rw.item.header
|
||||
}
|
||||
|
||||
// Write implements http.ResponseWriter.
|
||||
func (rw *cacheRecorder) Write(buf []byte) (int, error) {
|
||||
if rw.item == nil {
|
||||
return 0, nil
|
||||
}
|
||||
rw.item.body = append(rw.item.body, buf...)
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
// WriteString implements io.StringWriter.
|
||||
func (rw *cacheRecorder) WriteString(str string) (int, error) {
|
||||
return rw.Write([]byte(str))
|
||||
}
|
||||
|
||||
// WriteHeader implements http.ResponseWriter.
|
||||
func (rw *cacheRecorder) WriteHeader(code int) {
|
||||
if rw.item == nil {
|
||||
return
|
||||
}
|
||||
if code < 100 || code > 999 {
|
||||
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
|
||||
}
|
||||
rw.item.code = code
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher.
|
||||
func (rw *cacheRecorder) Flush() {
|
||||
// Do nothing
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Benchmark_cacheItem_cost(b *testing.B) {
|
||||
ci := &cacheItem{
|
||||
creationTime: time.Now(),
|
||||
eTag: "abc",
|
||||
code: 200,
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
},
|
||||
body: []byte("<html>abcdefghijklmnopqrstuvwxyz</html>"),
|
||||
}
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
for p.Next() {
|
||||
ci.cost()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_cacheItem_cost(t *testing.T) {
|
||||
ci := &cacheItem{
|
||||
header: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
},
|
||||
body: []byte("<html>abcdefghijklmnopqrstuvwxyz</html>"),
|
||||
eTag: "abc",
|
||||
}
|
||||
bodyLen := len(ci.body)
|
||||
assert.Equal(t, 39, bodyLen)
|
||||
eTagLen := len(ci.eTag)
|
||||
assert.Equal(t, 3, eTagLen)
|
||||
assert.Greater(t, ci.cost(), bodyLen+eTagLen)
|
||||
}
|
||||
|
||||
func Benchmark_cacheKey(b *testing.B) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil)
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
for p.Next() {
|
||||
cacheKey(req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Benchmark_cache_getCache(b *testing.B) {
|
||||
c := &cache{}
|
||||
c.c, _ = ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: 40 * 1000,
|
||||
MaxCost: 20 * 1000 * 1000,
|
||||
BufferItems: 64,
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/abc?abc=def&hij=klm", nil)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "abcdefghijklmnopqrstuvwxyz")
|
||||
_, _ = w.Write([]byte("abcdefghijklmnopqrstuvwxyz"))
|
||||
})
|
||||
for i := 0; i < b.N; i++ {
|
||||
c.getCache(strconv.Itoa(i), handler, req)
|
||||
}
|
||||
}
|
6
utils.go
6
utils.go
|
@ -23,6 +23,7 @@ import (
|
|||
tdl "github.com/mergestat/timediff/locale"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/thoas/go-funk"
|
||||
"go.goblog.app/app/pkgs/bufferpool"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
|
@ -237,7 +238,8 @@ func htmlTextFromReader(r io.Reader) (string, error) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var text strings.Builder
|
||||
text := bufferpool.Get()
|
||||
defer bufferpool.Put(text)
|
||||
if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 {
|
||||
// Input was real HTML, so build the text from the body
|
||||
// Declare recursive function to print childs
|
||||
|
@ -249,7 +251,7 @@ func htmlTextFromReader(r io.Reader) (string, error) {
|
|||
_, _ = text.WriteString("\n\n")
|
||||
}
|
||||
if sel.Is("ol > li") { // List item in ordered list
|
||||
_, _ = fmt.Fprintf(&text, "%d. ", i+1) // Add list item number
|
||||
_, _ = fmt.Fprintf(text, "%d. ", i+1) // Add list item number
|
||||
}
|
||||
if sel.Children().Length() > 0 { // Has children
|
||||
printChilds(sel.Children()) // Recursive call to print childs
|
||||
|
|
Loading…
Reference in New Issue