diff --git a/http.go b/http.go index 8112aa1..5a2e8e4 100644 --- a/http.go +++ b/http.go @@ -9,27 +9,23 @@ import ( "os" "os/signal" "strconv" + "sync" "syscall" "time" ) -func startServer() { - r := chi.NewRouter() - - if appConfig.server.logging { - r.Use(middleware.RealIP) - r.Use(middleware.Logger) +func startServer() error { + d := newDynamicHandler() + h, err := buildHandler() + if err != nil { + return err } - r.Use(middleware.Recoverer) - r.Use(middleware.StripSlashes) - - r.Get("/", hello) - r.Get("/*", servePost) + d.swapHandler(h) address := ":" + strconv.Itoa(appConfig.server.port) srv := &http.Server{ Addr: address, - Handler: r, + Handler: d, } go func() { @@ -45,10 +41,58 @@ func startServer() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { - log.Fatal(err) + return err } + return nil +} + +func buildHandler() (http.Handler, error) { + r := chi.NewRouter() + + if appConfig.server.logging { + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + } + r.Use(middleware.Recoverer) + r.Use(middleware.StripSlashes) + + r.Get("/", hello) + + allPostPaths, err := allPostPaths() + if err != nil { + return nil, err + } else { + for _, path := range allPostPaths { + if path != "" { + r.Get("/"+path, servePost) + } + } + } + + return r, nil } func hello(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Hello World!")) } + +type dynamicHandler struct { + realHandler http.Handler + changeMutex *sync.Mutex +} + +func newDynamicHandler() *dynamicHandler { + return &dynamicHandler{ + changeMutex: &sync.Mutex{}, + } +} + +func (d *dynamicHandler) swapHandler(h http.Handler) { + d.changeMutex.Lock() + d.realHandler = h + d.changeMutex.Unlock() +} + +func (d *dynamicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + d.realHandler.ServeHTTP(w, r) +} diff --git a/main.go b/main.go index e7adfe6..c044471 100644 --- a/main.go +++ b/main.go @@ -23,5 +23,8 @@ func main() { }() log.Println("Loaded database") log.Println("Start server") - startServer() + err = startServer() + if err != nil { + log.Fatal(err) + } } diff --git a/posts.go b/posts.go index c6dc52f..817837b 100644 --- a/posts.go +++ b/posts.go @@ -1,6 +1,7 @@ package main import ( + "context" "database/sql" "errors" "net/http" @@ -17,7 +18,7 @@ type post struct { } func servePost(w http.ResponseWriter, r *http.Request) { - post, err := getPost(strings.TrimSuffix(strings.TrimPrefix(r.RequestURI, "/"), "/")) + post, err := getPost(strings.TrimSuffix(strings.TrimPrefix(r.RequestURI, "/"), "/"), r.Context()) if err == postNotFound { http.NotFound(w, r) return @@ -34,9 +35,9 @@ func servePost(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(htmlContent) } -func getPost(path string) (*post, error) { +func getPost(path string, context context.Context) (*post, error) { queriedPost := &post{} - row := appDb.QueryRow("select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path) + row := appDb.QueryRowContext(context, "select path, COALESCE(content, ''), COALESCE(published, ''), COALESCE(updated, '') from posts where path=?", path) err := row.Scan(&queriedPost.path, &queriedPost.content, &queriedPost.published, &queriedPost.updated) if err == sql.ErrNoRows { return nil, postNotFound @@ -45,3 +46,17 @@ func getPost(path string) (*post, error) { } return queriedPost, nil } + +func allPostPaths() ([]string, error) { + var postPaths []string + rows, err := appDb.Query("select path from posts") + if err != nil { + return nil, err + } + for rows.Next() { + var path string + _ = rows.Scan(&path) + postPaths = append(postPaths, path) + } + return postPaths, nil +}