Small fixes and higher test coverage

This commit is contained in:
Jan-Lukas Else 2021-06-14 16:29:22 +02:00
parent b83c09d5ac
commit ca99b726a8
10 changed files with 244 additions and 29 deletions

View File

@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) {
_ = a.createWebmention(id, inReplyTo) _ = a.createWebmention(id, inReplyTo)
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 { } else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
// May be a mention; find links to blog and save them as webmentions // May be a mention; find links to blog and save them as webmentions
if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil { if links, err := allLinksFromHTMLString(content, id); err == nil {
for _, link := range links { for _, link := range links {
if strings.Contains(link, blogIri) { if strings.Contains(link, blogIri) {
_ = a.createWebmention(id, link) _ = a.createWebmention(id, link)

View File

@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err
wg.Add(1) wg.Add(1)
go func(p *post) { go func(p *post) {
defer wg.Done() defer wg.Done()
links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p)) links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p))
for _, link := range links { for _, link := range links {
linkChan <- stringPair{a.fullPostURL(p), link} linkChan <- stringPair{a.fullPostURL(p), link}
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"log" "log"
@ -12,21 +13,22 @@ import (
) )
type database struct { type database struct {
db *sql.DB db *sql.DB
stmts map[string]*sql.Stmt c context.Context
g singleflight.Group cf context.CancelFunc
persistentCacheGroup singleflight.Group stmts map[string]*sql.Stmt
g singleflight.Group
pc singleflight.Group
} }
func (a *goBlog) initDatabase() (err error) { func (a *goBlog) initDatabase() (err error) {
// Setup db // Setup db
db, err := a.openDatabase(a.cfg.Db.File) db, err := a.openDatabase(a.cfg.Db.File, true)
if err != nil { if err != nil {
return err return err
} }
// Create appDB // Create appDB
a.db = db a.db = db
db.vacuum()
addShutdownFunc(func() { addShutdownFunc(func() {
_ = db.close() _ = db.close()
log.Println("Closed database") log.Println("Closed database")
@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) {
return nil return nil
} }
func (a *goBlog) openDatabase(file string) (*database, error) { func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
// Register driver // Register driver
dbDriverName := generateRandomString(15) dbDriverName := generateRandomString(15)
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{ sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) {
return nil, errors.New("sqlite not compiled with FTS5") return nil, errors.New("sqlite not compiled with FTS5")
} }
// Migrate DB // Migrate DB
err = migrateDb(db) err = migrateDb(db, logging)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, cf := context.WithCancel(context.Background())
return &database{ return &database{
db: db, db: db,
stmts: map[string]*sql.Stmt{}, stmts: map[string]*sql.Stmt{},
c: c,
cf: cf,
}, nil }, nil
} }
@ -109,21 +114,17 @@ func (db *database) dump(file string) {
} }
func (db *database) close() error { func (db *database) close() error {
db.vacuum() db.cf()
return db.db.Close() return db.db.Close()
} }
func (db *database) vacuum() {
_, _ = db.exec("VACUUM")
}
func (db *database) prepare(query string) (*sql.Stmt, error) { func (db *database) prepare(query string) (*sql.Stmt, error) {
stmt, err, _ := db.g.Do(query, func() (interface{}, error) { stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
stmt, ok := db.stmts[query] stmt, ok := db.stmts[query]
if ok && stmt != nil { if ok && stmt != nil {
return stmt, nil return stmt, nil
} }
stmt, err := db.db.Prepare(query) stmt, err := db.db.PrepareContext(db.c, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return stmt.Exec(args...) return stmt.ExecContext(db.c, args...)
} }
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) { func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
// Can't prepare the statement // Can't prepare the statement
return db.db.Exec(query, args...) return db.db.ExecContext(db.c, query, args...)
} }
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return stmt.Query(args...) return stmt.QueryContext(db.c, args...)
} }
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return stmt.QueryRow(args...), nil return stmt.QueryRowContext(db.c, args...), nil
} }
// Other things // Other things

View File

@ -7,10 +7,12 @@ import (
"github.com/lopezator/migrator" "github.com/lopezator/migrator"
) )
func migrateDb(db *sql.DB) error { func migrateDb(db *sql.DB, logging bool) error {
m, err := migrator.New( m, err := migrator.New(
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) { migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
log.Printf(s, i) if logging {
log.Printf(s, i)
}
})), })),
migrator.Migrations( migrator.Migrations(
&migrator.Migration{ &migrator.Migration{

61
database_test.go Normal file
View File

@ -0,0 +1,61 @@
package main
import (
"testing"
)
func Test_database(t *testing.T) {
t.Run("Basic Database Test", func(t *testing.T) {
app := &goBlog{}
db, err := app.openDatabase(":memory:", false)
if err != nil {
t.Errorf("Error: %v", err)
}
_, err = db.execMulti("create table test(test text);")
if err != nil {
t.Errorf("Error: %v", err)
}
_, err = db.exec("insert into test (test) values ('Test')")
if err != nil {
t.Errorf("Error: %v", err)
}
row, err := db.queryRow("select count(test) from test")
if err != nil {
t.Errorf("Error: %v", err)
}
var test1 int
err = row.Scan(&test1)
if err != nil {
t.Errorf("Error: %v", err)
}
if test1 != 1 {
t.Error("Wrong result")
}
rows, err := db.query("select count(test), test from test")
if err != nil {
t.Errorf("Error: %v", err)
}
var test2 int
var testStr string
if !rows.Next() {
t.Error("No result row")
}
err = rows.Scan(&test2, &testStr)
if err != nil {
t.Errorf("Error: %v", err)
}
if test2 != 1 || testStr != "Test" {
t.Error("Wrong result")
}
err = db.close()
if err != nil {
t.Errorf("Error: %v", err)
}
})
}

View File

@ -2,6 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"strings"
marktag "git.jlel.se/jlelse/goldmark-mark" marktag "git.jlel.se/jlelse/goldmark-mark"
"github.com/PuerkitoBio/goquery" "github.com/PuerkitoBio/goquery"
@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string {
if err != nil { if err != nil {
return "" return ""
} }
return d.Text() return strings.TrimSpace(d.Text())
} }
// Extensions etc... // Extensions etc...
@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e
_, _ = w.Write(util.EscapeHTML(newDestination)) _, _ = w.Write(util.EscapeHTML(newDestination))
_, _ = w.WriteRune('"') _, _ = w.WriteRune('"')
// Open external links (links that start with "http") in new tab // Open external links (links that start with "http") in new tab
if bytes.HasPrefix(n.Destination, []byte("http")) { if isAbsoluteURL(string(n.Destination)) {
_, _ = w.WriteString(` target="_blank" rel="noopener"`) _, _ = w.WriteString(` target="_blank" rel="noopener"`)
} }
// Title // Title

68
markdown_test.go Normal file
View File

@ -0,0 +1,68 @@
package main
import (
"strings"
"testing"
)
func Test_markdown(t *testing.T) {
t.Run("Basic Markdown tests", func(t *testing.T) {
app := &goBlog{
cfg: &config{
Server: &configServer{
PublicAddress: "https://example.com",
},
},
}
app.initMarkdown()
// Relative / absolute links
rendered, err := app.renderMarkdown("[Relative](/relative)", false)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `href="/relative"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
rendered, err = app.renderMarkdown("[Relative](/relative)", true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `href="https://example.com/relative"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
if strings.Contains(string(rendered), `target="_blank"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// External links
rendered, err = app.renderMarkdown("[External](https://example.com)", true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `target="_blank"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// Link title
rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `title="Test-Title"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// Text
renderedText := app.renderText("**This** *is* [text](/)")
if renderedText != "This is text" {
t.Errorf("Wrong result, got \"%v\"", renderedText)
}
})
}

View File

@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error {
} }
func (db *database) retrievePersistentCache(key string) (data []byte, err error) { func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) { d, err, _ := db.pc.Do(key, func() (interface{}, error) {
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows { if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {

View File

@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool {
} }
func isAbsoluteURL(s string) bool { func isAbsoluteURL(s string) bool {
if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") { if u, err := url.Parse(s); err != nil || !u.IsAbs() {
return false
}
if _, err := url.Parse(s); err != nil {
return false return false
} }
return true return true
} }
func allLinksFromHTMLString(html, baseURL string) ([]string, error) {
return allLinksFromHTML(strings.NewReader(html), baseURL)
}
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) { func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
doc, err := goquery.NewDocumentFromReader(r) doc, err := goquery.NewDocumentFromReader(r)
if err != nil { if err != nil {

81
utils_test.go Normal file
View File

@ -0,0 +1,81 @@
package main
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func Test_urlize(t *testing.T) {
if res := urlize("äbc ef"); res != "bc-ef" {
t.Errorf("Wrong result, got: %v", res)
}
}
func Test_sortedStrings(t *testing.T) {
input := []string{"a", "c", "b"}
if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
t.Errorf("Wrong result, got: %v", res)
}
}
func Test_generateRandomString(t *testing.T) {
if l := len(generateRandomString(30)); l != 30 {
t.Errorf("Wrong length: %v", l)
}
}
func Test_isAllowedHost(t *testing.T) {
req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil)
req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil)
if isAllowedHost(req1, "example.com") != true {
t.Error("Wrong result")
}
if isAllowedHost(req1, "example.net") != false {
t.Error("Wrong result")
}
if isAllowedHost(req2, "example.com") != true {
t.Error("Wrong result")
}
if isAllowedHost(req3, "example.com") != true {
t.Error("Wrong result")
}
}
func Test_isAbsoluteURL(t *testing.T) {
if isAbsoluteURL("http://example.com") != true {
t.Error("Wrong result")
}
if isAbsoluteURL("https://example.com") != true {
t.Error("Wrong result")
}
if isAbsoluteURL("/test") != false {
t.Error("Wrong result")
}
}
func Test_wordCount(t *testing.T) {
if wordCount("abc def abc") != 3 {
t.Error("Wrong result")
}
}
func Test_allLinksFromHTMLString(t *testing.T) {
baseUrl := "https://example.net/post/abc"
html := `<a href="relative1">Test</a><a href="relative1">Test</a><a href="/relative2">Test</a><a href="https://example.com">Test</a>`
expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
t.Errorf("Got error: %v", err)
} else if !reflect.DeepEqual(result, expected) {
t.Errorf("Wrong result, got: %v", result)
}
}