Use mutexes to prevent cache stampede

This commit is contained in:
Jan-Lukas Else 2020-09-22 16:42:36 +02:00
parent 1c8da99620
commit e1c362ac2f
2 changed files with 18 additions and 0 deletions

View File

@ -6,9 +6,16 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"sync"
"time" "time"
) )
var cacheMutexes map[string]*sync.Mutex
func initCache() {
cacheMutexes = map[string]*sync.Mutex{}
}
func cacheMiddleware(next http.Handler) http.Handler { func cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestUrl, _ := url.ParseRequestURI(r.RequestURI) requestUrl, _ := url.ParseRequestURI(r.RequestURI)
@ -16,8 +23,16 @@ func cacheMiddleware(next http.Handler) http.Handler {
if appConfig.Cache.Enable && if appConfig.Cache.Enable &&
// check bypass query // check bypass query
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") { !(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
// Check cache mutex
if cacheMutexes[path] == nil {
cacheMutexes[path] = &sync.Mutex{}
}
// Lock mutex - prevents multiple new renderings
cacheMutexes[path].Lock()
// Get cache
cacheTime, header, body := getCache(r.Context(), path) cacheTime, header, body := getCache(r.Context(), path)
if cacheTime == 0 { if cacheTime == 0 {
// No cache available
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
next.ServeHTTP(recorder, r) next.ServeHTTP(recorder, r)
// copy values from recorder // copy values from recorder
@ -35,8 +50,10 @@ func cacheMiddleware(next http.Handler) http.Handler {
if code == http.StatusOK { if code == http.StatusOK {
saveCache(path, now, recorder.Header(), recorder.Body.Bytes()) saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
} }
cacheMutexes[path].Unlock()
return return
} }
cacheMutexes[path].Unlock()
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
expiresTimeString := time.Unix(cacheTime+appConfig.Cache.Expiration, 0).Format(time.RFC1123) expiresTimeString := time.Unix(cacheTime+appConfig.Cache.Expiration, 0).Format(time.RFC1123)
// check conditional request // check conditional request

View File

@ -25,6 +25,7 @@ func main() {
log.Println("Initialize server components...") log.Println("Initialize server components...")
initMinify() initMinify()
initMarkdown() initMarkdown()
initCache()
err = initTemplateAssets() // Needs minify err = initTemplateAssets() // Needs minify
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)