mirror of https://github.com/jlelse/GoBlog
Various code improvements
This commit is contained in:
parent
e8bf6af11b
commit
c0c4fa04e0
46
cache.go
46
cache.go
|
@ -9,14 +9,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CacheMiddleware(next http.Handler) http.Handler {
|
func cacheMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
requestUrl, _ := url.ParseRequestURI(r.RequestURI)
|
requestUrl, _ := url.ParseRequestURI(r.RequestURI)
|
||||||
path := SlashTrimmedPath(r)
|
path := slashTrimmedPath(r)
|
||||||
if appConfig.cache.enable &&
|
if appConfig.cache.enable &&
|
||||||
// check bypass query
|
// check bypass query
|
||||||
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
|
!(requestUrl != nil && requestUrl.Query().Get("cache") == "0") {
|
||||||
cacheTime, header, body := getCache(path, r.Context())
|
cacheTime, header, body := getCache(r.Context(), path)
|
||||||
if cacheTime == 0 {
|
if cacheTime == 0 {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
next.ServeHTTP(recorder, r)
|
next.ServeHTTP(recorder, r)
|
||||||
|
@ -36,31 +36,29 @@ func CacheMiddleware(next http.Handler) http.Handler {
|
||||||
saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
|
saveCache(path, now, recorder.Header(), recorder.Body.Bytes())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
|
cacheTimeString := time.Unix(cacheTime, 0).Format(time.RFC1123)
|
||||||
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
|
expiresTimeString := time.Unix(cacheTime+appConfig.cache.expiration, 0).Format(time.RFC1123)
|
||||||
// check conditional request
|
// check conditional request
|
||||||
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
|
ifModifiedSinceHeader := r.Header.Get("If-Modified-Since")
|
||||||
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
|
if ifModifiedSinceHeader != "" && ifModifiedSinceHeader == cacheTimeString {
|
||||||
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
|
||||||
// send 304
|
|
||||||
w.WriteHeader(http.StatusNotModified)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// copy cached headers
|
|
||||||
for k, v := range header {
|
|
||||||
w.Header()[k] = v
|
|
||||||
}
|
|
||||||
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
||||||
w.Header().Set("GoBlog-Cache", "HIT")
|
// send 304
|
||||||
// write cached body
|
w.WriteHeader(http.StatusNotModified)
|
||||||
_, _ = w.Write(body)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
// copy cached headers
|
||||||
next.ServeHTTP(w, r)
|
for k, v := range header {
|
||||||
|
w.Header()[k] = v
|
||||||
|
}
|
||||||
|
setCacheHeaders(w, cacheTimeString, expiresTimeString)
|
||||||
|
w.Header().Set("GoBlog-Cache", "HIT")
|
||||||
|
// write cached body
|
||||||
|
_, _ = w.Write(body)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +68,7 @@ func setCacheHeaders(w http.ResponseWriter, cacheTimeString string, expiresTimeS
|
||||||
w.Header().Set("Expires", expiresTimeString)
|
w.Header().Set("Expires", expiresTimeString)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCache(path string, context context.Context) (creationTime int64, header map[string][]string, body []byte) {
|
func getCache(context context.Context, path string) (creationTime int64, header map[string][]string, body []byte) {
|
||||||
var headerBytes []byte
|
var headerBytes []byte
|
||||||
allowedTime := time.Now().Unix() - appConfig.cache.expiration
|
allowedTime := time.Now().Unix() - appConfig.cache.expiration
|
||||||
row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime)
|
row := appDb.QueryRowContext(context, "select COALESCE(time, 0), header, body from cache where path=? and time>=?", path, allowedTime)
|
||||||
|
|
|
@ -30,7 +30,7 @@ func closeDb() error {
|
||||||
return appDb.Close()
|
return appDb.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func vacuumDb() {
|
func vacuumDb() {
|
||||||
startWritingToDb()
|
startWritingToDb()
|
||||||
_, _ = appDb.Exec("VACUUM;")
|
_, _ = appDb.Exec("VACUUM;")
|
||||||
finishWritingToDb()
|
finishWritingToDb()
|
||||||
|
|
20
http.go
20
http.go
|
@ -61,22 +61,20 @@ func buildHandler() (http.Handler, error) {
|
||||||
allPostPaths, err := allPostPaths()
|
allPostPaths, err := allPostPaths()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
}
|
||||||
for _, path := range allPostPaths {
|
for _, path := range allPostPaths {
|
||||||
if path != "" {
|
if path != "" {
|
||||||
r.With(CacheMiddleware).Get(path, servePost)
|
r.With(cacheMiddleware).Get(path, servePost)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allRedirectPaths, err := allRedirectPaths()
|
allRedirectPaths, err := allRedirectPaths()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
}
|
||||||
for _, path := range allRedirectPaths {
|
for _, path := range allRedirectPaths {
|
||||||
if path != "" {
|
if path != "" {
|
||||||
r.Get(path, serveRedirect)
|
r.Get(path, serveRedirect)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +102,7 @@ func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
d.realHandler.ServeHTTP(w, r)
|
d.realHandler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SlashTrimmedPath(r *http.Request) string {
|
func slashTrimmedPath(r *http.Request) string {
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
if len(path) > 1 {
|
if len(path) > 1 {
|
||||||
path = strings.TrimSuffix(path, "/")
|
path = strings.TrimSuffix(path, "/")
|
||||||
|
|
12
posts.go
12
posts.go
|
@ -7,7 +7,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var postNotFound = errors.New("post not found")
|
var errPostNotFound = errors.New("post not found")
|
||||||
|
|
||||||
type post struct {
|
type post struct {
|
||||||
path string
|
path string
|
||||||
|
@ -18,9 +18,9 @@ type post struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func servePost(w http.ResponseWriter, r *http.Request) {
|
func servePost(w http.ResponseWriter, r *http.Request) {
|
||||||
path := SlashTrimmedPath(r)
|
path := slashTrimmedPath(r)
|
||||||
post, err := getPost(path, r.Context())
|
post, err := getPost(r.Context(), path)
|
||||||
if err == postNotFound {
|
if err == errPostNotFound {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
@ -36,12 +36,12 @@ func servePost(w http.ResponseWriter, r *http.Request) {
|
||||||
_, _ = w.Write(htmlContent)
|
_, _ = w.Write(htmlContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPost(path string, context context.Context) (*post, error) {
|
func getPost(context context.Context, path string) (*post, error) {
|
||||||
queriedPost := &post{}
|
queriedPost := &post{}
|
||||||
row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path)
|
row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path)
|
||||||
err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated)
|
err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, postNotFound
|
return nil, errPostNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
13
redirects.go
13
redirects.go
|
@ -7,27 +7,26 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var redirectNotFound = errors.New("redirect not found")
|
var errRedirectNotFound = errors.New("redirect not found")
|
||||||
|
|
||||||
func serveRedirect(w http.ResponseWriter, r *http.Request) {
|
func serveRedirect(w http.ResponseWriter, r *http.Request) {
|
||||||
redirect, err := getRedirect(SlashTrimmedPath(r), r.Context())
|
redirect, err := getRedirect(r.Context(), slashTrimmedPath(r))
|
||||||
if err == redirectNotFound {
|
if err == errRedirectNotFound {
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// TODO: Change status code
|
http.Redirect(w, r, redirect, http.StatusFound)
|
||||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRedirect(fromPath string, context context.Context) (string, error) {
|
func getRedirect(context context.Context, fromPath string) (string, error) {
|
||||||
var toPath string
|
var toPath string
|
||||||
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
|
row := appDb.QueryRowContext(context, "select toPath from redirects where fromPath=?", fromPath)
|
||||||
err := row.Scan(&toPath)
|
err := row.Scan(&toPath)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", redirectNotFound
|
return "", errRedirectNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue