diff --git a/cache.go b/cache.go index 2b1b7bc..dcd60ba 100644 --- a/cache.go +++ b/cache.go @@ -49,7 +49,7 @@ func cacheMiddleware(next http.Handler) http.Handler { } setCacheHeaders(w, cache) // check conditional request - if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == cache.hash { + if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == cache.eTag { // send 304 w.WriteHeader(http.StatusNotModified) return @@ -78,8 +78,7 @@ func cacheKey(r *http.Request) string { } func setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { - w.Header().Del(cacheInternalExpirationHeader) - w.Header().Set("ETag", cache.hash) + w.Header().Set("ETag", cache.eTag) w.Header().Set("Last-Modified", cache.creationTime.UTC().Format(http.TimeFormat)) if w.Header().Get("Cache-Control") == "" { if cache.expiration != 0 { @@ -93,7 +92,7 @@ func setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { type cacheItem struct { expiration int creationTime time.Time - hash string + eTag string code int header http.Header body []byte @@ -126,20 +125,29 @@ func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) result := recorder.Result() body, _ := ioutil.ReadAll(result.Body) _ = result.Body.Close() - h := sha256.New() - _, _ = io.Copy(h, bytes.NewReader(body)) - hash := fmt.Sprintf("%x", h.Sum(nil)) - exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader)) + eTag := result.Header.Get("ETag") + if eTag == "" { + h := sha256.New() + _, _ = io.Copy(h, bytes.NewReader(body)) + eTag = fmt.Sprintf("%x", h.Sum(nil)) + } lastMod := time.Now() if lm := result.Header.Get("Last-Modified"); lm != "" { if parsedTime, te := dateparse.ParseLocal(lm); te == nil { lastMod = parsedTime } } + exp, _ := strconv.Atoi(result.Header.Get(cacheInternalExpirationHeader)) + // Remove problematic headers + result.Header.Del(cacheInternalExpirationHeader) + result.Header.Del("Accept-Ranges") + result.Header.Del("ETag") + result.Header.Del("Last-Modified") + // Create cache item item = &cacheItem{ expiration: exp, creationTime: lastMod, - hash: hash, + eTag: eTag, code: result.StatusCode, header: result.Header, body: body, diff --git a/http.go b/http.go index 6929dc6..2a05997 100644 --- a/http.go +++ b/http.go @@ -172,16 +172,16 @@ func buildHandler() (http.Handler, error) { // Assets for _, path := range allAssetPaths() { - r.With(cacheMiddleware).Get(path, serveAsset) + r.Get(path, serveAsset) } // Static files for _, path := range allStaticPaths() { - r.With(cacheMiddleware).Get(path, serveStaticFile) + r.Get(path, serveStaticFile) } // Media files - r.With(cacheMiddleware).Get(`/m/{file:[0-9a-fA-F]+(\.[0-9a-zA-Z]+)?}`, serveMediaFile) + r.Get(`/m/{file:[0-9a-fA-F]+(\.[0-9a-zA-Z]+)?}`, serveMediaFile) // Short paths r.With(cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", redirectToLongPath) diff --git a/render.go b/render.go index 38ca5fc..3e424c3 100644 --- a/render.go +++ b/render.go @@ -130,7 +130,7 @@ func initRendering() error { } return d.Before(b) }, - "asset": assetFile, + "asset": assetFileName, "string": getTemplateStringVariant, "include": func(templateName string, data ...interface{}) (template.HTML, error) { if len(data) == 1 { diff --git a/templateAssets.go b/templateAssets.go index 3bd01ba..f1a45d4 100644 --- a/templateAssets.go +++ b/templateAssets.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "fmt" "io/ioutil" + "mime" "net/http" "os" "path" @@ -13,23 +14,23 @@ import ( const assetsFolder = "templates/assets" -var compiledAssetsFolder string -var assetFiles map[string]string +var assetFileNames map[string]string = map[string]string{} +var assetFiles map[string]*assetFile = map[string]*assetFile{} + +type assetFile struct { + contentType string + body []byte +} func initTemplateAssets() (err error) { - compiledAssetsFolder, err = ioutil.TempDir("", "goblog-assets-*") - if err != nil { - return - } - assetFiles = map[string]string{} err = filepath.Walk(assetsFolder, func(path string, info os.FileInfo, err error) error { if info.Mode().IsRegular() { - compiled, err := compileAssets(path) + compiled, err := compileAsset(path) if err != nil { return err } if compiled != "" { - assetFiles[strings.TrimPrefix(path, assetsFolder+"/")] = compiled + assetFileNames[strings.TrimPrefix(path, assetsFolder+"/")] = compiled } } return nil @@ -40,7 +41,7 @@ func initTemplateAssets() (err error) { return nil } -func compileAssets(name string) (compiledFileName string, err error) { +func compileAsset(name string) (compiledFileName string, err error) { originalContent, err := ioutil.ReadFile(name) if err != nil { return @@ -67,21 +68,21 @@ func compileAssets(name string) (compiledFileName string, err error) { sha.Write(compiledContent) hash := fmt.Sprintf("%x", sha.Sum(nil)) compiledFileName = hash + compiledExt - err = ioutil.WriteFile(path.Join(compiledAssetsFolder, compiledFileName), compiledContent, 0644) - if err != nil { - return + assetFiles[compiledFileName] = &assetFile{ + contentType: mime.TypeByExtension(compiledExt), + body: compiledContent, } return } // Function for templates -func assetFile(fileName string) string { - return "/" + assetFiles[fileName] +func assetFileName(fileName string) string { + return "/" + assetFileNames[fileName] } func allAssetPaths() []string { var paths []string - for _, name := range assetFiles { + for _, name := range assetFileNames { paths = append(paths, "/"+name) } return paths @@ -89,6 +90,13 @@ func allAssetPaths() []string { // Gets only called by registered paths func serveAsset(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Cache-Control", "public,max-age=31536000,immutable") - http.ServeFile(w, r, filepath.Join(compiledAssetsFolder, r.URL.Path)) + f := strings.TrimPrefix(r.URL.Path, "/") + af, ok := assetFiles[f] + if !ok { + serve404(w, r) + return + } + w.Header().Set("Cache-Control", "public,max-age=31536000,immutable") + w.Header().Set(contentType, af.contentType) + w.Write(af.body) }