Some cache fixes

This commit is contained in:
Jan-Lukas Else 2021-07-25 10:53:12 +02:00
parent bfa2d302c9
commit 91e2b268c7
8 changed files with 98 additions and 69 deletions

View File

@ -120,24 +120,24 @@ func (a *goBlog) checkLogin(w http.ResponseWriter, r *http.Request) bool {
return true return true
} }
// Serve original request // Serve original request
setLoggedIn(req) setLoggedIn(req, true)
a.d.ServeHTTP(w, req) a.d.ServeHTTP(w, req)
return true return true
} }
func (a *goBlog) isLoggedIn(r *http.Request) bool { func (a *goBlog) isLoggedIn(r *http.Request) bool {
// Check if context key already set // Check if context key already set
if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok && loggedIn { if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok {
return true return loggedIn
} }
// Check app passwords // Check app passwords
if username, password, ok := r.BasicAuth(); ok && a.checkAppPasswords(username, password) { if username, password, ok := r.BasicAuth(); ok && a.checkAppPasswords(username, password) {
setLoggedIn(r) setLoggedIn(r, true)
return true return true
} }
// Check session cookie // Check session cookie
if a.checkLoginCookie(r) { if a.checkLoginCookie(r) {
setLoggedIn(r) setLoggedIn(r, true)
return true return true
} }
// Not logged in // Not logged in
@ -145,9 +145,9 @@ func (a *goBlog) isLoggedIn(r *http.Request) bool {
} }
// Set request context value // Set request context value
func setLoggedIn(r *http.Request) { func setLoggedIn(r *http.Request, loggedIn bool) {
newRequest := r.WithContext(context.WithValue(r.Context(), loggedInKey, true)) // Overwrite the value of r (r is a pointer)
(*r) = *newRequest (*r) = *(r.WithContext(context.WithValue(r.Context(), loggedInKey, loggedIn)))
} }
// HandlerFunc to redirect to home after login // HandlerFunc to redirect to home after login

View File

@ -187,8 +187,14 @@ func Test_authMiddleware(t *testing.T) {
func Test_setLoggedIn(t *testing.T) { func Test_setLoggedIn(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/abc", nil) req := httptest.NewRequest(http.MethodGet, "/abc", nil)
setLoggedIn(req) setLoggedIn(req, true)
loggedIn, ok := req.Context().Value(loggedInKey).(bool) loggedIn, ok := req.Context().Value(loggedInKey).(bool)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, loggedIn) assert.True(t, loggedIn)
req = httptest.NewRequest(http.MethodGet, "/abc", nil)
setLoggedIn(req, false)
loggedIn, ok = req.Context().Value(loggedInKey).(bool)
assert.True(t, ok)
assert.False(t, loggedIn)
} }

View File

@ -27,9 +27,6 @@ func (a *goBlog) serveBlogroll(w http.ResponseWriter, r *http.Request) {
a.serveError(w, r, "", http.StatusInternalServerError) a.serveError(w, r, "", http.StatusInternalServerError)
return return
} }
if a.cfg.Cache != nil && a.cfg.Cache.Enable {
a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration))
}
c := a.cfg.Blogs[blog].Blogroll c := a.cfg.Blogs[blog].Blogroll
can := a.getRelativePath(blog, defaultIfEmpty(c.Path, defaultBlogrollPath)) can := a.getRelativePath(blog, defaultIfEmpty(c.Path, defaultBlogrollPath))
a.render(w, r, templateBlogroll, &renderData{ a.render(w, r, templateBlogroll, &renderData{
@ -54,9 +51,6 @@ func (a *goBlog) serveBlogrollExport(w http.ResponseWriter, r *http.Request) {
a.serveError(w, r, "", http.StatusInternalServerError) a.serveError(w, r, "", http.StatusInternalServerError)
return return
} }
if a.cfg.Cache != nil && a.cfg.Cache.Enable {
a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration))
}
w.Header().Set(contentType, contenttype.XMLUTF8) w.Header().Set(contentType, contenttype.XMLUTF8)
var opmlBytes bytes.Buffer var opmlBytes bytes.Buffer
_ = opml.Render(&opmlBytes, &opml.OPML{ _ = opml.Render(&opmlBytes, &opml.OPML{

View File

@ -2,13 +2,13 @@ package main
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"strings" "strings"
"time" "time"
@ -17,7 +17,10 @@ import (
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
const cacheInternalExpirationHeader = "Goblog-Expire" const (
cacheLoggedInKey contextKey = "cacheLoggedIn"
cacheExpirationKey contextKey = "cacheExpiration"
)
type cache struct { type cache struct {
g singleflight.Group g singleflight.Group
@ -40,10 +43,15 @@ func (a *goBlog) initCache() (err error) {
return return
} }
func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { func cacheLoggedIn(next http.Handler) http.Handler {
c := a.cache
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if c.c == nil { next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cacheLoggedInKey, true)))
})
}
func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if a.cache.c == nil {
// No cache configured // No cache configured
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
@ -57,22 +65,28 @@ func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
if a.isLoggedIn(r) { // Check login
next.ServeHTTP(w, r) if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli {
return // Continue caching, but remove login
setLoggedIn(r, false)
} else {
if a.isLoggedIn(r) {
next.ServeHTTP(w, r)
return
}
} }
// Search and serve cache // Search and serve cache
key := cacheKey(r) key := cacheKey(r)
// Get cache or render it // Get cache or render it
cacheInterface, _, _ := c.g.Do(key, func() (interface{}, error) { cacheInterface, _, _ := a.cache.g.Do(key, func() (interface{}, error) {
return c.getCache(key, next, r), nil return a.cache.getCache(key, next, r), nil
}) })
ci := cacheInterface.(*cacheItem) ci := cacheInterface.(*cacheItem)
// copy cached headers // copy cached headers
for k, v := range ci.header { for k, v := range ci.header {
w.Header()[k] = v w.Header()[k] = v
} }
c.setCacheHeaders(w, ci) a.cache.setCacheHeaders(w, ci)
// check conditional request // check conditional request
if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag { if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag {
// send 304 // send 304
@ -148,16 +162,18 @@ func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *
} }
if item == nil { if item == nil {
// No cache available // No cache available
// Make and use copy of r
cr := r.Clone(r.Context())
// Remove problematic headers // Remove problematic headers
r.Header.Del("If-Modified-Since") cr.Header.Del("If-Modified-Since")
r.Header.Del("If-Unmodified-Since") cr.Header.Del("If-Unmodified-Since")
r.Header.Del("If-None-Match") cr.Header.Del("If-None-Match")
r.Header.Del("If-Match") cr.Header.Del("If-Match")
r.Header.Del("If-Range") cr.Header.Del("If-Range")
r.Header.Del("Range") cr.Header.Del("Range")
// Record request // Record request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
next.ServeHTTP(recorder, r) next.ServeHTTP(recorder, cr)
// Cache values from recorder // Cache values from recorder
result := recorder.Result() result := recorder.Result()
body, _ := io.ReadAll(result.Body) body, _ := io.ReadAll(result.Body)
@ -174,9 +190,8 @@ func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *
lastMod = parsedTime lastMod = parsedTime
} }
} }
exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader)) exp, _ := cr.Context().Value(cacheExpirationKey).(int)
// Remove problematic headers // Remove problematic headers
result.Header.Del(cacheInternalExpirationHeader)
result.Header.Del("Accept-Ranges") result.Header.Del("Accept-Ranges")
result.Header.Del("ETag") result.Header.Del("ETag")
result.Header.Del("Last-Modified") result.Header.Del("Last-Modified")
@ -205,9 +220,9 @@ func (c *cache) purge() {
c.c.Clear() c.c.Clear()
} }
func (a *goBlog) setInternalCacheExpirationHeader(w http.ResponseWriter, r *http.Request, expiration int) { func (a *goBlog) defaultCacheExpiration() int {
if a.isLoggedIn(r) { if a.cfg.Cache != nil {
return return a.cfg.Cache.Expiration
} }
w.Header().Set(cacheInternalExpirationHeader, strconv.Itoa(expiration)) return 0
} }

