diff --git a/cache.go b/cache.go index 597d99e..456e557 100644 --- a/cache.go +++ b/cache.go @@ -1,76 +1,36 @@ package main import ( - "database/sql" "net/http" "net/http/httptest" - "net/url" "sync" "time" + + "golang.org/x/sync/singleflight" ) -var cacheMutexMapMutex *sync.Mutex -var cacheMutexes map[string]*sync.Mutex -var cacheDb *sql.DB -var cacheDbWriteMutex = &sync.Mutex{} +var cacheMap = map[string]*cacheItem{} +var cacheMutex = &sync.RWMutex{} -func initCache() (err error) { - cacheMutexMapMutex = &sync.Mutex{} - cacheMutexes = map[string]*sync.Mutex{} - cacheDb, err = sql.Open("sqlite3", ":memory:") - if err != nil { - return err - } - tx, err := cacheDb.Begin() - if err != nil { - return - } - _, err = tx.Exec("CREATE TABLE cache (path text not null primary key, time integer, header blob, body blob);") - if err != nil { - return - } - err = tx.Commit() - if err != nil { - return - } - return -} - -func startWritingToCacheDb() { - cacheDbWriteMutex.Lock() -} - -func finishWritingToCacheDb() { - cacheDbWriteMutex.Unlock() -} +var requestGroup singleflight.Group func cacheMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestURL, _ := url.ParseRequestURI(r.RequestURI) - path := slashTrimmedPath(r) if appConfig.Cache.Enable && // check bypass query - !(requestURL != nil && requestURL.Query().Get("cache") == "0") { - // Check cache mutex - cacheMutexMapMutex.Lock() - if cacheMutexes[path] == nil { - cacheMutexes[path] = &sync.Mutex{} - } - cacheMutexMapMutex.Unlock() - // Get cache - cm := cacheMutexes[path] - cm.Lock() - cacheTime, header, body := getCache(path) - cm.Unlock() - if cacheTime == 0 { - cm.Lock() - // Render cache - renderCache(path, next, w, r) - cm.Unlock() - return - } - cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) - expiresTimeString := time.Unix(cacheTime+appConfig.Cache.Expiration, 0).Format(time.RFC1123) + !(r.URL.Query().Get("cache") == "0") && + // check method + (r.Method == http.MethodGet || r.Method == http.MethodHead) { + // Fix path + path := slashTrimmedPath(r) + // Get cache or render it + cacheInterface, _, _ := requestGroup.Do(path, func() (interface{}, error) { + return getCache(path, next, r), nil + }) + cache := cacheInterface.(*cacheItem) + // log.Println(string(cache.body)) + cacheTimeString := time.Unix(cache.creationTime, 0).Format(time.RFC1123) + expiresTimeString := time.Unix(cache.creationTime+appConfig.Cache.Expiration, 0).Format(time.RFC1123) // check conditional request ifModifiedSinceHeader := r.Header.Get("If-Modified-Since") if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString { @@ -80,13 +40,14 @@ func cacheMiddleware(next http.Handler) http.Handler { return } // copy cached headers - for k, v := range header { + for k, v := range cache.header { w.Header()[k] = v } setCacheHeaders(w, cacheTimeString, expiresTimeString) - w.Header().Set("GoBlog-Cache", "HIT") + // set status code + w.WriteHeader(cache.code) // write cached body - _, _ = w.Write(body) + _, _ = w.Write(cache.body) return } next.ServeHTTP(w, r) @@ -100,46 +61,38 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS w.Header().Set("Expires", expiresTimeString) } -func renderCache(path string, next http.Handler, w http.ResponseWriter, r *http.Request) { - // No cache available - recorder := httptest.NewRecorder() - next.ServeHTTP(recorder, r) - // copy values from recorder - code := recorder.Code - // send response - for k, v := range recorder.Header() { - w.Header()[k] = v - } - now := time.Now() - setCacheHeaders(w, now.Format(time.RFC1123), time.Unix(now.Unix()+appConfig.Cache.Expiration, 0).Format(time.RFC1123)) - w.Header().Set("GoBlog-Cache", "MISS") - w.WriteHeader(code) - _, _ = w.Write(recorder.Body.Bytes()) - // Save cache - if code == http.StatusOK { - saveCache(path, now, recorder.Header(), recorder.Body.Bytes()) - } +type cacheItem struct { + creationTime int64 + code int + header http.Header + body []byte } -func getCache(path string) (creationTime int64, header map[string][]string, body []byte) { - var headerBytes []byte - allowedTime := time.Now().Unix() - appConfig.Cache.Expiration - row := cacheDb.QueryRow("select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime) - _ = row.Scan(&creationTime, &headerBytes, &body) - header = make(map[string][]string) - _ = json.Unmarshal(headerBytes, &header) - return -} - -func saveCache(path string, now time.Time, header map[string][]string, body []byte) { - headerBytes, _ := json.Marshal(header) - startWritingToCacheDb() - defer finishWritingToCacheDb() - _, _ = cacheDb.Exec("insert or replace into cache (path, time, header, body) values (?, ?, ?, ?);", path, now.Unix(), headerBytes, body) +func getCache(path string, next http.Handler, r *http.Request) *cacheItem { + cacheMutex.RLock() + item, ok := cacheMap[path] + cacheMutex.RUnlock() + if !ok || item.creationTime < time.Now().Unix()-appConfig.Cache.Expiration { + item = &cacheItem{} + // No cache available + recorder := httptest.NewRecorder() + next.ServeHTTP(recorder, r) + // copy values from recorder + now := time.Now() + item.creationTime = now.Unix() + item.code = recorder.Code + item.header = recorder.Header() + item.body = recorder.Body.Bytes() + // Save cache + cacheMutex.Lock() + cacheMap[path] = item + cacheMutex.Unlock() + } + return item } func purgeCache() { - startWritingToCacheDb() - defer finishWritingToCacheDb() - _, _ = cacheDb.Exec("delete from cache; vacuum;") + cacheMutex.Lock() + cacheMap = map[string]*cacheItem{} + cacheMutex.Unlock() } diff --git a/go.mod b/go.mod index 0907574..6f279ef 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/yuin/goldmark-emoji v1.0.1 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 // indirect + golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13 // indirect gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect gopkg.in/ini.v1 v1.62.0 // indirect diff --git a/go.sum b/go.sum index cf8fe27..566b9cc 100644 --- a/go.sum +++ b/go.sum @@ -352,7 +352,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 h1:Bx6FllMpG4NWDOfhMBz1VR2QYNp/SAOHPIAsaVmxfPo= +golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/http.go b/http.go index eec8589..2632a7f 100644 --- a/http.go +++ b/http.go @@ -2,7 +2,9 @@ package main import ( "compress/flate" + "log" "net/http" + "os" "strconv" "strings" "sync" @@ -73,7 +75,10 @@ func buildHandler() (http.Handler, error) { r.Use(middleware.Recoverer) if appConfig.Server.Logging { r.Use(middleware.RealIP) - r.Use(middleware.Logger) + r.Use(middleware.RequestLogger(&middleware.DefaultLogFormatter{ + Logger: log.New(os.Stdout, "", log.LstdFlags), + NoColor: true, + })) } r.Use(middleware.Compress(flate.DefaultCompression)) r.Use(middleware.StripSlashes) @@ -205,7 +210,7 @@ func buildHandler() (http.Handler, error) { r.With(cacheMiddleware, minifier.Middleware).Get(sitemapPath, serveSitemap) // Check redirects, then serve 404 - r.With(checkRegexRedirects, minifier.Middleware).NotFound(serve404) + r.With(checkRegexRedirects, cacheMiddleware, minifier.Middleware).NotFound(serve404) return r, nil } diff --git a/main.go b/main.go index 33e800f..bbaf17d 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,6 @@ func main() { log.Println("Initialize server components...") initMinify() initMarkdown() - initCache() err = initTemplateAssets() // Needs minify if err != nil { log.Fatal(err)