This is my new blog CMS https://jlelse.blog
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

194 lines
5.1 KiB

4 months ago
2 weeks ago
4 months ago
2 weeks ago
4 months ago
4 months ago
4 months ago
4 months ago
4 months ago
2 weeks ago
2 weeks ago
2 weeks ago
2 months ago
2 months ago
4 months ago
4 months ago
2 months ago
4 months ago
2 months ago
4 months ago
4 months ago
2 months ago
3 weeks ago
2 months ago
2 months ago
2 months ago
4 months ago
2 months ago
5 months ago
4 months ago
  1. package main
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/araddon/dateparse"
  14. lru "github.com/hashicorp/golang-lru"
  15. "golang.org/x/sync/singleflight"
  16. )
  17. const (
  18. cacheInternalExpirationHeader = "GoBlog-Expire"
  19. )
  20. var (
  21. cacheGroup singleflight.Group
  22. cacheLru *lru.Cache
  23. )
  24. func initCache() (err error) {
  25. cacheLru, err = lru.New(500)
  26. return
  27. }
  28. func cacheMiddleware(next http.Handler) http.Handler {
  29. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  30. // Do checks
  31. if !appConfig.Cache.Enable {
  32. next.ServeHTTP(w, r)
  33. return
  34. }
  35. if !(r.Method == http.MethodGet || r.Method == http.MethodHead) {
  36. next.ServeHTTP(w, r)
  37. return
  38. }
  39. if r.URL.Query().Get("cache") == "0" || r.URL.Query().Get("cache") == "false" {
  40. next.ServeHTTP(w, r)
  41. return
  42. }
  43. if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok && loggedIn {
  44. next.ServeHTTP(w, r)
  45. return
  46. }
  47. // Search and serve cache
  48. key := cacheKey(r)
  49. // Get cache or render it
  50. cacheInterface, _, _ := cacheGroup.Do(key, func() (interface{}, error) {
  51. return getCache(key, next, r), nil
  52. })
  53. cache := cacheInterface.(*cacheItem)
  54. // copy cached headers
  55. for k, v := range cache.header {
  56. w.Header()[k] = v
  57. }
  58. setCacheHeaders(w, cache)
  59. // check conditional request
  60. if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == cache.eTag {
  61. // send 304
  62. w.WriteHeader(http.StatusNotModified)
  63. return
  64. }
  65. if ifModifiedSinceHeader := r.Header.Get("If-Modified-Since"); ifModifiedSinceHeader != "" {
  66. if t, err := dateparse.ParseAny(ifModifiedSinceHeader); err == nil && t.After(cache.creationTime) {
  67. // send 304
  68. w.WriteHeader(http.StatusNotModified)
  69. return
  70. }
  71. }
  72. // set status code
  73. w.WriteHeader(cache.code)
  74. // write cached body
  75. _, _ = w.Write(cache.body)
  76. })
  77. }
  78. func cacheKey(r *http.Request) string {
  79. def := cacheURLString(r.URL)
  80. // Special cases
  81. if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest {
  82. return "as-" + def
  83. }
  84. // Default
  85. return def
  86. }
  87. func cacheURLString(u *url.URL) string {
  88. var buf strings.Builder
  89. _, _ = buf.WriteString(u.EscapedPath())
  90. if q := u.Query(); len(q) > 0 {
  91. _ = buf.WriteByte('?')
  92. _, _ = buf.WriteString(q.Encode())
  93. }
  94. return buf.String()
  95. }
  96. func setCacheHeaders(w http.ResponseWriter, cache *cacheItem) {
  97. w.Header().Set("ETag", cache.eTag)
  98. w.Header().Set("Last-Modified", cache.creationTime.UTC().Format(http.TimeFormat))
  99. if w.Header().Get("Cache-Control") == "" {
  100. if cache.expiration != 0 {
  101. w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,stale-while-revalidate=%d", cache.expiration, cache.expiration))
  102. } else {
  103. w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", appConfig.Cache.Expiration, appConfig.Cache.Expiration/3, appConfig.Cache.Expiration))
  104. }
  105. }
  106. }
  107. type cacheItem struct {
  108. expiration int
  109. creationTime time.Time
  110. eTag string
  111. code int
  112. header http.Header
  113. body []byte
  114. }
  115. func (c *cacheItem) expired() bool {
  116. if c.expiration != 0 {
  117. return time.Now().After(c.creationTime.Add(time.Duration(c.expiration) * time.Second))
  118. }
  119. return false
  120. }
  121. func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) {
  122. if lruItem, ok := cacheLru.Get(key); ok {
  123. item = lruItem.(*cacheItem)
  124. }
  125. if item == nil || item.expired() {
  126. // No cache available
  127. // Remove problematic headers
  128. r.Header.Del("If-Modified-Since")
  129. r.Header.Del("If-Unmodified-Since")
  130. r.Header.Del("If-None-Match")
  131. r.Header.Del("If-Match")
  132. r.Header.Del("If-Range")
  133. r.Header.Del("Range")
  134. // Record request
  135. recorder := httptest.NewRecorder()
  136. next.ServeHTTP(recorder, r)
  137. // Cache values from recorder
  138. result := recorder.Result()
  139. body, _ := io.ReadAll(result.Body)
  140. _ = result.Body.Close()
  141. eTag := result.Header.Get("ETag")
  142. if eTag == "" {
  143. h := sha256.New()
  144. _, _ = io.Copy(h, bytes.NewReader(body))
  145. eTag = fmt.Sprintf("%x", h.Sum(nil))
  146. }
  147. lastMod := time.Now()
  148. if lm := result.Header.Get("Last-Modified"); lm != "" {
  149. if parsedTime, te := dateparse.ParseLocal(lm); te == nil {
  150. lastMod = parsedTime
  151. }
  152. }
  153. exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader))
  154. // Remove problematic headers
  155. result.Header.Del(cacheInternalExpirationHeader)
  156. result.Header.Del("Accept-Ranges")
  157. result.Header.Del("ETag")
  158. result.Header.Del("Last-Modified")
  159. // Create cache item
  160. item = &cacheItem{
  161. expiration: exp,
  162. creationTime: lastMod,
  163. eTag: eTag,
  164. code: result.StatusCode,
  165. header: result.Header,
  166. body: body,
  167. }
  168. // Save cache
  169. if cch := item.header.Get("Cache-Control"); !strings.Contains(cch, "no-store") && !strings.Contains(cch, "private") && !strings.Contains(cch, "no-cache") {
  170. cacheLru.Add(key, item)
  171. }
  172. }
  173. return item
  174. }
  175. func purgeCache() {
  176. cacheLru.Purge()
  177. }
  178. func setInternalCacheExpirationHeader(w http.ResponseWriter, expiration int) {
  179. w.Header().Set(cacheInternalExpirationHeader, strconv.Itoa(expiration))
  180. }