Various code improvements

This commit is contained in:
Jan-Lukas Else 2020-07-30 21:18:13 +02:00
parent e8bf6af11b
commit c0c4fa04e0
5 changed files with 44 additions and 49 deletions

View File

@ -9,14 +9,14 @@ import (
"time"
)
func CacheMiddleware(next http.Handler) http.Handler {
func cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestUrl, _ := url.ParseRequestURI(r.RequestURI)
path := SlashTrimmedPath(r)
path := slashTrimmedPath(r)
if appConfig.cache.enable &&
// check bypass query
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
cacheTime, header, body := getCache(path, r.Context())
cacheTime, header, body := getCache(r.Context(), path)
if cacheTime == 0 {
recorder := httptest.NewRecorder()
next.ServeHTTP(recorder, r)
@ -36,31 +36,29 @@ func CacheMiddleware(next http.Handler) http.Handler {
saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
}
return
} else {
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
// check conditional request
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
setCacheHeaders(w, cacheTimeString, expiresTimeString)
// send 304
w.WriteHeader(http.StatusNotModified)
return
}
// copy cached headers
for k, v := range header {
w.Header()[k] = v
}
}
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
// check conditional request
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
setCacheHeaders(w, cacheTimeString, expiresTimeString)
w.Header().Set("GoBlog-Cache", "HIT")
// write cached body
_, _ = w.Write(body)
// send 304
w.WriteHeader(http.StatusNotModified)
return
}
} else {
next.ServeHTTP(w, r)
// copy cached headers
for k, v := range header {
w.Header()[k] = v
}
setCacheHeaders(w, cacheTimeString, expiresTimeString)
w.Header().Set("GoBlog-Cache", "HIT")
// write cached body
_, _ = w.Write(body)
return
}
next.ServeHTTP(w, r)
return
})
}
@ -70,7 +68,7 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS
w.Header().Set("Expires", expiresTimeString)
}
func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) {
func getCache(context context.Context, path string) (creationTime int64, header map[string][]string, body []byte) {
var headerBytes []byte
allowedTime := time.Now().Unix() - appConfig.cache.expiration
row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime)

View File

@ -30,7 +30,7 @@ func closeDb() error {
return appDb.Close()
}
func vacuumDb() {
func vacuumDb() {
startWritingToDb()
_, _ = appDb.Exec("VACUUM;")
finishWritingToDb()

20
http.go
View File

@ -61,22 +61,20 @@ func buildHandler() (http.Handler, error) {
allPostPaths, err := allPostPaths()
if err != nil {
return nil, err
} else {
for _, path := range allPostPaths {
if path != "" {
r.With(CacheMiddleware).Get(path, servePost)
}
}
for _, path := range allPostPaths {
if path != "" {
r.With(cacheMiddleware).Get(path, servePost)
}
}
allRedirectPaths, err := allRedirectPaths()
if err != nil {
return nil, err
} else {
for _, path := range allRedirectPaths {
if path != "" {
r.Get(path, serveRedirect)
}
}
for _, path := range allRedirectPaths {
if path != "" {
r.Get(path, serveRedirect)
}
}
@ -104,7 +102,7 @@ func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
d.realHandler.ServeHTTP(w, r)
}
func SlashTrimmedPath(r *http.Request) string {
func slashTrimmedPath(r *http.Request) string {
path := r.URL.Path
if len(path) > 1 {
path = strings.TrimSuffix(path, "/")

View File

@ -7,7 +7,7 @@ import (
"net/http"
)
var postNotFound = errors.New("post not found")
var errPostNotFound = errors.New("post not found")
type post struct {
path string
@ -18,9 +18,9 @@ type post struct {
}
func servePost(w http.ResponseWriter, r *http.Request) {
path := SlashTrimmedPath(r)
post, err := getPost(path, r.Context())
if err == postNotFound {
path := slashTrimmedPath(r)
post, err := getPost(r.Context(), path)
if err == errPostNotFound {
http.NotFound(w, r)
return
} else if err != nil {
@ -36,12 +36,12 @@ func servePost(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(htmlContent)
}
func getPost(path string, context context.Context) (*post, error) {
func getPost(context context.Context, path string) (*post, error) {
queriedPost := &post{}
row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path)
err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated)
if err == sql.ErrNoRows {
return nil, postNotFound
return nil, errPostNotFound
} else if err != nil {
return nil, err
}

View File

@ -7,27 +7,26 @@ import (
"net/http"
)
var redirectNotFound = errors.New("redirect not found")
var errRedirectNotFound = errors.New("redirect not found")
func serveRedirect(w http.ResponseWriter, r *http.Request) {
redirect, err := getRedirect(SlashTrimmedPath(r), r.Context())
if err == redirectNotFound {
redirect, err := getRedirect(r.Context(), slashTrimmedPath(r))
if err == errRedirectNotFound {
http.NotFound(w, r)
return
} else if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// TODO: Change status code
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
http.Redirect(w, r, redirect, http.StatusFound)
}
func getRedirect(fromPath string, context context.Context) (string, error) {
func getRedirect(context context.Context, fromPath string) (string, error) {
var toPath string
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
err := row.Scan(&toPath)
if err == sql.ErrNoRows {
return "", redirectNotFound
return "", errRedirectNotFound
} else if err != nil {
return "", err
}