View File

@ -6,13 +6,6 @@ const customPageContextKey = "custompage"
func (a *goBlog) serveCustomPage(w http.ResponseWriter, r *http.Request) { func (a *goBlog) serveCustomPage(w http.ResponseWriter, r *http.Request) {
page := r.Context().Value(customPageContextKey).(*configCustomPage) page := r.Context().Value(customPageContextKey).(*configCustomPage)
if a.cfg.Cache != nil && a.cfg.Cache.Enable && page.Cache {
if page.CacheExpiration != 0 {
a.setInternalCacheExpirationHeader(w, r, page.CacheExpiration)
} else {
a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration))
}
}
a.render(w, r, page.Template, &renderData{ a.render(w, r, page.Template, &renderData{
BlogString: r.Context().Value(blogContextKey).(string), BlogString: r.Context().Value(blogContextKey).(string),
Canonical: a.getFullAddress(page.Path), Canonical: a.getFullAddress(page.Path),

47
http.go
View File

@ -183,7 +183,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
r.Post("/{blog}/inbox", a.apHandleInbox) r.Post("/{blog}/inbox", a.apHandleInbox)
}) })
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(a.cacheMiddleware) r.Use(cacheLoggedIn, a.cacheMiddleware)
r.Get("/.well-known/webfinger", a.apHandleWebfinger) r.Get("/.well-known/webfinger", a.apHandleWebfinger)
r.Get("/.well-known/host-meta", handleWellKnownHostMeta) r.Get("/.well-known/host-meta", handleWellKnownHostMeta)
r.Get("/.well-known/nodeinfo", a.serveNodeInfoDiscover) r.Get("/.well-known/nodeinfo", a.serveNodeInfoDiscover)
@ -232,7 +232,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
r.Handle("/captcha/*", captcha.Server(500, 250)) r.Handle("/captcha/*", captcha.Server(500, 250))
// Short paths // Short paths
r.With(privateModeHandler...).With(a.cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", a.redirectToLongPath) r.With(privateModeHandler...).With(cacheLoggedIn, a.cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", a.redirectToLongPath)
for blog, blogConfig := range a.cfg.Blogs { for blog, blogConfig := range a.cfg.Blogs {
sbm := middleware.WithValue(blogContextKey, blog) sbm := middleware.WithValue(blogContextKey, blog)
@ -321,9 +321,8 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
statsPath := blogConfig.getRelativePath(defaultIfEmpty(bsc.Path, defaultBlogStatsPath)) statsPath := blogConfig.getRelativePath(defaultIfEmpty(bsc.Path, defaultBlogStatsPath))
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(privateModeHandler...) r.Use(privateModeHandler...)
r.Use(a.cacheMiddleware, sbm) r.With(a.cacheMiddleware, sbm).Get(statsPath, a.serveBlogStats)
r.Get(statsPath, a.serveBlogStats) r.With(cacheLoggedIn, a.cacheMiddleware, sbm).Get(statsPath+".table.html", a.serveBlogStatsTable)
r.Get(statsPath+".table.html", a.serveBlogStatsTable)
}) })
} }
@ -364,14 +363,28 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
} }
// Custom pages // Custom pages
for _, cp := range blogConfig.CustomPages { r.Group(func(r chi.Router) {
scp := middleware.WithValue(customPageContextKey, cp) r.Use(privateModeHandler...)
if cp.Cache { r.Use(sbm)
r.With(privateModeHandler...).With(a.cacheMiddleware, sbm, scp).Get(cp.Path, a.serveCustomPage) for _, cp := range blogConfig.CustomPages {
} else { r.Group(func(r chi.Router) {
r.With(privateModeHandler...).With(sbm, scp).Get(cp.Path, a.serveCustomPage) scp := middleware.WithValue(customPageContextKey, cp)
if cp.Cache {
ce := cp.CacheExpiration
if ce == 0 {
ce = a.defaultCacheExpiration()
}
r.With(
a.cacheMiddleware,
middleware.WithValue(cacheExpirationKey, ce),
scp,
).Get(cp.Path, a.serveCustomPage)
} else {
r.With(scp).Get(cp.Path, a.serveCustomPage)
}
})
} }
} })
// Random post // Random post
if rp := blogConfig.RandomPost; rp != nil && rp.Enabled { if rp := blogConfig.RandomPost; rp != nil && rp.Enabled {
@ -420,6 +433,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
brPath := blogConfig.getRelativePath(defaultIfEmpty(brConfig.Path, defaultBlogrollPath)) brPath := blogConfig.getRelativePath(defaultIfEmpty(brConfig.Path, defaultBlogrollPath))
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(privateModeHandler...) r.Use(privateModeHandler...)
r.Use(middleware.WithValue(cacheExpirationKey, a.defaultCacheExpiration()))
r.Use(a.cacheMiddleware, sbm) r.Use(a.cacheMiddleware, sbm)
r.Get(brPath, a.serveBlogroll) r.Get(brPath, a.serveBlogroll)
r.Get(brPath+".opml", a.serveBlogrollExport) r.Get(brPath+".opml", a.serveBlogrollExport)
@ -432,9 +446,8 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
r.Route(mapPath, func(r chi.Router) { r.Route(mapPath, func(r chi.Router) {
r.Use(privateModeHandler...) r.Use(privateModeHandler...)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(a.cacheMiddleware, sbm) r.With(a.cacheMiddleware, sbm).Get("/", a.serveGeoMap)
r.Get("/", a.serveGeoMap) r.With(cacheLoggedIn, a.cacheMiddleware).HandleFunc("/leaflet/*", a.serveLeaflet(mapPath+"/"))
r.HandleFunc("/leaflet/*", a.serveLeaflet(mapPath+"/"))
}) })
r.Get("/tiles/{z}/{x}/{y}.png", a.proxyTiles(mapPath+"/tiles")) r.Get("/tiles/{z}/{x}/{y}.png", a.proxyTiles(mapPath+"/tiles"))
}) })
@ -454,7 +467,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) {
} }
// Sitemap // Sitemap
r.With(privateModeHandler...).With(a.cacheMiddleware).Get(sitemapPath, a.serveSitemap) r.With(privateModeHandler...).With(cacheLoggedIn, a.cacheMiddleware).Get(sitemapPath, a.serveSitemap)
// Robots.txt - doesn't need cache, because it's too simple // Robots.txt - doesn't need cache, because it's too simple
if !privateMode { if !privateMode {
@ -521,7 +534,7 @@ func (a *goBlog) servePostsAliasesRedirects(pmh ...func(http.Handler) http.Handl
} }
case "alias": case "alias":
// Is alias, redirect // Is alias, redirect
alicePrivate.Append(a.cacheMiddleware).ThenFunc(func(w http.ResponseWriter, r *http.Request) { alicePrivate.Append(cacheLoggedIn, a.cacheMiddleware).ThenFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, value, http.StatusFound) http.Redirect(w, r, value, http.StatusFound)
}).ServeHTTP(w, r) }).ServeHTTP(w, r)
return return

