diff --git a/blogstats.go b/blogstats.go index 3acfebe..eeb2d25 100644 --- a/blogstats.go +++ b/blogstats.go @@ -45,70 +45,134 @@ func (a *goBlog) serveBlogStatsTable(w http.ResponseWriter, r *http.Request) { }) } +const blogStatsSql = ` +with filtered as ( + select + path, + coalesce(published, '') as pub, + substr(published, 1, 4) as year, + substr(published, 6, 2) as month, + wordcount(coalesce(content, '')) as words, + charcount(coalesce(content, '')) as chars + from posts + where status = @status and blog = @blog +) +select * +from ( + select * + from ( + select + year, + 'A', + coalesce(count(path), 0) as pc, + coalesce(sum(words), 0) as wc, + coalesce(sum(chars), 0) as cc, + coalesce(round(sum(words)/count(path), 0), 0) as wpp + from filtered + where pub != '' + group by year + order by year desc + ) + union all + select * + from ( + select + year, + month, + coalesce(count(path), 0) as pc, + coalesce(sum(words), 0) as wc, + coalesce(sum(chars), 0) as cc, + coalesce(round(sum(words)/count(path), 0), 0) as wpp + from filtered + where pub != '' + group by year, month + order by year desc, month desc + ) + union all + select * + from ( + select + 'N', + 'N', + coalesce(count(path), 0) as pc, + coalesce(sum(words), 0) as wc, + coalesce(sum(chars), 0) as cc, + coalesce(round(sum(words)/count(path), 0), 0) as wpp + from filtered + where pub == '' + ) + union all + select * + from ( + select + 'A', + 'A', + coalesce(count(path), 0) as pc, + coalesce(sum(words), 0) as wc, + coalesce(sum(chars), 0) as cc, + coalesce(round(sum(words)/count(path), 0), 0) as wpp + from filtered + ) +); +` + func (db *database) getBlogStats(blog string) (data map[string]interface{}, err error) { + // Check cache if stats := db.loadBlogStatsCache(blog); stats != nil { return stats, nil } - // Build query - prq := &postsRequestConfig{ - blog: blog, - status: statusPublished, - } - query, params := buildPostsQuery(prq) - query = "select path, mdtext(content) as content, published, substr(published, 1, 4) as year, substr(published, 6, 2) as month from (" + query + ")" - postCount := "coalesce(count(distinct path), 0) as postcount" - charCount := "coalesce(sum(coalesce(charcount(distinct content), 0)), 0)" - wordCount := "coalesce(sum(wordcount(distinct content)), 0) as wordcount" - wordsPerPost := "coalesce(round(wordcount/postcount,0), 0)" + // Prevent creating posts while getting stats + db.pcm.Lock() + defer db.pcm.Unlock() + // Stats type to hold the stats data for a single row type statsTableType struct { Name, Posts, Chars, Words, WordsPerPost string } - // Count total posts - row, err := db.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+"))", params...) - if err != nil { - return nil, err - } - total := statsTableType{} - if err = row.Scan(&total.Posts, &total.Chars, &total.Words, &total.WordsPerPost); err != nil { - return nil, err - } - // Count posts per year - rows, err := db.query("select *, "+wordsPerPost+" from (select year, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' group by year order by year desc)", params...) - if err != nil { - return nil, err - } + // Scan objects + currentStats := statsTableType{} + var currentMonth, currentYear string + // Data to later return + var total statsTableType + var noDate statsTableType var years []statsTableType - year := statsTableType{} - for rows.Next() { - if err = rows.Scan(&year.Name, &year.Posts, &year.Chars, &year.Words, &year.WordsPerPost); err == nil { - years = append(years, year) - } else { - return nil, err - } - } - // Count posts without date - row, err = db.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published = '')", params...) + months := map[string][]statsTableType{} + // Query and scan + rows, err := db.query(blogStatsSql, sql.Named("status", statusPublished), sql.Named("blog", blog)) if err != nil { return nil, err } - noDate := statsTableType{} - if err = row.Scan(&noDate.Posts, &noDate.Chars, &noDate.Words, &noDate.WordsPerPost); err != nil { - return nil, err - } - // Count posts per month per year - months := map[string][]statsTableType{} - month := statsTableType{} - for _, year := range years { - rows, err = db.query("select *, "+wordsPerPost+" from (select month, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' and year = @year group by month order by month desc)", append(params, sql.Named("year", year.Name))...) - if err != nil { - return nil, err - } - for rows.Next() { - if err = rows.Scan(&month.Name, &month.Posts, &month.Chars, &month.Words, &month.WordsPerPost); err == nil { - months[year.Name] = append(months[year.Name], month) - } else { - return nil, err + for rows.Next() { + err = rows.Scan(¤tYear, ¤tMonth, ¤tStats.Posts, ¤tStats.Words, ¤tStats.Chars, ¤tStats.WordsPerPost) + if currentYear == "A" && currentMonth == "A" { + total = statsTableType{ + Posts: currentStats.Posts, + Words: currentStats.Words, + Chars: currentStats.Chars, + WordsPerPost: currentStats.WordsPerPost, } + } else if currentYear == "N" && currentMonth == "N" { + noDate = statsTableType{ + Posts: currentStats.Posts, + Words: currentStats.Words, + Chars: currentStats.Chars, + WordsPerPost: currentStats.WordsPerPost, + } + } else if currentMonth == "A" { + years = append(years, statsTableType{ + Name: currentYear, + Posts: currentStats.Posts, + Words: currentStats.Words, + Chars: currentStats.Chars, + WordsPerPost: currentStats.WordsPerPost, + }) + } else { + months[currentYear] = append(months[currentYear], statsTableType{ + Name: currentMonth, + Posts: currentStats.Posts, + Words: currentStats.Words, + Chars: currentStats.Chars, + WordsPerPost: currentStats.WordsPerPost, + }) } } data = map[string]interface{}{ diff --git a/database.go b/database.go index 98e444b..c91c067 100644 --- a/database.go +++ b/database.go @@ -15,11 +15,14 @@ import ( ) type database struct { - db *sql.DB - stmts map[string]*sql.Stmt - g singleflight.Group - pc singleflight.Group - pcm sync.Mutex + // Basic things + db *sql.DB // database + em sync.Mutex // command execution (insert, update, delete ...) + sg singleflight.Group // singleflight group for prepared statements + ps sync.Map // map with prepared statements + // Other things + pc singleflight.Group // persistant cache + pcm sync.Mutex // post creation } func (a *goBlog) initDatabase() (err error) { @@ -75,11 +78,13 @@ func (a *goBlog) openDatabase(file string, logging bool) (*database, error) { } sql.Register("goblog_db_"+dbDriverName, dr) // Open db - db, err := sql.Open("goblog_db_"+dbDriverName, file+"?cache=shared&mode=rwc&_journal_mode=WAL") + db, err := sql.Open("goblog_db_"+dbDriverName, file+"?mode=rwc&_journal_mode=WAL&_busy_timeout=100&cache=shared") if err != nil { return nil, err } - db.SetMaxOpenConns(1) + numConns := 5 + db.SetMaxOpenConns(numConns) + db.SetMaxIdleConns(numConns) err = db.Ping() if err != nil { return nil, err @@ -107,14 +112,17 @@ func (a *goBlog) openDatabase(file string, logging bool) (*database, error) { return nil, err } return &database{ - db: db, - stmts: map[string]*sql.Stmt{}, + db: db, }, nil } // Main features func (db *database) dump(file string) { + // Lock execution + db.em.Lock() + defer db.em.Unlock() + // Dump database f, err := os.Create(file) if err != nil { log.Println("Error while dump db:", err.Error()) @@ -130,17 +138,20 @@ func (db *database) close() error { } func (db *database) prepare(query string) (*sql.Stmt, error) { - stmt, err, _ := db.g.Do(query, func() (interface{}, error) { - stmt, ok := db.stmts[query] - if ok && stmt != nil { - return stmt, nil + stmt, err, _ := db.sg.Do(query, func() (interface{}, error) { + // Look if statement already exists + st, ok := db.ps.Load(query) + if ok { + return st, nil } - stmt, err := db.db.Prepare(query) + // ... otherwise prepare ... + st, err := db.db.Prepare(query) if err != nil { return nil, err } - db.stmts[query] = stmt - return stmt, nil + // ... and store it + db.ps.Store(query, st) + return st, nil }) if err != nil { return nil, err @@ -148,33 +159,43 @@ func (db *database) prepare(query string) (*sql.Stmt, error) { return stmt.(*sql.Stmt), nil } -func (db *database) exec(query string, args ...interface{}) (sql.Result, error) { - stmt, err := db.prepare(query) - if err != nil { - return nil, err - } - return stmt.Exec(args...) -} +const dbNoCache = "nocache" -func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) { - // Can't prepare the statement +func (db *database) exec(query string, args ...interface{}) (sql.Result, error) { + // Lock execution + db.em.Lock() + defer db.em.Unlock() + // Check if prepared cache should be skipped + if len(args) > 0 && args[0] == dbNoCache { + return db.db.Exec(query, args[1:]...) + } + // Use prepared statement + st, _ := db.prepare(query) + if st != nil { + return st.Exec(args...) + } + // Or execute directly return db.db.Exec(query, args...) } func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { - stmt, err := db.prepare(query) - if err != nil { - return nil, err + // Use prepared statement + st, _ := db.prepare(query) + if st != nil { + return st.Query(args...) } - return stmt.Query(args...) + // Or query directly + return db.db.Query(query, args...) } func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { - stmt, err := db.prepare(query) - if err != nil { - return nil, err + // Use prepared statement + st, _ := db.prepare(query) + if st != nil { + return st.QueryRow(args...), nil } - return stmt.QueryRow(args...), nil + // Or query directly + return db.db.QueryRow(query, args...), nil } // Other things diff --git a/database_test.go b/database_test.go index 570355c..037d723 100644 --- a/database_test.go +++ b/database_test.go @@ -2,6 +2,8 @@ package main import ( "testing" + + "github.com/stretchr/testify/require" ) func (a *goBlog) setInMemoryDatabase() { @@ -19,7 +21,7 @@ func Test_database(t *testing.T) { t.Fatalf("Error: %v", err) } - _, err = db.execMulti("create table test(test text);") + _, err = db.exec("create table test(test text);") if err != nil { t.Fatalf("Error: %v", err) } @@ -65,3 +67,40 @@ func Test_database(t *testing.T) { } }) } + +func Test_parallelDatabase(t *testing.T) { + t.Run("Test parallel db access", func(t *testing.T) { + // Test that parallel database access works without problems + + t.Parallel() + + app := &goBlog{ + cfg: &config{}, + } + app.setInMemoryDatabase() + + _, err := app.db.exec("create table test(test text);") + require.NoError(t, err) + + t.Run("1", func(t *testing.T) { + for i := 0; i < 10000; i++ { + _, e := app.db.exec("insert into test (test) values ('Test')") + require.NoError(t, e) + } + }) + + t.Run("2", func(t *testing.T) { + for i := 0; i < 10000; i++ { + _, e := app.db.exec("insert into test (test) values ('Test')") + require.NoError(t, e) + } + }) + + t.Run("3", func(t *testing.T) { + for i := 0; i < 10000; i++ { + _, e := app.db.queryRow("select count(test) from test") + require.NoError(t, e) + } + }) + }) +} diff --git a/editor.go b/editor.go index ff74efd..ff55a71 100644 --- a/editor.go +++ b/editor.go @@ -105,7 +105,7 @@ func (a *goBlog) editorMicropubPost(w http.ResponseWriter, r *http.Request, medi http.Redirect(w, r, location, http.StatusFound) return } - if result.StatusCode >= 200 && result.StatusCode <= 400 { + if result.StatusCode >= 200 && result.StatusCode < 400 { http.Redirect(w, r, editorPath, http.StatusFound) return } diff --git a/postsDb.go b/postsDb.go index 4269610..9113cf6 100644 --- a/postsDb.go +++ b/postsDb.go @@ -144,29 +144,18 @@ func (a *goBlog) createOrReplacePost(p *post, o *postCreationOptions) error { // Save check post to database func (db *database) savePost(p *post, o *postCreationOptions) error { - // Prevent bad things + // Check + if !o.new && o.oldPath == "" { + return errors.New("old path required") + } + // Lock post creation db.pcm.Lock() defer db.pcm.Unlock() - // Check if path is already in use - if o.new || (p.Path != o.oldPath) { - // Post is new or post path was changed - newPathExists := false - row, err := db.queryRow("select exists(select 1 from posts where path = @path)", sql.Named("path", p.Path)) - if err != nil { - return err - } - err = row.Scan(&newPathExists) - if err != nil { - return err - } - if newPathExists { - // New path already exists - return errors.New("post already exists at given path") - } - } // Build SQL var sqlBuilder strings.Builder - var sqlArgs []interface{} + var sqlArgs = []interface{}{dbNoCache} + // Start transaction + sqlBuilder.WriteString("begin;") // Delete old post if !o.new { sqlBuilder.WriteString("delete from posts where path = ?;") @@ -184,8 +173,13 @@ func (db *database) savePost(p *post, o *postCreationOptions) error { } } } + // Commit transaction + sqlBuilder.WriteString("commit;") // Execute - if _, err := db.execMulti(sqlBuilder.String(), sqlArgs...); err != nil { + if _, err := db.exec(sqlBuilder.String(), sqlArgs...); err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint failed: posts.path") { + return errors.New("post already exists at given path") + } return err } // Update FTS index diff --git a/postsDb_test.go b/postsDb_test.go index b9490c9..a128f9f 100644 --- a/postsDb_test.go +++ b/postsDb_test.go @@ -150,6 +150,18 @@ func Test_postsDb(t *testing.T) { if is.NoError(err) { is.Equal(1, count) } + + // Check that post is already present + err = app.db.savePost(&post{ + Path: "/test/abc", + Content: "ABCD", + Published: "2021-06-10 10:00:00", + Updated: "2021-06-15 10:00:00", + Blog: "en", + Section: "test", + Status: statusPublished, + }, &postCreationOptions{new: true}) + must.Error(err) } func Test_ftsWithoutTitle(t *testing.T) {