diff --git a/config.go b/config.go index d91ce3f..8c810a5 100644 --- a/config.go +++ b/config.go @@ -38,6 +38,8 @@ type configBlog struct { Title string `mapstructure:"title"` // Number of posts per page Pagination int `mapstructure:"pagination"` + // Sections + Sections []string `mapstructure:"sections"` } type configUser struct { @@ -68,6 +70,7 @@ func initConfig() error { viper.SetDefault("blog.lang", "en") viper.SetDefault("blog.title", "My blog") viper.SetDefault("blog.pagination", 10) + viper.SetDefault("blog.sections", []string{"posts"}) viper.SetDefault("user.nick", "admin") viper.SetDefault("user.name", "Admin") viper.SetDefault("user.password", "secret") diff --git a/example-config.yaml b/example-config.yaml index 0b261f7..3a7950c 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -13,6 +13,8 @@ cache: blog: lang: en title: My blog + sections: + - posts user: nick: admin name: Admin diff --git a/http.go b/http.go index 0306db0..203fd1c 100644 --- a/http.go +++ b/http.go @@ -85,13 +85,20 @@ func buildHandler() (http.Handler, error) { } } + for _, section := range appConfig.Blog.Sections { + if section != "" { + r.With(cacheMiddleware, minifier.Middleware).Get("/"+section, serveSection("/"+section, section)) + r.With(cacheMiddleware, minifier.Middleware).Get("/"+section+"/page/{page}", serveSection("/"+section, section)) + } + } + routePatterns := routesToStringSlice(r.Routes()) if !routePatterns.has("/") { - r.With(cacheMiddleware, minifier.Middleware).Get("/", serveIndex("/")) - r.With(cacheMiddleware, minifier.Middleware).Get("/page/{page}", serveIndex("/")) + r.With(cacheMiddleware, minifier.Middleware).Get("/", serveHome("/")) + r.With(cacheMiddleware, minifier.Middleware).Get("/page/{page}", serveHome("/")) } else if !routePatterns.has("/blog") { - r.With(cacheMiddleware, minifier.Middleware).Get("/blog", serveIndex("/blog")) - r.With(cacheMiddleware, minifier.Middleware).Get("/blog/page/{page}", serveIndex("/blog")) + r.With(cacheMiddleware, minifier.Middleware).Get("/blog", serveHome("/blog")) + r.With(cacheMiddleware, minifier.Middleware).Get("/blog/page/{page}", serveHome("/blog")) } r.With(minifier.Middleware).NotFound(serve404) diff --git a/posts.go b/posts.go index f901853..062bbcb 100644 --- a/posts.go +++ b/posts.go @@ -46,12 +46,13 @@ type indexTemplateDate struct { type postPaginationAdapter struct { context context.Context + config *postsRequestConfig nums int } func (p *postPaginationAdapter) Nums() int { if p.nums == 0 { - p.nums, _ = countPosts(p.context, &postsRequestConfig{}) + p.nums, _ = countPosts(p.context, p.config) } return p.nums } @@ -62,18 +63,31 @@ func (p *postPaginationAdapter) Slice(offset, length int, data interface{}) erro } posts, err := getPosts(p.context, &postsRequestConfig{ - offset: offset, - limit: length, + sections: p.config.sections, + offset: offset, + limit: length, }) reflect.ValueOf(data).Elem().Set(reflect.ValueOf(&posts).Elem()) return err } -func serveIndex(path string) func(w http.ResponseWriter, r *http.Request) { +func serveHome(path string) func(w http.ResponseWriter, r *http.Request) { + return serveIndex(path, "") +} + +func serveSection(path, section string) func(w http.ResponseWriter, r *http.Request) { + return serveIndex(path, section) +} + +func serveIndex(path string, section string) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { pageNoString := chi.URLParam(r, "page") pageNo, _ := strconv.Atoi(pageNoString) - p := paginator.New(&postPaginationAdapter{context: r.Context()}, appConfig.Blog.Pagination) + sections := appConfig.Blog.Sections + if len(section) > 0 { + sections = []string{section} + } + p := paginator.New(&postPaginationAdapter{context: r.Context(), config: &postsRequestConfig{sections: sections}}, appConfig.Blog.Pagination) p.SetPage(pageNo) var posts []*Post err := p.Results(&posts) @@ -110,22 +124,34 @@ func getPost(context context.Context, path string) (*Post, error) { } type postsRequestConfig struct { - path string - limit int - offset int + path string + limit int + offset int + sections []string } func getPosts(context context.Context, config *postsRequestConfig) (posts []*Post, err error) { paths := make(map[string]int) var rows *sql.Rows defaultSelection := "select p.path, coalesce(content, ''), coalesce(published, ''), coalesce(updated, ''), coalesce(parameter, ''), coalesce(value, '') " - defaultTables := " from posts p left outer join post_parameters pp on p.path = pp.path " - defaultSorting := " order by p.updated desc " + postsTable := "posts" + if len(config.sections) != 0 { + postsTable = "(select * from posts where" + for i, section := range config.sections { + if i > 0 { + postsTable += " or" + } + postsTable += " path like '/" + section + "/%'" + } + postsTable += ")" + } + defaultTables := " from " + postsTable + " p left outer join post_parameters pp on p.path = pp.path " + defaultSorting := " order by coalesce(p.updated, p.published) desc " if config.path != "" { query := defaultSelection + defaultTables + " where p.path=?" + defaultSorting rows, err = appDb.QueryContext(context, query, config.path) } else if config.limit != 0 || config.offset != 0 { - query := defaultSelection + " from (select * from posts p " + defaultSorting + " limit ? offset ?) p left outer join post_parameters pp on p.path = pp.path " + query := defaultSelection + " from (select * from " + postsTable + " p " + defaultSorting + " limit ? offset ?) p left outer join post_parameters pp on p.path = pp.path " rows, err = appDb.QueryContext(context, query, config.limit, config.offset) } else { query := defaultSelection + defaultTables + defaultSorting @@ -157,9 +183,9 @@ func getPosts(context context.Context, config *postsRequestConfig) (posts []*Pos return posts, nil } -func countPosts(context context.Context, _ *postsRequestConfig) (count int, err error) { - err = appDb.QueryRowContext(context, "select count(*) from posts").Scan(&count) - return +func countPosts(context context.Context, config *postsRequestConfig) (int, error) { + posts, err := getPosts(context, config) + return len(posts), err } func allPostPaths() ([]string, error) {