diff --git a/blogroll.go b/blogroll.go index 160f348..218574e 100644 --- a/blogroll.go +++ b/blogroll.go @@ -27,10 +27,9 @@ func (a *goBlog) serveBlogroll(w http.ResponseWriter, r *http.Request) { return } c := bc.Blogroll - can := a.getRelativePath(blog, defaultIfEmpty(c.Path, defaultBlogrollPath)) + can := bc.getRelativePath(defaultIfEmpty(c.Path, defaultBlogrollPath)) a.render(w, r, templateBlogroll, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(can), + Canonical: a.getFullAddress(can), Data: map[string]interface{}{ "Title": c.Title, "Description": c.Description, diff --git a/blogstats.go b/blogstats.go index 41dcf61..ebe526c 100644 --- a/blogstats.go +++ b/blogstats.go @@ -22,11 +22,10 @@ func (a *goBlog) initBlogStats() { } func (a *goBlog) serveBlogStats(w http.ResponseWriter, r *http.Request) { - blog, bc := a.getBlog(r) + _, bc := a.getBlog(r) canonical := bc.getRelativePath(defaultIfEmpty(bc.BlogStats.Path, defaultBlogStatsPath)) a.render(w, r, templateBlogStats, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(canonical), + Canonical: a.getFullAddress(canonical), Data: map[string]interface{}{ "TableUrl": canonical + blogStatsTablePath, }, @@ -34,7 +33,7 @@ func (a *goBlog) serveBlogStats(w http.ResponseWriter, r *http.Request) { } func (a *goBlog) serveBlogStatsTable(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + blog, _ := a.getBlog(r) data, err, _ := a.blogStatsCacheGroup.Do(blog, func() (interface{}, error) { return a.db.getBlogStats(blog) }) @@ -44,8 +43,7 @@ func (a *goBlog) serveBlogStatsTable(w http.ResponseWriter, r *http.Request) { } // Render a.render(w, r, templateBlogStatsTable, &renderData{ - BlogString: blog, - Data: data, + Data: data, }) } diff --git a/captcha.go b/captcha.go index 5f7c8aa..396fba8 100644 --- a/captcha.go +++ b/captcha.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "time" "github.com/dchest/captcha" "go.goblog.app/app/pkgs/contenttype" @@ -15,6 +16,12 @@ import ( const captchaSolvedKey contextKey = "captchaSolved" +var captchaStore = captcha.NewMemoryStore(100, 10*time.Minute) + +func init() { + captcha.SetCustomStore(captchaStore) +} + func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if captcha already solved @@ -56,10 +63,9 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { b = []byte(r.PostForm.Encode()) } // Render captcha - ses.Save(r, w) + _ = ses.Save(r, w) w.Header().Set("Cache-Control", "no-store,max-age=0") a.renderWithStatusCode(w, r, http.StatusUnauthorized, templateCaptcha, &renderData{ - BlogString: r.Context().Value(blogKey).(string), Data: map[string]string{ "captchamethod": r.Method, "captchaheaders": base64.StdEncoding.EncodeToString(h), diff --git a/captcha_test.go b/captcha_test.go index e911276..25d68a5 100644 --- a/captcha_test.go +++ b/captcha_test.go @@ -1,12 +1,20 @@ package main import ( + "encoding/base64" "io" + "log" "net/http" "net/http/httptest" + "net/url" + "strconv" + "strings" "testing" + "github.com/PuerkitoBio/goquery" + "github.com/justinas/alice" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.goblog.app/app/pkgs/contenttype" ) @@ -19,14 +27,14 @@ func Test_captchaMiddleware(t *testing.T) { _ = app.initDatabase(false) app.initComponents(false) - h := app.captchaMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + app.d = alice.New(app.checkIsCaptcha, app.captchaMiddleware).ThenFunc(func(rw http.ResponseWriter, r *http.Request) { _, _ = rw.Write([]byte("ABC Test")) - })) + }) - t.Run("Default", func(t *testing.T) { + t.Run("Show captcha", func(t *testing.T) { rec := httptest.NewRecorder() - h.ServeHTTP(rec, reqWithDefaultBlog(httptest.NewRequest(http.MethodPost, "/abc", nil))) + app.d.ServeHTTP(rec, httptest.NewRequest(http.MethodPost, "/abc", nil)) res := rec.Result() resBody, _ := io.ReadAll(res.Body) @@ -38,7 +46,7 @@ func Test_captchaMiddleware(t *testing.T) { assert.Contains(t, resString, "name=captchamethod value=POST") }) - t.Run("Captcha session", func(t *testing.T) { + t.Run("Show no captcha, when already solved", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/abc", nil) rec1 := httptest.NewRecorder() @@ -52,7 +60,7 @@ func Test_captchaMiddleware(t *testing.T) { rec2 := httptest.NewRecorder() - h.ServeHTTP(rec2, req) + app.d.ServeHTTP(rec2, req) res := rec2.Result() resBody, _ := io.ReadAll(res.Body) @@ -62,4 +70,129 @@ func Test_captchaMiddleware(t *testing.T) { assert.Equal(t, http.StatusOK, res.StatusCode) assert.Contains(t, resString, "ABC Test") }) + + t.Run("Captcha flow", func(t *testing.T) { + // Do original request + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader("test")) + + app.d.ServeHTTP(rec, req) + + // Check response + res := rec.Result() + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + + // Check cookie + cookies := res.Cookies() + require.Len(t, cookies, 1) + captchaCookie := cookies[0] + assert.Equal(t, "c", captchaCookie.Name) + captchaSessionId := captchaCookie.Value + assert.NotEmpty(t, captchaSessionId) + + // Check session + sr := httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader("test")) + sr.AddCookie(captchaCookie) + session, err := app.captchaSessions.Get(sr, "c") + require.NoError(t, err) + assert.Equal(t, captchaSessionId, session.ID) + captchaId := session.Values["captchaid"].(string) + assert.NotEmpty(t, captchaId) + _, captchaSolved := session.Values["captcha"].(bool) + assert.False(t, captchaSolved) + + log.Println("Captcha ID:", captchaId) + + // Check form values + doc, err := goquery.NewDocumentFromReader(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + form := doc.Find("form") + cm := form.Find("input[name=captchamethod]") + assert.Equal(t, "POST", cm.AttrOr("value", "")) + ch := form.Find("input[name=captchaheaders]") + assert.NotEmpty(t, ch.AttrOr("value", "")) + cb := form.Find("input[name=captchabody]") + assert.NotEmpty(t, cb.AttrOr("value", "")) + dcb, _ := base64.StdEncoding.DecodeString(cb.AttrOr("value", "")) + assert.Equal(t, "test", string(dcb)) + ci := doc.Find("img.captchaimg") + assert.Contains(t, ci.AttrOr("src", ""), captchaId) + + // Do second request with wrong captcha + rec = httptest.NewRecorder() + + formValues := &url.Values{} + formValues.Add("captchaaction", "captcha") + formValues.Add("captchamethod", cm.AttrOr("value", "")) + formValues.Add("captchaheaders", ch.AttrOr("value", "")) + formValues.Add("captchabody", cb.AttrOr("value", "")) + formValues.Add("digits", "123456") // Wrong captcha + + req = httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader(formValues.Encode())) + req.Header.Set(contentType, contenttype.WWWForm) + req.AddCookie(captchaCookie) + + app.d.ServeHTTP(rec, req) + + // Check response + res = rec.Result() + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + + // Check cookie + require.Len(t, res.Cookies(), 1) + assert.Equal(t, captchaSessionId, res.Cookies()[0].Value) + + // Check session + sr = httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader("test")) + sr.AddCookie(captchaCookie) + session, err = app.captchaSessions.Get(sr, "c") + require.NoError(t, err) + assert.Equal(t, captchaSessionId, session.ID) + captchaId = session.Values["captchaid"].(string) + assert.NotEmpty(t, captchaId) + _, captchaSolved = session.Values["captcha"].(bool) + assert.False(t, captchaSolved) + + log.Println("Captcha ID:", captchaId) + + // Check form values + doc, err = goquery.NewDocumentFromReader(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + ci = doc.Find("img.captchaimg") + assert.Contains(t, ci.AttrOr("src", ""), captchaId) + + // Solve captcha + digits := captchaStore.Get(captchaId, false) + digitsString := "" + for _, digit := range digits { + digitsString += strconv.Itoa(int(digit)) + } + + // Do third request with solved captcha + rec = httptest.NewRecorder() + + formValues = &url.Values{} + formValues.Add("captchaaction", "captcha") + formValues.Add("captchamethod", cm.AttrOr("value", "")) + formValues.Add("captchaheaders", ch.AttrOr("value", "")) + formValues.Add("captchabody", cb.AttrOr("value", "")) + formValues.Add("digits", digitsString) // Correct captcha + + req = httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader(formValues.Encode())) + req.Header.Set(contentType, contenttype.WWWForm) + req.AddCookie(captchaCookie) + + app.d.ServeHTTP(rec, req) + + // Check response + res = rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Contains(t, resString, "ABC Test") + }) + } diff --git a/comments.go b/comments.go index cdd3a6e..58367ce 100644 --- a/comments.go +++ b/comments.go @@ -40,11 +40,10 @@ func (a *goBlog) serveComment(w http.ResponseWriter, r *http.Request) { a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) a.render(w, r, templateComment, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(id)))), - Data: comment, + Canonical: a.getFullAddress(bc.getRelativePath(path.Join(commentPath, strconv.Itoa(id)))), + Data: comment, }) } @@ -72,8 +71,8 @@ func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) { // Serve error a.serveError(w, r, err.Error(), http.StatusInternalServerError) } else { - blog := r.Context().Value(blogKey).(string) - commentAddress := a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(int(commentID)))) + _, bc := a.getBlog(r) + commentAddress := bc.getRelativePath(path.Join(commentPath, strconv.Itoa(int(commentID)))) // Send webmention _ = a.createWebmention(a.getFullAddress(commentAddress), a.getFullAddress(target)) // Redirect to comment diff --git a/commentsAdmin.go b/commentsAdmin.go index 1a06027..bb6a612 100644 --- a/commentsAdmin.go +++ b/commentsAdmin.go @@ -35,7 +35,6 @@ func (p *commentsPaginationAdapter) Slice(offset, length int, data interface{}) } func (a *goBlog) commentsAdmin(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) commentsPath := r.Context().Value(pathKey).(string) // Adapter pageNoString := chi.URLParam(r, "page") @@ -72,7 +71,6 @@ func (a *goBlog) commentsAdmin(w http.ResponseWriter, r *http.Request) { nextPath = fmt.Sprintf("%s/page/%d", commentsPath, nextPage) // Render a.render(w, r, templateCommentsAdmin, &renderData{ - BlogString: blog, Data: map[string]interface{}{ "Comments": comments, "HasPrev": hasPrev, diff --git a/config.go b/config.go index a6c5564..cc2607e 100644 --- a/config.go +++ b/config.go @@ -457,9 +457,11 @@ func (a *goBlog) getBlog(r *http.Request) (string, *configBlog) { if r == nil { return a.cfg.DefaultBlog, a.cfg.Blogs[a.cfg.DefaultBlog] } - blog := r.Context().Value(blogKey).(string) - if blog == "" { - return a.cfg.DefaultBlog, a.cfg.Blogs[a.cfg.DefaultBlog] + blog := a.cfg.DefaultBlog + if ctxBlog := r.Context().Value(blogKey); ctxBlog != nil { + if ctxBlogString, ok := ctxBlog.(string); ok { + blog = ctxBlogString + } } return blog, a.cfg.Blogs[blog] } diff --git a/contact.go b/contact.go index 50d5e84..ef96e20 100644 --- a/contact.go +++ b/contact.go @@ -15,10 +15,9 @@ import ( const defaultContactPath = "/contact" func (a *goBlog) serveContactForm(w http.ResponseWriter, r *http.Request) { - blog, bc := a.getBlog(r) + _, bc := a.getBlog(r) cc := bc.Contact a.render(w, r, templateContact, &renderData{ - BlogString: blog, Data: map[string]interface{}{ "title": cc.Title, "description": cc.Description, @@ -29,7 +28,7 @@ func (a *goBlog) serveContactForm(w http.ResponseWriter, r *http.Request) { func (a *goBlog) sendContactSubmission(w http.ResponseWriter, r *http.Request) { // Get blog - blog, bc := a.getBlog(r) + _, bc := a.getBlog(r) // Get form values and build message var message bytes.Buffer // Message @@ -65,7 +64,6 @@ func (a *goBlog) sendContactSubmission(w http.ResponseWriter, r *http.Request) { a.sendNotification(message.String()) // Give feedback a.render(w, r, templateContact, &renderData{ - BlogString: blog, Data: map[string]interface{}{ "sent": true, }, diff --git a/customPages.go b/customPages.go index 0209157..84401e8 100644 --- a/customPages.go +++ b/customPages.go @@ -7,8 +7,7 @@ const customPageContextKey = "custompage" func (a *goBlog) serveCustomPage(w http.ResponseWriter, r *http.Request) { page := r.Context().Value(customPageContextKey).(*configCustomPage) a.render(w, r, page.Template, &renderData{ - BlogString: r.Context().Value(blogKey).(string), - Canonical: a.getFullAddress(page.Path), - Data: page.Data, + Canonical: a.getFullAddress(page.Path), + Data: page.Data, }) } diff --git a/editor.go b/editor.go index 555d303..c1b5c1b 100644 --- a/editor.go +++ b/editor.go @@ -21,15 +21,13 @@ import ( const editorPath = "/editor" func (a *goBlog) serveEditor(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) a.render(w, r, templateEditor, &renderData{ - BlogString: blog, - Data: map[string]interface{}{}, + Data: map[string]interface{}{}, }) } func (a *goBlog) serveEditorPreview(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + blog, _ := a.getBlog(r) c, err := ws.Accept(w, r, nil) if err != nil { return @@ -93,7 +91,6 @@ func (a *goBlog) createMarkdownPreview(blog string, markdown []byte) (rendered [ } func (a *goBlog) serveEditorPost(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) if action := r.FormValue("editoraction"); action != "" { switch action { case "loadupdate": @@ -103,7 +100,6 @@ func (a *goBlog) serveEditorPost(w http.ResponseWriter, r *http.Request) { return } a.render(w, r, templateEditor, &renderData{ - BlogString: blog, Data: map[string]interface{}{ "UpdatePostURL": a.fullPostURL(post), "UpdatePostContent": a.postToMfItem(post).Properties.Content[0], diff --git a/editorFiles.go b/editorFiles.go index 53c6922..aefe875 100644 --- a/editorFiles.go +++ b/editorFiles.go @@ -8,7 +8,6 @@ import ( ) func (a *goBlog) serveEditorFiles(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) // Get files files, err := a.mediaFiles() if err != nil { @@ -18,8 +17,7 @@ func (a *goBlog) serveEditorFiles(w http.ResponseWriter, r *http.Request) { // Check if files at all if len(files) == 0 { a.render(w, r, templateEditorFiles, &renderData{ - BlogString: blog, - Data: map[string]interface{}{}, + Data: map[string]interface{}{}, }) return } @@ -42,7 +40,6 @@ func (a *goBlog) serveEditorFiles(w http.ResponseWriter, r *http.Request) { } // Serve HTML a.render(w, r, templateEditorFiles, &renderData{ - BlogString: blog, Data: map[string]interface{}{ "Files": files, "Uses": uses, @@ -69,5 +66,6 @@ func (a *goBlog) serveEditorFilesDelete(w http.ResponseWriter, r *http.Request) a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - http.Redirect(w, r, a.getRelativePath(r.Context().Value(blogKey).(string), "/editor/files"), http.StatusFound) + _, bc := a.getBlog(r) + http.Redirect(w, r, bc.getRelativePath("/editor/files"), http.StatusFound) } diff --git a/feeds.go b/feeds.go index 71dda79..755f7da 100644 --- a/feeds.go +++ b/feeds.go @@ -44,7 +44,7 @@ func (a *goBlog) generateFeed(blog string, f feedType, w http.ResponseWriter, r } for _, p := range posts { var contentBuf bytes.Buffer - a.min.Write(&contentBuf, contenttype.HTML, []byte(a.feedHtml(p))) + _, _ = a.min.Write(&contentBuf, contenttype.HTML, []byte(a.feedHtml(p))) feed.Add(&feeds.Item{ Title: p.RenderedTitle, Link: &feeds.Link{Href: a.fullPostURL(p)}, diff --git a/geoMap.go b/geoMap.go index 7f6c16b..2eb4a86 100644 --- a/geoMap.go +++ b/geoMap.go @@ -8,8 +8,7 @@ import ( const defaultGeoMapPath = "/map" func (a *goBlog) serveGeoMap(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) - bc := a.cfg.Blogs[blog] + blog, bc := a.getBlog(r) mapPath := bc.getRelativePath(defaultIfEmpty(bc.Map.Path, defaultGeoMapPath)) canonical := a.getFullAddress(mapPath) @@ -27,8 +26,7 @@ func (a *goBlog) serveGeoMap(w http.ResponseWriter, r *http.Request) { if len(allPostsWithLocation) == 0 { a.render(w, r, templateGeoMap, &renderData{ - BlogString: blog, - Canonical: canonical, + Canonical: canonical, Data: map[string]interface{}{ "nolocations": true, }, @@ -88,8 +86,7 @@ func (a *goBlog) serveGeoMap(w http.ResponseWriter, r *http.Request) { } a.render(w, r, templateGeoMap, &renderData{ - BlogString: blog, - Canonical: canonical, + Canonical: canonical, Data: map[string]interface{}{ "locations": locationsJson, "tracks": tracksJson, diff --git a/posts.go b/posts.go index 660adab..fe44fa4 100644 --- a/posts.go +++ b/posts.go @@ -79,7 +79,8 @@ func (a *goBlog) servePost(w http.ResponseWriter, r *http.Request) { } func (a *goBlog) redirectToRandomPost(rw http.ResponseWriter, r *http.Request) { - randomPath, err := a.getRandomPostPath(r.Context().Value(blogKey).(string)) + blog, _ := a.getBlog(r) + randomPath, err := a.getRandomPostPath(blog) if err != nil { a.serveError(rw, r, err.Error(), http.StatusInternalServerError) return @@ -112,7 +113,7 @@ func (p *postPaginationAdapter) Slice(offset, length int, data interface{}) erro } func (a *goBlog) serveHome(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + blog, _ := a.getBlog(r) if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { a.serveActivityStreams(blog, w, r) return @@ -123,37 +124,37 @@ func (a *goBlog) serveHome(w http.ResponseWriter, r *http.Request) { } func (a *goBlog) serveDrafts(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(blog, "/editor/drafts"), - title: a.ts.GetTemplateStringVariant(a.cfg.Blogs[blog].Lang, "drafts"), + path: bc.getRelativePath("/editor/drafts"), + title: a.ts.GetTemplateStringVariant(bc.Lang, "drafts"), status: statusDraft, }))) } func (a *goBlog) servePrivate(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(blog, "/editor/private"), - title: a.ts.GetTemplateStringVariant(a.cfg.Blogs[blog].Lang, "privateposts"), + path: bc.getRelativePath("/editor/private"), + title: a.ts.GetTemplateStringVariant(bc.Lang, "privateposts"), status: statusPrivate, }))) } func (a *goBlog) serveUnlisted(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(blog, "/editor/unlisted"), - title: a.ts.GetTemplateStringVariant(a.cfg.Blogs[blog].Lang, "unlistedposts"), + path: bc.getRelativePath("/editor/unlisted"), + title: a.ts.GetTemplateStringVariant(bc.Lang, "unlistedposts"), status: statusUnlisted, }))) } func (a *goBlog) serveScheduled(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(blog, "/editor/scheduled"), - title: a.ts.GetTemplateStringVariant(a.cfg.Blogs[blog].Lang, "scheduledposts"), + path: bc.getRelativePath("/editor/scheduled"), + title: a.ts.GetTemplateStringVariant(bc.Lang, "scheduledposts"), status: statusScheduled, }))) } @@ -193,8 +194,9 @@ func (a *goBlog) serveDate(w http.ResponseWriter, r *http.Request) { title.WriteString(fmt.Sprintf("-%02d", day)) dPath.WriteString(fmt.Sprintf("/%02d", day)) } + _, bc := a.getBlog(r) a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(r.Context().Value(blogKey).(string), dPath.String()), + path: bc.getRelativePath(dPath.String()), year: year, month: month, day: day, @@ -203,7 +205,6 @@ func (a *goBlog) serveDate(w http.ResponseWriter, r *http.Request) { } type indexConfig struct { - blog string path string section *configSection tax *configTaxonomy @@ -222,10 +223,7 @@ const indexConfigKey contextKey = "indexConfig" func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { ic := r.Context().Value(indexConfigKey).(*indexConfig) - blog := ic.blog - if blog == "" { - blog, _ = r.Context().Value(blogKey).(string) - } + blog, bc := a.getBlog(r) search := chi.URLParam(r, "search") if search != "" { // Decode and sanitize search @@ -237,7 +235,7 @@ func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { if ic.section != nil { sections = []string{ic.section.Name} } else { - for sectionKey := range a.cfg.Blogs[blog].Sections { + for sectionKey := range bc.Sections { sections = append(sections, sectionKey) } } @@ -257,7 +255,7 @@ func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { publishedDay: ic.day, status: status, priorityOrder: true, - }, a: a}, a.cfg.Blogs[blog].Pagination) + }, a: a}, bc.Pagination) p.SetPage(pageNo) var posts []*post err := p.Results(&posts) @@ -274,7 +272,7 @@ func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { title = ic.section.Title description = ic.section.Description } else if search != "" { - title = fmt.Sprintf("%s: %s", a.cfg.Blogs[blog].Search.Title, search) + title = fmt.Sprintf("%s: %s", bc.Search.Title, search) } // Check if feed if ft := feedType(chi.URLParam(r, "feed")); ft != noFeed { @@ -313,8 +311,7 @@ func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { summaryTemplate = templateSummary } a.render(w, r, templateIndex, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(path), + Canonical: a.getFullAddress(path), Data: map[string]interface{}{ "Title": title, "Description": description, diff --git a/render.go b/render.go index 30efe1f..ad1308c 100644 --- a/render.go +++ b/render.go @@ -162,16 +162,14 @@ func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { data.User = a.cfg.User } // Blog - if data.Blog == nil { - if data.BlogString == "" { - data.BlogString = a.cfg.DefaultBlog - } + if data.Blog == nil && data.BlogString == "" { + data.BlogString, data.Blog = a.getBlog(r) + } else if data.Blog == nil { data.Blog = a.cfg.Blogs[data.BlogString] - } - if data.BlogString == "" { - for s, b := range a.cfg.Blogs { - if b == data.Blog { - data.BlogString = s + } else if data.BlogString == "" { + for name, blog := range a.cfg.Blogs { + if blog == data.Blog { + data.BlogString = name break } } diff --git a/search.go b/search.go index cbca0fc..a577a3c 100644 --- a/search.go +++ b/search.go @@ -13,7 +13,6 @@ const defaultSearchPath = "/search" const searchPlaceholder = "{search}" func (a *goBlog) serveSearch(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) servePath := r.Context().Value(pathKey).(string) err := r.ParseForm() if err != nil { @@ -28,8 +27,7 @@ func (a *goBlog) serveSearch(w http.ResponseWriter, r *http.Request) { return } a.render(w, r, templateSearch, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(servePath), + Canonical: a.getFullAddress(servePath), }) } diff --git a/sitemap.go b/sitemap.go index 598a64f..c78360a 100644 --- a/sitemap.go +++ b/sitemap.go @@ -39,18 +39,18 @@ func (a *goBlog) serveSitemapBlog(w http.ResponseWriter, r *http.Request) { // Create sitemap sm := sitemap.NewSitemapIndex() // Add blog sitemaps - b := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) now := time.Now().UTC() sm.Add(&sitemap.URL{ - Loc: a.getFullAddress(a.getRelativePath(b, sitemapBlogFeaturesPath)), + Loc: a.getFullAddress(bc.getRelativePath(sitemapBlogFeaturesPath)), LastMod: &now, }) sm.Add(&sitemap.URL{ - Loc: a.getFullAddress(a.getRelativePath(b, sitemapBlogArchivesPath)), + Loc: a.getFullAddress(bc.getRelativePath(sitemapBlogArchivesPath)), LastMod: &now, }) sm.Add(&sitemap.URL{ - Loc: a.getFullAddress(a.getRelativePath(b, sitemapBlogPostsPath)), + Loc: a.getFullAddress(bc.getRelativePath(sitemapBlogPostsPath)), LastMod: &now, }) // Write sitemap @@ -61,7 +61,7 @@ func (a *goBlog) serveSitemapBlogFeatures(w http.ResponseWriter, r *http.Request // Create sitemap sm := sitemap.New() // Add features to sitemap - bc := a.cfg.Blogs[r.Context().Value(blogKey).(string)] + _, bc := a.getBlog(r) // Home sm.Add(&sitemap.URL{ Loc: a.getFullAddress(bc.getRelativePath("")), @@ -116,8 +116,7 @@ func (a *goBlog) serveSitemapBlogArchives(w http.ResponseWriter, r *http.Request // Create sitemap sm := sitemap.New() // Add archives to sitemap - b := r.Context().Value(blogKey).(string) - bc := a.cfg.Blogs[b] + b, bc := a.getBlog(r) // Sections for _, section := range bc.Sections { if section.Name != "" { @@ -160,9 +159,10 @@ func (a *goBlog) serveSitemapBlogPosts(w http.ResponseWriter, r *http.Request) { // Create sitemap sm := sitemap.New() // Request posts + blog, _ := a.getBlog(r) posts, _ := a.getPosts(&postsRequestConfig{ status: statusPublished, - blog: r.Context().Value(blogKey).(string), + blog: blog, withoutParameters: true, }) // Add posts to sitemap diff --git a/taxonomies.go b/taxonomies.go index 1e8f570..5002b7a 100644 --- a/taxonomies.go +++ b/taxonomies.go @@ -13,7 +13,7 @@ import ( const taxonomyContextKey = "taxonomy" func (a *goBlog) serveTaxonomy(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + blog, _ := a.getBlog(r) tax := r.Context().Value(taxonomyContextKey).(*configTaxonomy) allValues, err := a.db.allTaxonomyValues(blog, tax.Name) if err != nil { @@ -21,8 +21,7 @@ func (a *goBlog) serveTaxonomy(w http.ResponseWriter, r *http.Request) { return } a.render(w, r, templateTaxonomy, &renderData{ - BlogString: blog, - Canonical: a.getFullAddress(r.URL.Path), + Canonical: a.getFullAddress(r.URL.Path), Data: map[string]interface{}{ "Taxonomy": tax, "ValueGroups": groupStrings(allValues), @@ -31,7 +30,7 @@ func (a *goBlog) serveTaxonomy(w http.ResponseWriter, r *http.Request) { } func (a *goBlog) serveTaxonomyValue(w http.ResponseWriter, r *http.Request) { - blog := r.Context().Value(blogKey).(string) + _, bc := a.getBlog(r) tax := r.Context().Value(taxonomyContextKey).(*configTaxonomy) taxValueParam := chi.URLParam(r, "taxValue") if taxValueParam == "" { @@ -59,7 +58,7 @@ func (a *goBlog) serveTaxonomyValue(w http.ResponseWriter, r *http.Request) { } // Serve index a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: a.getRelativePath(blog, fmt.Sprintf("/%s/%s", tax.Name, taxValueParam)), + path: bc.getRelativePath(fmt.Sprintf("/%s/%s", tax.Name, taxValueParam)), tax: tax, taxValue: taxValue, }))) diff --git a/tts.go b/tts.go index 1c1619d..e76cd35 100644 --- a/tts.go +++ b/tts.go @@ -198,8 +198,7 @@ func (a *goBlog) createTTSAudio(lang, text, outputFile string) error { if encoded, ok := content["audioContent"]; ok { if encodedStr, ok := encoded.(string); ok { if audio, err := base64.StdEncoding.DecodeString(encodedStr); err == nil { - os.WriteFile(outputFile, audio, os.ModePerm) - return nil + return os.WriteFile(outputFile, audio, os.ModePerm) } else { return err }