mirror of https://github.com/jlelse/GoBlog
Various code improvements
This commit is contained in:
parent
e8bf6af11b
commit
c0c4fa04e0
46
cache.go
46
cache.go
|
@ -9,14 +9,14 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func CacheMiddleware(next http.Handler) http.Handler {
|
||||
func cacheMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestUrl, _ := url.ParseRequestURI(r.RequestURI)
|
||||
path := SlashTrimmedPath(r)
|
||||
path := slashTrimmedPath(r)
|
||||
if appConfig.cache.enable &&
|
||||
// check bypass query
|
||||
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
|
||||
cacheTime, header, body := getCache(path, r.Context())
|
||||
cacheTime, header, body := getCache(r.Context(), path)
|
||||
if cacheTime == 0 {
|
||||
recorder := httptest.NewRecorder()
|
||||
next.ServeHTTP(recorder, r)
|
||||
|
@ -36,31 +36,29 @@ func CacheMiddleware(next http.Handler) http.Handler {
|
|||
saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
|
||||
}
|
||||
return
|
||||
} else {
|
||||
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
|
||||
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
|
||||
// check conditional request
|
||||
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
|
||||
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
|
||||
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
||||
// send 304
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
// copy cached headers
|
||||
for k, v := range header {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
}
|
||||
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
|
||||
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
|
||||
// check conditional request
|
||||
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
|
||||
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
|
||||
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
||||
w.Header().Set("GoBlog-Cache", "HIT")
|
||||
// write cached body
|
||||
_, _ = w.Write(body)
|
||||
// send 304
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
// copy cached headers
|
||||
for k, v := range header {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
||||
w.Header().Set("GoBlog-Cache", "HIT")
|
||||
// write cached body
|
||||
_, _ = w.Write(body)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -70,7 +68,7 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS
|
|||
w.Header().Set("Expires", expiresTimeString)
|
||||
}
|
||||
|
||||
func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) {
|
||||
func getCache(context context.Context, path string) (creationTime int64, header map[string][]string, body []byte) {
|
||||
var headerBytes []byte
|
||||
allowedTime := time.Now().Unix() - appConfig.cache.expiration
|
||||
row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime)
|
||||
|
|
|
@ -30,7 +30,7 @@ func closeDb() error {
|
|||
return appDb.Close()
|
||||
}
|
||||
|
||||
func vacuumDb() {
|
||||
func vacuumDb() {
|
||||
startWritingToDb()
|
||||
_, _ = appDb.Exec("VACUUM;")
|
||||
finishWritingToDb()
|
||||
|
|
20
http.go
20
http.go
|
@ -61,22 +61,20 @@ func buildHandler() (http.Handler, error) {
|
|||
allPostPaths, err := allPostPaths()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
for _, path := range allPostPaths {
|
||||
if path != "" {
|
||||
r.With(CacheMiddleware).Get(path, servePost)
|
||||
}
|
||||
}
|
||||
for _, path := range allPostPaths {
|
||||
if path != "" {
|
||||
r.With(cacheMiddleware).Get(path, servePost)
|
||||
}
|
||||
}
|
||||
|
||||
allRedirectPaths, err := allRedirectPaths()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
for _, path := range allRedirectPaths {
|
||||
if path != "" {
|
||||
r.Get(path, serveRedirect)
|
||||
}
|
||||
}
|
||||
for _, path := range allRedirectPaths {
|
||||
if path != "" {
|
||||
r.Get(path, serveRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,7 +102,7 @@ func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
d.realHandler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func SlashTrimmedPath(r *http.Request) string {
|
||||
func slashTrimmedPath(r *http.Request) string {
|
||||
path := r.URL.Path
|
||||
if len(path) > 1 {
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
|
12
posts.go
12
posts.go
|
@ -7,7 +7,7 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
var postNotFound = errors.New("post not found")
|
||||
var errPostNotFound = errors.New("post not found")
|
||||
|
||||
type post struct {
|
||||
path string
|
||||
|
@ -18,9 +18,9 @@ type post struct {
|
|||
}
|
||||
|
||||
func servePost(w http.ResponseWriter, r *http.Request) {
|
||||
path := SlashTrimmedPath(r)
|
||||
post, err := getPost(path, r.Context())
|
||||
if err == postNotFound {
|
||||
path := slashTrimmedPath(r)
|
||||
post, err := getPost(r.Context(), path)
|
||||
if err == errPostNotFound {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
} else if err != nil {
|
||||
|
@ -36,12 +36,12 @@ func servePost(w http.ResponseWriter, r *http.Request) {
|
|||
_, _ = w.Write(htmlContent)
|
||||
}
|
||||
|
||||
func getPost(path string, context context.Context) (*post, error) {
|
||||
func getPost(context context.Context, path string) (*post, error) {
|
||||
queriedPost := &post{}
|
||||
row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path)
|
||||
err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, postNotFound
|
||||
return nil, errPostNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
13
redirects.go
13
redirects.go
|
@ -7,27 +7,26 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
var redirectNotFound = errors.New("redirect not found")
|
||||
var errRedirectNotFound = errors.New("redirect not found")
|
||||
|
||||
func serveRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
redirect, err := getRedirect(SlashTrimmedPath(r), r.Context())
|
||||
if err == redirectNotFound {
|
||||
redirect, err := getRedirect(r.Context(), slashTrimmedPath(r))
|
||||
if err == errRedirectNotFound {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
} else if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// TODO: Change status code
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
}
|
||||
|
||||
func getRedirect(fromPath string, context context.Context) (string, error) {
|
||||
func getRedirect(context context.Context, fromPath string) (string, error) {
|
||||
var toPath string
|
||||
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
|
||||
err := row.Scan(&toPath)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", redirectNotFound
|
||||
return "", errRedirectNotFound
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue