diff --git a/authentication.go b/authentication.go index bf56678..c56e291 100644 --- a/authentication.go +++ b/authentication.go @@ -120,24 +120,24 @@ func (a *goBlog) checkLogin(w http.ResponseWriter, r *http.Request) bool { return true } // Serve original request - setLoggedIn(req) + setLoggedIn(req, true) a.d.ServeHTTP(w, req) return true } func (a *goBlog) isLoggedIn(r *http.Request) bool { // Check if context key already set - if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok && loggedIn { - return true + if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok { + return loggedIn } // Check app passwords if username, password, ok := r.BasicAuth(); ok && a.checkAppPasswords(username, password) { - setLoggedIn(r) + setLoggedIn(r, true) return true } // Check session cookie if a.checkLoginCookie(r) { - setLoggedIn(r) + setLoggedIn(r, true) return true } // Not logged in @@ -145,9 +145,9 @@ func (a *goBlog) isLoggedIn(r *http.Request) bool { } // Set request context value -func setLoggedIn(r *http.Request) { - newRequest := r.WithContext(context.WithValue(r.Context(), loggedInKey, true)) - (*r) = *newRequest +func setLoggedIn(r *http.Request, loggedIn bool) { + // Overwrite the value of r (r is a pointer) + (*r) = *(r.WithContext(context.WithValue(r.Context(), loggedInKey, loggedIn))) } // HandlerFunc to redirect to home after login diff --git a/authentication_test.go b/authentication_test.go index e93ab95..027c68a 100644 --- a/authentication_test.go +++ b/authentication_test.go @@ -187,8 +187,14 @@ func Test_authMiddleware(t *testing.T) { func Test_setLoggedIn(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/abc", nil) - setLoggedIn(req) + setLoggedIn(req, true) loggedIn, ok := req.Context().Value(loggedInKey).(bool) assert.True(t, ok) assert.True(t, loggedIn) + + req = httptest.NewRequest(http.MethodGet, "/abc", nil) + setLoggedIn(req, false) + loggedIn, ok = req.Context().Value(loggedInKey).(bool) + assert.True(t, ok) + assert.False(t, loggedIn) } diff --git a/blogroll.go b/blogroll.go index 297ee6e..14eb408 100644 --- a/blogroll.go +++ b/blogroll.go @@ -27,9 +27,6 @@ func (a *goBlog) serveBlogroll(w http.ResponseWriter, r *http.Request) { a.serveError(w, r, "", http.StatusInternalServerError) return } - if a.cfg.Cache != nil && a.cfg.Cache.Enable { - a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) - } c := a.cfg.Blogs[blog].Blogroll can := a.getRelativePath(blog, defaultIfEmpty(c.Path, defaultBlogrollPath)) a.render(w, r, templateBlogroll, &renderData{ @@ -54,9 +51,6 @@ func (a *goBlog) serveBlogrollExport(w http.ResponseWriter, r *http.Request) { a.serveError(w, r, "", http.StatusInternalServerError) return } - if a.cfg.Cache != nil && a.cfg.Cache.Enable { - a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) - } w.Header().Set(contentType, contenttype.XMLUTF8) var opmlBytes bytes.Buffer _ = opml.Render(&opmlBytes, &opml.OPML{ diff --git a/cache.go b/cache.go index 8ade9e4..257eab0 100644 --- a/cache.go +++ b/cache.go @@ -2,13 +2,13 @@ package main import ( "bytes" + "context" "crypto/sha256" "encoding/binary" "fmt" "io" "net/http" "net/http/httptest" - "strconv" "strings" "time" @@ -17,7 +17,10 @@ import ( "golang.org/x/sync/singleflight" ) -const cacheInternalExpirationHeader = "Goblog-Expire" +const ( + cacheLoggedInKey contextKey = "cacheLoggedIn" + cacheExpirationKey contextKey = "cacheExpiration" +) type cache struct { g singleflight.Group @@ -40,10 +43,15 @@ func (a *goBlog) initCache() (err error) { return } -func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { - c := a.cache +func cacheLoggedIn(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if c.c == nil { + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cacheLoggedInKey, true))) + }) +} + +func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.cache.c == nil { // No cache configured next.ServeHTTP(w, r) return @@ -57,22 +65,28 @@ func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } - if a.isLoggedIn(r) { - next.ServeHTTP(w, r) - return + // Check login + if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli { + // Continue caching, but remove login + setLoggedIn(r, false) + } else { + if a.isLoggedIn(r) { + next.ServeHTTP(w, r) + return + } } // Search and serve cache key := cacheKey(r) // Get cache or render it - cacheInterface, _, _ := c.g.Do(key, func() (interface{}, error) { - return c.getCache(key, next, r), nil + cacheInterface, _, _ := a.cache.g.Do(key, func() (interface{}, error) { + return a.cache.getCache(key, next, r), nil }) ci := cacheInterface.(*cacheItem) // copy cached headers for k, v := range ci.header { w.Header()[k] = v } - c.setCacheHeaders(w, ci) + a.cache.setCacheHeaders(w, ci) // check conditional request if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag { // send 304 @@ -148,16 +162,18 @@ func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item * } if item == nil { // No cache available + // Make and use copy of r + cr := r.Clone(r.Context()) // Remove problematic headers - r.Header.Del("If-Modified-Since") - r.Header.Del("If-Unmodified-Since") - r.Header.Del("If-None-Match") - r.Header.Del("If-Match") - r.Header.Del("If-Range") - r.Header.Del("Range") + cr.Header.Del("If-Modified-Since") + cr.Header.Del("If-Unmodified-Since") + cr.Header.Del("If-None-Match") + cr.Header.Del("If-Match") + cr.Header.Del("If-Range") + cr.Header.Del("Range") // Record request recorder := httptest.NewRecorder() - next.ServeHTTP(recorder, r) + next.ServeHTTP(recorder, cr) // Cache values from recorder result := recorder.Result() body, _ := io.ReadAll(result.Body) @@ -174,9 +190,8 @@ func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item * lastMod = parsedTime } } - exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader)) + exp, _ := cr.Context().Value(cacheExpirationKey).(int) // Remove problematic headers - result.Header.Del(cacheInternalExpirationHeader) result.Header.Del("Accept-Ranges") result.Header.Del("ETag") result.Header.Del("Last-Modified") @@ -205,9 +220,9 @@ func (c *cache) purge() { c.c.Clear() } -func (a *goBlog) setInternalCacheExpirationHeader(w http.ResponseWriter, r *http.Request, expiration int) { - if a.isLoggedIn(r) { - return +func (a *goBlog) defaultCacheExpiration() int { + if a.cfg.Cache != nil { + return a.cfg.Cache.Expiration } - w.Header().Set(cacheInternalExpirationHeader, strconv.Itoa(expiration)) + return 0 } diff --git a/customPages.go b/customPages.go index f35c7a8..0565a6d 100644 --- a/customPages.go +++ b/customPages.go @@ -6,13 +6,6 @@ const customPageContextKey = "custompage" func (a *goBlog) serveCustomPage(w http.ResponseWriter, r *http.Request) { page := r.Context().Value(customPageContextKey).(*configCustomPage) - if a.cfg.Cache != nil && a.cfg.Cache.Enable && page.Cache { - if page.CacheExpiration != 0 { - a.setInternalCacheExpirationHeader(w, r, page.CacheExpiration) - } else { - a.setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) - } - } a.render(w, r, page.Template, &renderData{ BlogString: r.Context().Value(blogContextKey).(string), Canonical: a.getFullAddress(page.Path), diff --git a/http.go b/http.go index ce80c7f..87afb33 100644 --- a/http.go +++ b/http.go @@ -183,7 +183,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { r.Post("/{blog}/inbox", a.apHandleInbox) }) r.Group(func(r chi.Router) { - r.Use(a.cacheMiddleware) + r.Use(cacheLoggedIn, a.cacheMiddleware) r.Get("/.well-known/webfinger", a.apHandleWebfinger) r.Get("/.well-known/host-meta", handleWellKnownHostMeta) r.Get("/.well-known/nodeinfo", a.serveNodeInfoDiscover) @@ -232,7 +232,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { r.Handle("/captcha/*", captcha.Server(500, 250)) // Short paths - r.With(privateModeHandler...).With(a.cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", a.redirectToLongPath) + r.With(privateModeHandler...).With(cacheLoggedIn, a.cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", a.redirectToLongPath) for blog, blogConfig := range a.cfg.Blogs { sbm := middleware.WithValue(blogContextKey, blog) @@ -321,9 +321,8 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { statsPath := blogConfig.getRelativePath(defaultIfEmpty(bsc.Path, defaultBlogStatsPath)) r.Group(func(r chi.Router) { r.Use(privateModeHandler...) - r.Use(a.cacheMiddleware, sbm) - r.Get(statsPath, a.serveBlogStats) - r.Get(statsPath+".table.html", a.serveBlogStatsTable) + r.With(a.cacheMiddleware, sbm).Get(statsPath, a.serveBlogStats) + r.With(cacheLoggedIn, a.cacheMiddleware, sbm).Get(statsPath+".table.html", a.serveBlogStatsTable) }) } @@ -364,14 +363,28 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { } // Custom pages - for _, cp := range blogConfig.CustomPages { - scp := middleware.WithValue(customPageContextKey, cp) - if cp.Cache { - r.With(privateModeHandler...).With(a.cacheMiddleware, sbm, scp).Get(cp.Path, a.serveCustomPage) - } else { - r.With(privateModeHandler...).With(sbm, scp).Get(cp.Path, a.serveCustomPage) + r.Group(func(r chi.Router) { + r.Use(privateModeHandler...) + r.Use(sbm) + for _, cp := range blogConfig.CustomPages { + r.Group(func(r chi.Router) { + scp := middleware.WithValue(customPageContextKey, cp) + if cp.Cache { + ce := cp.CacheExpiration + if ce == 0 { + ce = a.defaultCacheExpiration() + } + r.With( + a.cacheMiddleware, + middleware.WithValue(cacheExpirationKey, ce), + scp, + ).Get(cp.Path, a.serveCustomPage) + } else { + r.With(scp).Get(cp.Path, a.serveCustomPage) + } + }) } - } + }) // Random post if rp := blogConfig.RandomPost; rp != nil && rp.Enabled { @@ -420,6 +433,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { brPath := blogConfig.getRelativePath(defaultIfEmpty(brConfig.Path, defaultBlogrollPath)) r.Group(func(r chi.Router) { r.Use(privateModeHandler...) + r.Use(middleware.WithValue(cacheExpirationKey, a.defaultCacheExpiration())) r.Use(a.cacheMiddleware, sbm) r.Get(brPath, a.serveBlogroll) r.Get(brPath+".opml", a.serveBlogrollExport) @@ -432,9 +446,8 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { r.Route(mapPath, func(r chi.Router) { r.Use(privateModeHandler...) r.Group(func(r chi.Router) { - r.Use(a.cacheMiddleware, sbm) - r.Get("/", a.serveGeoMap) - r.HandleFunc("/leaflet/*", a.serveLeaflet(mapPath+"/")) + r.With(a.cacheMiddleware, sbm).Get("/", a.serveGeoMap) + r.With(cacheLoggedIn, a.cacheMiddleware).HandleFunc("/leaflet/*", a.serveLeaflet(mapPath+"/")) }) r.Get("/tiles/{z}/{x}/{y}.png", a.proxyTiles(mapPath+"/tiles")) }) @@ -454,7 +467,7 @@ func (a *goBlog) buildRouter() (*chi.Mux, error) { } // Sitemap - r.With(privateModeHandler...).With(a.cacheMiddleware).Get(sitemapPath, a.serveSitemap) + r.With(privateModeHandler...).With(cacheLoggedIn, a.cacheMiddleware).Get(sitemapPath, a.serveSitemap) // Robots.txt - doesn't need cache, because it's too simple if !privateMode { @@ -521,7 +534,7 @@ func (a *goBlog) servePostsAliasesRedirects(pmh ...func(http.Handler) http.Handl } case "alias": // Is alias, redirect - alicePrivate.Append(a.cacheMiddleware).ThenFunc(func(w http.ResponseWriter, r *http.Request) { + alicePrivate.Append(cacheLoggedIn, a.cacheMiddleware).ThenFunc(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, value, http.StatusFound) }).ServeHTTP(w, r) return diff --git a/render.go b/render.go index e652995..906f65a 100644 --- a/render.go +++ b/render.go @@ -102,11 +102,17 @@ type renderData struct { Blog *configBlog User *configUser Data interface{} - LoggedIn bool CommentsEnabled bool WebmentionReceivingEnabled bool TorUsed bool EasterEgg bool + // Not directly accessible + app *goBlog + req *http.Request +} + +func (d *renderData) LoggedIn() bool { + return d.app.isLoggedIn(d.req) } func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, data *renderData) { @@ -134,6 +140,12 @@ func (a *goBlog) renderWithStatusCode(w http.ResponseWriter, r *http.Request, st } func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { + if data.app == nil { + data.app = a + } + if data.req == nil { + data.req = r + } // User if data.User == nil { data.User = a.cfg.User @@ -160,10 +172,6 @@ func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { data.TorUsed = true } - // Check login - if a.isLoggedIn(r) { - data.LoggedIn = true - } // Check if comments enabled data.CommentsEnabled = data.Blog.Comments != nil && data.Blog.Comments.Enabled // Check if able to receive webmentions diff --git a/webmentionVerification.go b/webmentionVerification.go index 03a8f17..495c2dd 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -85,7 +85,7 @@ func (a *goBlog) verifyMention(m *mention) error { // Server not yet started time.Sleep(1 * time.Second) } - setLoggedIn(req) + setLoggedIn(req, true) a.d.ServeHTTP(rec, req) resp = rec.Result() } else {