Improve posts db

This commit is contained in:
Jan-Lukas Else 2021-07-03 12:11:57 +02:00
parent 9b0b20bd90
commit 85bf7ab711
11 changed files with 232 additions and 158 deletions

View File

@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -13,6 +14,9 @@ import (
func Test_captchaMiddleware(t *testing.T) { func Test_captchaMiddleware(t *testing.T) {
app := &goBlog{ app := &goBlog{
cfg: &config{ cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{ Server: &configServer{
PublicAddress: "https://example.com", PublicAddress: "https://example.com",
}, },
@ -26,7 +30,7 @@ func Test_captchaMiddleware(t *testing.T) {
}, },
} }
app.setInMemoryDatabase() app.initDatabase(false)
app.initSessions() app.initSessions()
_ = app.initTemplateStrings() _ = app.initTemplateStrings()
_ = app.initRendering() _ = app.initRendering()

View File

@ -12,7 +12,7 @@ import (
) )
func (a *goBlog) checkAllExternalLinks() { func (a *goBlog) checkAllExternalLinks() {
allPosts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished}) allPosts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished, withoutParameters: true})
if err != nil { if err != nil {
log.Println(err.Error()) log.Println(err.Error())
return return

View File

@ -23,12 +23,16 @@ type database struct {
// Other things // Other things
pc singleflight.Group // persistant cache pc singleflight.Group // persistant cache
pcm sync.Mutex // post creation pcm sync.Mutex // post creation
sp singleflight.Group // singleflight group for short path requests
spc sync.Map // shortpath cache
} }
func (a *goBlog) initDatabase() (err error) { func (a *goBlog) initDatabase(logging bool) (err error) {
log.Println("Initialize database...") if logging {
log.Println("Initialize database...")
}
// Setup db // Setup db
db, err := a.openDatabase(a.cfg.Db.File, true) db, err := a.openDatabase(a.cfg.Db.File, logging)
if err != nil { if err != nil {
return err return err
} }
@ -47,7 +51,9 @@ func (a *goBlog) initDatabase() (err error) {
}) })
db.dump(a.cfg.Db.DumpFile) db.dump(a.cfg.Db.DumpFile)
} }
log.Println("Initialized database") if logging {
log.Println("Initialized database")
}
return nil return nil
} }

View File

