Various code improvements

This commit is contained in:
Jan-Lukas Else 2020-07-30 21:18:13 +02:00
parent e8bf6af11b
commit c0c4fa04e0
5 changed files with 44 additions and 49 deletions

View File

@ -9,14 +9,14 @@ import (
"time" "time"
) )
func CacheMiddleware(next http.Handler) http.Handler { func cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestUrl, _ := url.ParseRequestURI(r.RequestURI) requestUrl, _ := url.ParseRequestURI(r.RequestURI)
path := SlashTrimmedPath(r) path := slashTrimmedPath(r)
if appConfig.cache.enable && if appConfig.cache.enable &&
// check bypass query // check bypass query
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") { !(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
cacheTime, header, body := getCache(path, r.Context()) cacheTime, header, body := getCache(r.Context(), path)
if cacheTime == 0 { if cacheTime == 0 {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
next.ServeHTTP(recorder, r) next.ServeHTTP(recorder, r)
@ -36,31 +36,29 @@ func CacheMiddleware(next http.Handler) http.Handler {
saveCache(path, now, recorder.Header(), recorder.Body.Bytes()) saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
} }
return return
} else { }
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123) cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123) expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
// check conditional request // check conditional request
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since") ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString { if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
setCacheHeaders(w, cacheTimeString, expiresTimeString)
// send 304
w.WriteHeader(http.StatusNotModified)
return
}
// copy cached headers
for k, v := range header {
w.Header()[k] = v
}
setCacheHeaders(w, cacheTimeString, expiresTimeString) setCacheHeaders(w, cacheTimeString, expiresTimeString)
w.Header().Set("GoBlog-Cache", "HIT") // send 304
// write cached body w.WriteHeader(http.StatusNotModified)
_, _ = w.Write(body)
return return
} }
} else { // copy cached headers
next.ServeHTTP(w, r) for k, v := range header {
w.Header()[k] = v
}
setCacheHeaders(w, cacheTimeString, expiresTimeString)
w.Header().Set("GoBlog-Cache", "HIT")
// write cached body
_, _ = w.Write(body)
return return
} }
next.ServeHTTP(w, r)
return
}) })
} }
@ -70,7 +68,7 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS
w.Header().Set("Expires", expiresTimeString) w.Header().Set("Expires", expiresTimeString)
} }
func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) { func getCache(context context.Context, path string) (creationTime int64, header map[string][]string, body []byte) {
var headerBytes []byte var headerBytes []byte
allowedTime := time.Now().Unix() - appConfig.cache.expiration allowedTime := time.Now().Unix() - appConfig.cache.expiration
row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime) row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime)

View File

@ -30,7 +30,7 @@ func closeDb() error {
return appDb.Close() return appDb.Close()
} }
func vacuumDb() { func vacuumDb() {
startWritingToDb() startWritingToDb()
_, _ = appDb.Exec("VACUUM;") _, _ = appDb.Exec("VACUUM;")
finishWritingToDb() finishWritingToDb()

20
http.go
View File

@ -61,22 +61,20 @@ func buildHandler() (http.Handler, error) {
allPostPaths, err := allPostPaths() allPostPaths, err := allPostPaths()
if err != nil { if err != nil {
return nil, err return nil, err
} else { }
for _, path := range allPostPaths { for _, path := range allPostPaths {
if path != "" { if path != "" {
r.With(CacheMiddleware).Get(path, servePost) r.With(cacheMiddleware).Get(path, servePost)
}
} }
} }
allRedirectPaths, err := allRedirectPaths() allRedirectPaths, err := allRedirectPaths()
if err != nil { if err != nil {
return nil, err return nil, err
} else { }
for _, path := range allRedirectPaths { for _, path := range allRedirectPaths {
if path != "" { if path != "" {
r.Get(path, serveRedirect) r.Get(path, serveRedirect)
}
} }
} }
@ -104,7 +102,7 @@ func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
d.realHandler.ServeHTTP(w, r) d.realHandler.ServeHTTP(w, r)
} }
func SlashTrimmedPath(r *http.Request) string { func slashTrimmedPath(r *http.Request) string {
path := r.URL.Path path := r.URL.Path
if len(path) > 1 { if len(path) > 1 {
path = strings.TrimSuffix(path, "/") path = strings.TrimSuffix(path, "/")

View File

@ -7,7 +7,7 @@ import (
"net/http" "net/http"
) )
var postNotFound = errors.New("post not found") var errPostNotFound = errors.New("post not found")
type post struct { type post struct {
path string path string
@ -18,9 +18,9 @@ type post struct {
} }
func servePost(w http.ResponseWriter, r *http.Request) { func servePost(w http.ResponseWriter, r *http.Request) {
path := SlashTrimmedPath(r) path := slashTrimmedPath(r)
post, err := getPost(path, r.Context()) post, err := getPost(r.Context(), path)
if err == postNotFound { if err == errPostNotFound {
http.NotFound(w, r) http.NotFound(w, r)
return return
} else if err != nil { } else if err != nil {
@ -36,12 +36,12 @@ func servePost(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(htmlContent) _, _ = w.Write(htmlContent)
} }
func getPost(path string, context context.Context) (*post, error) { func getPost(context context.Context, path string) (*post, error) {
queriedPost := &post{} queriedPost := &post{}
row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path) row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path)
err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated) err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, postNotFound return nil, errPostNotFound
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} }

View File

@ -7,27 +7,26 @@ import (
"net/http" "net/http"
) )
var redirectNotFound = errors.New("redirect not found") var errRedirectNotFound = errors.New("redirect not found")
func serveRedirect(w http.ResponseWriter, r *http.Request) { func serveRedirect(w http.ResponseWriter, r *http.Request) {
redirect, err := getRedirect(SlashTrimmedPath(r), r.Context()) redirect, err := getRedirect(r.Context(), slashTrimmedPath(r))
if err == redirectNotFound { if err == errRedirectNotFound {
http.NotFound(w, r) http.NotFound(w, r)
return return
} else if err != nil { } else if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
// TODO: Change status code http.Redirect(w, r, redirect, http.StatusFound)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
} }
func getRedirect(fromPath string, context context.Context) (string, error) { func getRedirect(context context.Context, fromPath string) (string, error) {
var toPath string var toPath string
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath) row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
err := row.Scan(&toPath) err := row.Scan(&toPath)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", redirectNotFound return "", errRedirectNotFound
} else if err != nil { } else if err != nil {
return "", err return "", err
} }