View File

@ -102,11 +102,17 @@ type renderData struct {
Blog *configBlog Blog *configBlog
User *configUser User *configUser
Data interface{} Data interface{}
LoggedIn bool
CommentsEnabled bool CommentsEnabled bool
WebmentionReceivingEnabled bool WebmentionReceivingEnabled bool
TorUsed bool TorUsed bool
EasterEgg bool EasterEgg bool
// Not directly accessible
app *goBlog
req *http.Request
}
func (d *renderData) LoggedIn() bool {
return d.app.isLoggedIn(d.req)
} }
func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, data *renderData) { func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, data *renderData) {
@ -134,6 +140,12 @@ func (a *goBlog) renderWithStatusCode(w http.ResponseWriter, r *http.Request, st
} }
func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { func (a *goBlog) checkRenderData(r *http.Request, data *renderData) {
if data.app == nil {
data.app = a
}
if data.req == nil {
data.req = r
}
// User // User
if data.User == nil { if data.User == nil {
data.User = a.cfg.User data.User = a.cfg.User
@ -160,10 +172,6 @@ func (a *goBlog) checkRenderData(r *http.Request, data *renderData) {
if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed {
data.TorUsed = true data.TorUsed = true
} }
// Check login
if a.isLoggedIn(r) {
data.LoggedIn = true
}
// Check if comments enabled // Check if comments enabled
data.CommentsEnabled = data.Blog.Comments != nil && data.Blog.Comments.Enabled data.CommentsEnabled = data.Blog.Comments != nil && data.Blog.Comments.Enabled
// Check if able to receive webmentions // Check if able to receive webmentions

View File

@ -85,7 +85,7 @@ func (a *goBlog) verifyMention(m *mention) error {
// Server not yet started // Server not yet started
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
setLoggedIn(req) setLoggedIn(req, true)
a.d.ServeHTTP(rec, req) a.d.ServeHTTP(rec, req)
resp = rec.Result() resp = rec.Result()
} else { } else {