diff --git a/app.go b/app.go index 0ea9651..8b2789f 100644 --- a/app.go +++ b/app.go @@ -7,6 +7,7 @@ import ( shutdowner "git.jlel.se/jlelse/go-shutdowner" ts "git.jlel.se/jlelse/template-strings" + "github.com/dgraph-io/ristretto" ct "github.com/elnormous/contenttype" "github.com/go-fed/httpsig" "github.com/hacdias/indieauth/v2" @@ -73,6 +74,10 @@ type goBlog struct { mediaStorage mediaStorage // Minify min minify.Minifier + // Reactions + reactionsInit sync.Once + reactionsCache *ristretto.Cache + reactionsSfg singleflight.Group // Regex Redirects regexRedirects []*regexRedirect // Sessions diff --git a/reactions.go b/reactions.go index c223469..ca1fb68 100644 --- a/reactions.go +++ b/reactions.go @@ -3,8 +3,10 @@ package main import ( "encoding/json" "errors" + "io" "net/http" + "github.com/dgraph-io/ristretto" "github.com/samber/lo" "go.goblog.app/app/pkgs/bufferpool" "go.goblog.app/app/pkgs/contenttype" @@ -29,6 +31,17 @@ func (a *goBlog) reactionsEnabledForPost(post *post) bool { return a.reactionsEnabled() && post != nil && post.firstParameter(reactionsPostParam) != "false" } +func (a *goBlog) initReactions() { + a.reactionsInit.Do(func() { + a.reactionsCache, _ = ristretto.NewCache(&ristretto.Config{ + NumCounters: 1000, + MaxCost: 100, // Cache reactions for 100 posts + BufferItems: 64, + IgnoreInternalCost: true, + }) + }) +} + func (a *goBlog) postReaction(w http.ResponseWriter, r *http.Request) { path := r.FormValue("path") reaction := r.FormValue("reaction") @@ -48,6 +61,11 @@ func (a *goBlog) saveReaction(reaction, path string) error { if !lo.Contains(allowedReactions, reaction) { return errors.New("reaction not allowed") } + // Init + a.initReactions() + // Delete from cache + defer a.reactionsSfg.Forget(path) + defer a.reactionsCache.Del(path) // Insert reaction _, err := a.db.exec("insert into reactions (path, reaction, count) values (?, ?, 1) on conflict (path, reaction) do update set count=count+1", path, reaction) return err @@ -68,38 +86,57 @@ func (a *goBlog) getReactions(w http.ResponseWriter, r *http.Request) { return } w.Header().Set(contentType, contenttype.JSONUTF8) - _ = a.min.Get().Minify(contenttype.JSON, w, buf) + _, _ = io.Copy(w, buf) } func (a *goBlog) getReactionsFromDatabase(path string) (map[string]int, error) { - sqlBuf := bufferpool.Get() - defer bufferpool.Put(sqlBuf) - sqlArgs := []any{} - sqlBuf.WriteString("select reaction, count from reactions where path=? and reaction in (") - sqlArgs = append(sqlArgs, path) - for i, reaction := range allowedReactions { - if i > 0 { - sqlBuf.WriteString(",") + // Init + a.initReactions() + // Check cache + if val, cached := a.reactionsCache.Get(path); cached { + // Return from cache + return val.(map[string]int), nil + } + // Get reactions + res, err, _ := a.reactionsSfg.Do(path, func() (interface{}, error) { + // Build query + sqlBuf := bufferpool.Get() + defer bufferpool.Put(sqlBuf) + sqlArgs := []any{} + sqlBuf.WriteString("select reaction, count from reactions where path=? and reaction in (") + sqlArgs = append(sqlArgs, path) + for i, reaction := range allowedReactions { + if i > 0 { + sqlBuf.WriteString(",") + } + sqlBuf.WriteString("?") + sqlArgs = append(sqlArgs, reaction) } - sqlBuf.WriteString("?") - sqlArgs = append(sqlArgs, reaction) - } - sqlBuf.WriteString(") and path not in (select path from post_parameters where parameter=? and value=?)") - sqlArgs = append(sqlArgs, reactionsPostParam, "false") - rows, err := a.db.query(sqlBuf.String(), sqlArgs...) - if err != nil { - return nil, err - } - defer rows.Close() - reactions := map[string]int{} - for rows.Next() { - var reaction string - var count int - err = rows.Scan(&reaction, &count) + sqlBuf.WriteString(") and path not in (select path from post_parameters where parameter=? and value=?)") + sqlArgs = append(sqlArgs, reactionsPostParam, "false") + // Execute query + rows, err := a.db.query(sqlBuf.String(), sqlArgs...) if err != nil { return nil, err } - reactions[reaction] = count + // Build result + defer rows.Close() + reactions := map[string]int{} + for rows.Next() { + var reaction string + var count int + err = rows.Scan(&reaction, &count) + if err != nil { + return nil, err + } + reactions[reaction] = count + } + // Cache result + a.reactionsCache.Set(path, reactions, 1) + return reactions, nil + }) + if err != nil || res == nil { + return nil, err } - return reactions, nil + return res.(map[string]int), nil } diff --git a/reactions_test.go b/reactions_test.go index 84fd3b3..3423977 100644 --- a/reactions_test.go +++ b/reactions_test.go @@ -136,13 +136,13 @@ func Test_reactionsHighLevel(t *testing.T) { rec = httptest.NewRecorder() app.getReactions(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, `{"❤️":1}`, rec.Body.String()) + assert.Equal(t, "{\"❤️\":1}\n", rec.Body.String()) // Get reactions for a non-existing post req = httptest.NewRequest(http.MethodGet, "/?path=/non-existing-post", nil) rec = httptest.NewRecorder() app.getReactions(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, `{}`, rec.Body.String()) + assert.Equal(t, "{}\n", rec.Body.String()) }