@ -191,6 +191,34 @@ func migrateDb(db *sql.DB, logging bool) error {
return err return err
}, },
}, },
&migrator.Migration{
Name: "00017",
Func: func(tx *sql.Tx) error {
_, err := tx.Exec(`
create index index_post_parameters on post_parameters (path, parameter, value);
create index index_queue on queue (name, schedule);
drop index index_pp_path;
drop index index_queue_name;
drop index index_queue_schedule;
drop view view_posts_with_title;
create table posts_new (path text not null primary key, content text, published text, updated text, blog text not null, section text, status text not null, priority integer not null default 0);
insert into posts_new select *, 0 from posts;
drop table posts;
alter table posts_new rename to posts;
create view view_posts_with_title as select p.rowid as id, p.path as path, coalesce(pp.value, '') as title, content, published, updated, blog, section, status, priority from posts p left outer join (select * from post_parameters pp where pp.parameter = 'title') pp on p.path = pp.path;
drop table posts_fts;
create virtual table posts_fts using fts5(path unindexed, title, content, published unindexed, updated unindexed, blog unindexed, section unindexed, status unindexed, priority unindexed, content=view_posts_with_title, content_rowid=id);
insert into posts_fts(posts_fts) values ('rebuild');
create index index_posts_status on posts (status);
create index index_posts_blog on posts (blog);
create index index_posts_section on posts (section);
create index index_posts_published on posts (published);
create index index_posts_priority on posts (published);
drop trigger if exists trigger_posts_delete_pp;
`)
return err
},
},
), ),
) )
if err != nil { if err != nil {

View File

@ -2,14 +2,8 @@ package main
import ( import (
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func (a *goBlog) setInMemoryDatabase() {
a.db, _ = a.openDatabase(":memory:", false)
}
func Test_database(t *testing.T) { func Test_database(t *testing.T) {
t.Run("Basic Database Test", func(t *testing.T) { t.Run("Basic Database Test", func(t *testing.T) {
app := &goBlog{ app := &goBlog{
@ -67,40 +61,3 @@ func Test_database(t *testing.T) {
} }
}) })
} }
func Test_parallelDatabase(t *testing.T) {
t.Run("Test parallel db access", func(t *testing.T) {
// Test that parallel database access works without problems
t.Parallel()
app := &goBlog{
cfg: &config{},
}
app.setInMemoryDatabase()
_, err := app.db.exec("create table test(test text);")
require.NoError(t, err)
t.Run("1", func(t *testing.T) {
for i := 0; i < 10000; i++ {
_, e := app.db.exec("insert into test (test) values ('Test')")
require.NoError(t, e)
}
})
t.Run("2", func(t *testing.T) {
for i := 0; i < 10000; i++ {
_, e := app.db.exec("insert into test (test) values ('Test')")
require.NoError(t, e)
}
})
t.Run("3", func(t *testing.T) {
for i := 0; i < 10000; i++ {
_, e := app.db.queryRow("select count(test) from test")
require.NoError(t, e)
}
})
})
}

View File

@ -86,7 +86,7 @@ func main() {
app.preStartHooks() app.preStartHooks()
// Initialize database and markdown // Initialize database and markdown
if err = app.initDatabase(); err != nil { if err = app.initDatabase(true); err != nil {
app.logErrAndQuit("Failed to init database:", err.Error()) app.logErrAndQuit("Failed to init database:", err.Error())
return return
} }

View File

@ -158,8 +158,8 @@ func (db *database) savePost(p *post, o *postCreationOptions) error {
sqlBuilder.WriteString("begin;") sqlBuilder.WriteString("begin;")
// Delete old post // Delete old post
if !o.new { if !o.new {
sqlBuilder.WriteString("delete from posts where path = ?;") sqlBuilder.WriteString("delete from posts where path = ?;delete from post_parameters where path = ?;")
sqlArgs = append(sqlArgs, o.oldPath) sqlArgs = append(sqlArgs, o.oldPath, o.oldPath)
} }
// Insert new post // Insert new post
sqlBuilder.WriteString("insert into posts (path, content, published, updated, blog, section, status) values (?, ?, ?, ?, ?, ?, ?);") sqlBuilder.WriteString("insert into posts (path, content, published, updated, blog, section, status) values (?, ?, ?, ?, ?, ?, ?);")
@ -204,7 +204,7 @@ func (db *database) deletePost(path string) (*post, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = db.exec("delete from posts where path = @path", sql.Named("path", p.Path)) _, err = db.exec("begin;delete from posts where path = ?;delete from post_parameters where path = ?;commit;", dbNoCache, p.Path, p.Path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -226,134 +226,134 @@ type postsRequestConfig struct {
parameterValue string parameterValue string
publishedYear, publishedMonth, publishedDay int publishedYear, publishedMonth, publishedDay int
randomOrder bool randomOrder bool
withoutParameters bool
} }
func buildPostsQuery(c *postsRequestConfig) (query string, args []interface{}) { func buildPostsQuery(c *postsRequestConfig, selection string) (query string, args []interface{}) {
args = []interface{}{} args = []interface{}{}
selection := "select p.path as path, coalesce(content, '') as content, coalesce(published, '') as published, coalesce(updated, '') as updated, coalesce(blog, '') as blog, coalesce(section, '') as section, coalesce(status, '') as status, coalesce(parameter, '') as parameter, coalesce(value, '') as value "
table := "posts" table := "posts"
if c.search != "" { if c.search != "" {
table = "posts_fts(@search)" table = "posts_fts(@search)"
args = append(args, sql.Named("search", c.search)) args = append(args, sql.Named("search", c.search))
} }
var wheres []string
if c.path != "" {
wheres = append(wheres, "path = @path")
args = append(args, sql.Named("path", c.path))
}
if c.status != "" && c.status != statusNil { if c.status != "" && c.status != statusNil {
table = "(select * from " + table + " where status = @status)" wheres = append(wheres, "status = @status")
args = append(args, sql.Named("status", c.status)) args = append(args, sql.Named("status", c.status))
} }
if c.blog != "" { if c.blog != "" {
table = "(select * from " + table + " where blog = @blog)" wheres = append(wheres, "blog = @blog")
args = append(args, sql.Named("blog", c.blog)) args = append(args, sql.Named("blog", c.blog))
} }
if c.parameter != "" { if c.parameter != "" {
table = "(select distinct p.* from " + table + " p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @param "
args = append(args, sql.Named("param", c.parameter))
if c.parameterValue != "" { if c.parameterValue != "" {
table += "and pp.value = @paramval)" wheres = append(wheres, "path in (select path from post_parameters where parameter = @param and value = @paramval)")
args = append(args, sql.Named("paramval", c.parameterValue)) args = append(args, sql.Named("param", c.parameter), sql.Named("paramval", c.parameterValue))
} else { } else {
table += "and length(coalesce(pp.value, '')) > 1)" wheres = append(wheres, "path in (select path from post_parameters where parameter = @param and length(coalesce(value, '')) > 0)")
args = append(args, sql.Named("param", c.parameter))
} }
} }
if c.taxonomy != nil && len(c.taxonomyValue) > 0 { if c.taxonomy != nil && len(c.taxonomyValue) > 0 {
table = "(select distinct p.* from " + table + " p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @taxname and lower(pp.value) = lower(@taxval))" wheres = append(wheres, "path in (select path from post_parameters where parameter = @taxname and lower(value) = lower(@taxval))")
args = append(args, sql.Named("taxname", c.taxonomy.Name), sql.Named("taxval", c.taxonomyValue)) args = append(args, sql.Named("taxname", c.taxonomy.Name), sql.Named("taxval", c.taxonomyValue))
} }
if len(c.sections) > 0 { if len(c.sections) > 0 {
table = "(select * from " + table + " where section in (" ws := "section in ("
for i, section := range c.sections { for i, section := range c.sections {
if i > 0 { if i > 0 {
table += ", " ws += ", "
} }
named := fmt.Sprintf("section%v", i) named := fmt.Sprintf("section%v", i)
table += "@" + named ws += "@" + named
args = append(args, sql.Named(named, section)) args = append(args, sql.Named(named, section))
} }
table += "))" ws += ")"
wheres = append(wheres, ws)
} }
if c.publishedYear != 0 { if c.publishedYear != 0 {
table = "(select * from " + table + " p where substr(p.published, 1, 4) = @publishedyear)" wheres = append(wheres, "substr(published, 1, 4) = @publishedyear")
args = append(args, sql.Named("publishedyear", fmt.Sprintf("%0004d", c.publishedYear))) args = append(args, sql.Named("publishedyear", fmt.Sprintf("%0004d", c.publishedYear)))
} }
if c.publishedMonth != 0 { if c.publishedMonth != 0 {
table = "(select * from " + table + " p where substr(p.published, 6, 2) = @publishedmonth)" wheres = append(wheres, "substr(published, 6, 2) = @publishedmonth")
args = append(args, sql.Named("publishedmonth", fmt.Sprintf("%02d", c.publishedMonth))) args = append(args, sql.Named("publishedmonth", fmt.Sprintf("%02d", c.publishedMonth)))
} }
if c.publishedDay != 0 { if c.publishedDay != 0 {
table = "(select * from " + table + " p where substr(p.published, 9, 2) = @publishedday)" wheres = append(wheres, "substr(published, 9, 2) = @publishedday")
args = append(args, sql.Named("publishedday", fmt.Sprintf("%02d", c.publishedDay))) args = append(args, sql.Named("publishedday", fmt.Sprintf("%02d", c.publishedDay)))
} }
tables := " from " + table + " p left outer join post_parameters pp on p.path = pp.path " if len(wheres) > 0 {
sorting := " order by p.published desc " table += " where " + strings.Join(wheres, " and ")
}
sorting := " order by published desc"
if c.randomOrder { if c.randomOrder {
sorting = " order by random() " sorting = " order by random()"
} }
if c.path != "" { table += sorting
query = selection + tables + " where p.path = @path" + sorting if c.limit != 0 || c.offset != 0 {
args = append(args, sql.Named("path", c.path)) table += " limit @limit offset @offset"
} else if c.limit != 0 || c.offset != 0 {
query = selection + " from (select * from " + table + " p " + sorting + " limit @limit offset @offset) p left outer join post_parameters pp on p.path = pp.path "
args = append(args, sql.Named("limit", c.limit), sql.Named("offset", c.offset)) args = append(args, sql.Named("limit", c.limit), sql.Named("offset", c.offset))
} else {
query = selection + tables + sorting
} }
return query = "select " + selection + " from " + table
return query, args
}
func (d *database) getPostParameters(path string) (params map[string][]string, err error) {
rows, err := d.query("select parameter, value from post_parameters where path = @path order by id", sql.Named("path", path))
if err != nil {
return nil, err
}
var name, value string
params = map[string][]string{}
for rows.Next() {
if err = rows.Scan(&name, &value); err != nil {
return nil, err
}
params[name] = append(params[name], value)
}
return params, nil
} }
func (d *database) getPosts(config *postsRequestConfig) (posts []*post, err error) { func (d *database) getPosts(config *postsRequestConfig) (posts []*post, err error) {
// Query posts // Query posts
query, queryParams := buildPostsQuery(config) query, queryParams := buildPostsQuery(config, "path, coalesce(content, ''), coalesce(published, ''), coalesce(updated, ''), blog, coalesce(section, ''), status")
rows, err := d.query(query, queryParams...) rows, err := d.query(query, queryParams...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Prepare row scanning (this is a bit dirty, but it's much faster) // Prepare row scanning
postsMap := map[string]*post{} var path, content, published, updated, blog, section, status string
var postsOrder []string
var path, parameterName, parameterValue string
columns, _ := rows.Columns()
rawBuffer := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(columns))
for i := range rawBuffer {
scanArgs[i] = &rawBuffer[i]
}
for rows.Next() { for rows.Next() {
if err = rows.Scan(scanArgs...); err != nil { if err = rows.Scan(&path, &content, &published, &updated, &blog, &section, &status); err != nil {
return nil, err return nil, err
} }
path = string(rawBuffer[0]) // Create new post, fill and add to list
parameterName = string(rawBuffer[7]) p := &post{
parameterValue = string(rawBuffer[8]) Path: path,
if p, ok := postsMap[path]; ok { Content: content,
// Post already exists, add parameter Published: toLocalSafe(published),
p.Parameters[parameterName] = append(p.Parameters[parameterName], parameterValue) Updated: toLocalSafe(updated),
} else { Blog: blog,
// Create new post, fill and add to map Section: section,
p := &post{ Status: postStatus(status),
Path: path,
Content: string(rawBuffer[1]),
Published: toLocalSafe(string(rawBuffer[2])),
Updated: toLocalSafe(string(rawBuffer[3])),
Blog: string(rawBuffer[4]),
Section: string(rawBuffer[5]),
Status: postStatus(string(rawBuffer[6])),
Parameters: map[string][]string{},
}
if parameterName != "" {
p.Parameters[parameterName] = append(p.Parameters[parameterName], parameterValue)
}
postsMap[path] = p
postsOrder = append(postsOrder, path)
} }
} if !config.withoutParameters {
// Copy map items to list, because map has a random order if p.Parameters, err = d.getPostParameters(path); err != nil {
for _, path = range postsOrder { return nil, err
posts = append(posts, postsMap[path]) }
}
posts = append(posts, p)
} }
return posts, nil return posts, nil
} }
func (d *database) getPost(path string) (*post, error) { func (d *database) getPost(path string) (*post, error) {
posts, err := d.getPosts(&postsRequestConfig{path: path}) posts, err := d.getPosts(&postsRequestConfig{path: path, limit: 1})
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(posts) == 0 { } else if len(posts) == 0 {
@ -368,9 +368,8 @@ func (d *database) getDrafts(blog string) []*post {
} }
func (d *database) countPosts(config *postsRequestConfig) (count int, err error) { func (d *database) countPosts(config *postsRequestConfig) (count int, err error) {
query, params := buildPostsQuery(config) query, params := buildPostsQuery(config, "path")
query = "select count(distinct path) from (" + query + ")" row, err := d.queryRow("select count(distinct path) from ("+query+")", params...)
row, err := d.queryRow(query, params...)
if err != nil { if err != nil {
return return
} }
@ -394,29 +393,36 @@ func (d *database) allPostPaths(status postStatus) ([]string, error) {
return postPaths, nil return postPaths, nil
} }
func (a *goBlog) getRandomPostPath(blog string) (string, error) { func (a *goBlog) getRandomPostPath(blog string) (path string, err error) {
sections, ok := funk.Keys(a.cfg.Blogs[blog].Sections).([]string) sections, ok := funk.Keys(a.cfg.Blogs[blog].Sections).([]string)
if !ok { if !ok {
return "", errors.New("no sections") return "", errors.New("no sections")
} }
posts, err := a.db.getPosts(&postsRequestConfig{randomOrder: true, limit: 1, blog: blog, sections: sections}) query, params := buildPostsQuery(&postsRequestConfig{randomOrder: true, limit: 1, blog: blog, sections: sections}, "path")
row, err := a.db.queryRow(query, params...)
if err != nil { if err != nil {
return "", err return
} else if len(posts) == 0 {
return "", errPostNotFound
} }
return posts[0].Path, nil err = row.Scan(&path)
if errors.Is(err, sql.ErrNoRows) {
return "", errPostNotFound
} else if err != nil {
return "", err
}
return path, nil
} }
func (d *database) allTaxonomyValues(blog string, taxonomy string) ([]string, error) { func (d *database) allTaxonomyValues(blog string, taxonomy string) ([]string, error) {
var values []string var values []string
rows, err := d.query("select distinct pp.value from posts p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @tax and length(coalesce(pp.value, '')) > 1 and blog = @blog and status = @status", sql.Named("tax", taxonomy), sql.Named("blog", blog), sql.Named("status", statusPublished)) rows, err := d.query("select distinct value from post_parameters where parameter = @tax and length(coalesce(value, '')) > 0 and path in (select path from posts where blog = @blog and status = @status) order by value", sql.Named("tax", taxonomy), sql.Named("blog", blog), sql.Named("status", statusPublished))
if err != nil { if err != nil {
return nil, err return nil, err
} }
var value string
for rows.Next() { for rows.Next() {
var value string if err = rows.Scan(&value); err != nil {
_ = rows.Scan(&value) return nil, err
}
values = append(values, value) values = append(values, value)
} }
return values, nil return values, nil

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"path/filepath"
"testing" "testing"
"time" "time"
@ -14,6 +15,9 @@ func Test_postsDb(t *testing.T) {
app := &goBlog{ app := &goBlog{
cfg: &config{ cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Blogs: map[string]*configBlog{ Blogs: map[string]*configBlog{
"en": { "en": {
Sections: map[string]*section{ Sections: map[string]*section{
@ -23,7 +27,7 @@ func Test_postsDb(t *testing.T) {
}, },
}, },
} }
app.setInMemoryDatabase() app.initDatabase(false)
now := toLocalSafe(time.Now().String()) now := toLocalSafe(time.Now().String())
nowPlus1Hour := toLocalSafe(time.Now().Add(1 * time.Hour).String()) nowPlus1Hour := toLocalSafe(time.Now().Add(1 * time.Hour).String())
@ -39,13 +43,15 @@ func Test_postsDb(t *testing.T) {
Status: statusDraft, Status: statusDraft,
Parameters: map[string][]string{ Parameters: map[string][]string{
"title": {"Title"}, "title": {"Title"},
"tags": {"C", "A", "B"},
"empty": {},
}, },
}, &postCreationOptions{new: true}) }, &postCreationOptions{new: true})
must.NoError(err) must.NoError(err)
// Check post // Check post
p, err := app.db.getPost("/test/abc") p, err := app.db.getPost("/test/abc")
is.NoError(err) must.NoError(err)
is.Equal("/test/abc", p.Path) is.Equal("/test/abc", p.Path)
is.Equal("ABC", p.Content) is.Equal("ABC", p.Content)
is.Equal(now, p.Published) is.Equal(now, p.Published)
@ -54,29 +60,38 @@ func Test_postsDb(t *testing.T) {
is.Equal("test", p.Section) is.Equal("test", p.Section)
is.Equal(statusDraft, p.Status) is.Equal(statusDraft, p.Status)
is.Equal("Title", p.Title()) is.Equal("Title", p.Title())
is.Equal([]string{"C", "A", "B"}, p.Parameters["tags"])
// Check number of post paths // Check number of post paths
pp, err := app.db.allPostPaths(statusDraft) pp, err := app.db.allPostPaths(statusDraft)
is.NoError(err) must.NoError(err)
if is.Len(pp, 1) { if is.Len(pp, 1) {
is.Equal("/test/abc", pp[0]) is.Equal("/test/abc", pp[0])
} }
pp, err = app.db.allPostPaths(statusPublished) pp, err = app.db.allPostPaths(statusPublished)
is.NoError(err) must.NoError(err)
is.Len(pp, 0) is.Len(pp, 0)
// Check drafts // Check drafts
drafts := app.db.getDrafts("en") drafts := app.db.getDrafts("en")
is.Len(drafts, 1) is.Len(drafts, 1)
// Check by parameter
count, err := app.db.countPosts(&postsRequestConfig{parameter: "tags"})
must.NoError(err)
is.Equal(1, count)
count, err = app.db.countPosts(&postsRequestConfig{parameter: "empty"})
must.NoError(err)
is.Equal(0, count)
// Delete post // Delete post
_, err = app.db.deletePost("/test/abc") _, err = app.db.deletePost("/test/abc")
must.NoError(err) must.NoError(err)
// Check that there is no post // Check that there is no post
count, err := app.db.countPosts(&postsRequestConfig{}) count, err = app.db.countPosts(&postsRequestConfig{})
is.NoError(err) must.NoError(err)
is.Equal(0, count) is.Equal(0, count)
// Save published post // Save published post
@ -89,16 +104,20 @@ func Test_postsDb(t *testing.T) {
Section: "test", Section: "test",
Status: statusPublished, Status: statusPublished,
Parameters: map[string][]string{ Parameters: map[string][]string{
"tags": {"Test", "Blog"}, "tags": {"Test", "Blog", "A"},
}, },
}, &postCreationOptions{new: true}) }, &postCreationOptions{new: true})
must.NoError(err) must.NoError(err)
// Check that there is a new post // Check that there is a new post
count, err = app.db.countPosts(&postsRequestConfig{}) count, err = app.db.countPosts(&postsRequestConfig{})
if is.NoError(err) { must.NoError(err)
is.Equal(1, count) is.Equal(1, count)
}
// Check based on offset
count, err = app.db.countPosts(&postsRequestConfig{limit: 10, offset: 1})
must.NoError(err)
is.Equal(0, count)
// Check random post path // Check random post path
rp, err := app.getRandomPostPath("en") rp, err := app.getRandomPostPath("en")
@ -109,8 +128,8 @@ func Test_postsDb(t *testing.T) {
// Check taxonomies // Check taxonomies
tags, err := app.db.allTaxonomyValues("en", "tags") tags, err := app.db.allTaxonomyValues("en", "tags")
if is.NoError(err) { if is.NoError(err) {
is.Len(tags, 2) is.Len(tags, 3)
is.Equal([]string{"Test", "Blog"}, tags) is.Equal([]string{"A", "Blog", "Test"}, tags)
} }
// Check based on date // Check based on date
@ -128,6 +147,34 @@ func Test_postsDb(t *testing.T) {
is.Equal(1, count) is.Equal(1, count)
} }
count, err = app.db.countPosts(&postsRequestConfig{
publishedMonth: 5,
})
if is.NoError(err) {
is.Equal(0, count)
}
count, err = app.db.countPosts(&postsRequestConfig{
publishedMonth: 6,
})
if is.NoError(err) {
is.Equal(1, count)
}
count, err = app.db.countPosts(&postsRequestConfig{
publishedDay: 15,
})
if is.NoError(err) {
is.Equal(0, count)
}
count, err = app.db.countPosts(&postsRequestConfig{
publishedDay: 10,
})
if is.NoError(err) {
is.Equal(1, count)
}
// Check dates // Check dates
dates, err := app.db.allPublishedDates("en") dates, err := app.db.allPublishedDates("en")
if is.NoError(err) && is.NotEmpty(dates) { if is.NoError(err) && is.NotEmpty(dates) {
@ -168,9 +215,13 @@ func Test_ftsWithoutTitle(t *testing.T) {
// Added because there was a bug where there were no search results without title // Added because there was a bug where there were no search results without title
app := &goBlog{ app := &goBlog{
cfg: &config{}, cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
},
} }
app.setInMemoryDatabase() app.initDatabase(false)
err := app.db.savePost(&post{ err := app.db.savePost(&post{
Path: "/test/abc", Path: "/test/abc",
@ -192,9 +243,13 @@ func Test_ftsWithoutTitle(t *testing.T) {
func Test_usesOfMediaFile(t *testing.T) { func Test_usesOfMediaFile(t *testing.T) {
app := &goBlog{ app := &goBlog{
cfg: &config{}, cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
},
} }
app.setInMemoryDatabase() app.initDatabase(false)
err := app.db.savePost(&post{ err := app.db.savePost(&post{
Path: "/test/abc", Path: "/test/abc",

View File

@ -14,14 +14,21 @@ func (db *database) shortenPath(p string) (string, error) {
if p == "" { if p == "" {
return "", errors.New("empty path") return "", errors.New("empty path")
} }
id := db.getShortPathID(p) idi, err, _ := db.sp.Do(p, func() (interface{}, error) {
if id == -1 { id := db.getShortPathID(p)
_, err := db.exec("insert or ignore into shortpath (path) values (@path)", sql.Named("path", p)) if id == -1 {
if err != nil { _, err := db.exec("insert or ignore into shortpath (path) values (@path)", sql.Named("path", p))
return "", err if err != nil {
return nil, err
}
id = db.getShortPathID(p)
} }
id = db.getShortPathID(p) return id, nil
})
if err != nil {
return "", err
} }
id := idi.(int)
if id == -1 { if id == -1 {
return "", errors.New("failed to retrieve short path for " + p) return "", errors.New("failed to retrieve short path for " + p)
} }
@ -32,6 +39,9 @@ func (db *database) getShortPathID(p string) (id int) {
if p == "" { if p == "" {
return -1 return -1
} }
if idi, ok := db.spc.Load(p); ok {
return idi.(int)
}
row, err := db.queryRow("select id from shortpath where path = @path", sql.Named("path", p)) row, err := db.queryRow("select id from shortpath where path = @path", sql.Named("path", p))
if err != nil { if err != nil {
return -1 return -1
@ -40,6 +50,7 @@ func (db *database) getShortPathID(p string) (id int) {
if err != nil { if err != nil {
return -1 return -1
} }
db.spc.Store(p, id)
return id return id
} }

View File

@ -132,7 +132,7 @@ func (a *goBlog) serveSitemap(w http.ResponseWriter, r *http.Request) {
} }
} }
// Posts // Posts
if posts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished}); err == nil { if posts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished, withoutParameters: true}); err == nil {
for _, p := range posts { for _, p := range posts {
item := &sitemap.URL{Loc: a.fullPostURL(p)} item := &sitemap.URL{Loc: a.fullPostURL(p)}
var lastMod time.Time var lastMod time.Time

View File

@ -2,6 +2,7 @@ package main
import ( import (
"net/http" "net/http"
"path/filepath"
"testing" "testing"
"time" "time"
@ -114,6 +115,9 @@ func Test_telegram(t *testing.T) {
app := &goBlog{ app := &goBlog{
pPostHooks: []postHookFunc{}, pPostHooks: []postHookFunc{},
cfg: &config{ cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{ Server: &configServer{
PublicAddress: "https://example.com", PublicAddress: "https://example.com",
}, },
@ -129,7 +133,7 @@ func Test_telegram(t *testing.T) {
}, },
httpClient: fakeClient, httpClient: fakeClient,
} }
app.setInMemoryDatabase() app.initDatabase(false)
app.initTelegram() app.initTelegram()
@ -161,6 +165,9 @@ func Test_telegram(t *testing.T) {
app := &goBlog{ app := &goBlog{
pPostHooks: []postHookFunc{}, pPostHooks: []postHookFunc{},
cfg: &config{ cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{ Server: &configServer{
PublicAddress: "https://example.com", PublicAddress: "https://example.com",
}, },
@ -170,7 +177,7 @@ func Test_telegram(t *testing.T) {
}, },
httpClient: fakeClient, httpClient: fakeClient,
} }
app.setInMemoryDatabase() app.initDatabase(false)
app.initTelegram() app.initTelegram()