Small fixes and higher test coverage

This commit is contained in:
Jan-Lukas Else 2021-06-14 16:29:22 +02:00
parent b83c09d5ac
commit ca99b726a8
10 changed files with 244 additions and 29 deletions

View File

@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) {
_ = a.createWebmention(id, inReplyTo)
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
// May be a mention; find links to blog and save them as webmentions
if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil {
if links, err := allLinksFromHTMLString(content, id); err == nil {
for _, link := range links {
if strings.Contains(link, blogIri) {
_ = a.createWebmention(id, link)

View File

@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err
wg.Add(1)
go func(p *post) {
defer wg.Done()
links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p))
links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p))
for _, link := range links {
linkChan <- stringPair{a.fullPostURL(p), link}
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"database/sql"
"errors"
"log"
@ -12,21 +13,22 @@ import (
)
type database struct {
db *sql.DB
stmts map[string]*sql.Stmt
g singleflight.Group
persistentCacheGroup singleflight.Group
db *sql.DB
c context.Context
cf context.CancelFunc
stmts map[string]*sql.Stmt
g singleflight.Group
pc singleflight.Group
}
func (a *goBlog) initDatabase() (err error) {
// Setup db
db, err := a.openDatabase(a.cfg.Db.File)
db, err := a.openDatabase(a.cfg.Db.File, true)
if err != nil {
return err
}
// Create appDB
a.db = db
db.vacuum()
addShutdownFunc(func() {
_ = db.close()
log.Println("Closed database")
@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) {
return nil
}
func (a *goBlog) openDatabase(file string) (*database, error) {
func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
// Register driver
dbDriverName := generateRandomString(15)
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) {
return nil, errors.New("sqlite not compiled with FTS5")
}
// Migrate DB
err = migrateDb(db)
err = migrateDb(db, logging)
if err != nil {
return nil, err
}
c, cf := context.WithCancel(context.Background())
return &database{
db: db,
stmts: map[string]*sql.Stmt{},
c: c,
cf: cf,
}, nil
}
@ -109,21 +114,17 @@ func (db *database) dump(file string) {
}
func (db *database) close() error {
db.vacuum()
db.cf()
return db.db.Close()
}
func (db *database) vacuum() {
_, _ = db.exec("VACUUM")
}
func (db *database) prepare(query string) (*sql.Stmt, error) {
stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
stmt, ok := db.stmts[query]
if ok && stmt != nil {
return stmt, nil
}
stmt, err := db.db.Prepare(query)
stmt, err := db.db.PrepareContext(db.c, query)
if err != nil {
return nil, err
}
@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
if err != nil {
return nil, err
}
return stmt.Exec(args...)
return stmt.ExecContext(db.c, args...)
}
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
// Can't prepare the statement
return db.db.Exec(query, args...)
return db.db.ExecContext(db.c, query, args...)
}
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
if err != nil {
return nil, err
}
return stmt.Query(args...)
return stmt.QueryContext(db.c, args...)
}
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error
if err != nil {
return nil, err
}
return stmt.QueryRow(args...), nil
return stmt.QueryRowContext(db.c, args...), nil
}
// Other things

View File

@ -7,10 +7,12 @@ import (
"github.com/lopezator/migrator"
)
func migrateDb(db *sql.DB) error {
func migrateDb(db *sql.DB, logging bool) error {
m, err := migrator.New(
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
log.Printf(s, i)
if logging {
log.Printf(s, i)
}
})),
migrator.Migrations(
&migrator.Migration{

61
database_test.go Normal file
View File

@ -0,0 +1,61 @@
package main
import (
"testing"
)
func Test_database(t *testing.T) {
t.Run("Basic Database Test", func(t *testing.T) {
app := &goBlog{}
db, err := app.openDatabase(":memory:", false)
if err != nil {
t.Errorf("Error: %v", err)
}
_, err = db.execMulti("create table test(test text);")
if err != nil {
t.Errorf("Error: %v", err)
}
_, err = db.exec("insert into test (test) values ('Test')")
if err != nil {
t.Errorf("Error: %v", err)
}
row, err := db.queryRow("select count(test) from test")
if err != nil {
t.Errorf("Error: %v", err)
}
var test1 int
err = row.Scan(&test1)
if err != nil {
t.Errorf("Error: %v", err)
}
if test1 != 1 {
t.Error("Wrong result")
}
rows, err := db.query("select count(test), test from test")
if err != nil {
t.Errorf("Error: %v", err)
}
var test2 int
var testStr string
if !rows.Next() {
t.Error("No result row")
}
err = rows.Scan(&test2, &testStr)
if err != nil {
t.Errorf("Error: %v", err)
}
if test2 != 1 || testStr != "Test" {
t.Error("Wrong result")
}
err = db.close()
if err != nil {
t.Errorf("Error: %v", err)
}
})
}

View File

@ -2,6 +2,7 @@ package main
import (
"bytes"
"strings"
marktag "git.jlel.se/jlelse/goldmark-mark"
"github.com/PuerkitoBio/goquery"
@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string {
if err != nil {
return ""
}
return d.Text()
return strings.TrimSpace(d.Text())
}
// Extensions etc...
@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e
_, _ = w.Write(util.EscapeHTML(newDestination))
_, _ = w.WriteRune('"')
// Open external links (links that start with "http") in new tab
if bytes.HasPrefix(n.Destination, []byte("http")) {
if isAbsoluteURL(string(n.Destination)) {
_, _ = w.WriteString(` target="_blank" rel="noopener"`)
}
// Title

68
markdown_test.go Normal file
View File

@ -0,0 +1,68 @@
package main
import (
"strings"
"testing"
)
func Test_markdown(t *testing.T) {
t.Run("Basic Markdown tests", func(t *testing.T) {
app := &goBlog{
cfg: &config{
Server: &configServer{
PublicAddress: "https://example.com",
},
},
}
app.initMarkdown()
// Relative / absolute links
rendered, err := app.renderMarkdown("[Relative](/relative)", false)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `href="/relative"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
rendered, err = app.renderMarkdown("[Relative](/relative)", true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `href="https://example.com/relative"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
if strings.Contains(string(rendered), `target="_blank"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// External links
rendered, err = app.renderMarkdown("[External](https://example.com)", true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `target="_blank"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// Link title
rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true)
if err != nil {
t.Errorf("Error: %v", err)
}
if !strings.Contains(string(rendered), `title="Test-Title"`) {
t.Errorf("Wrong result, got %v", string(rendered))
}
// Text
renderedText := app.renderText("**This** *is* [text](/)")
if renderedText != "This is text" {
t.Errorf("Wrong result, got \"%v\"", renderedText)
}
})
}

View File

@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error {
}
func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) {
d, err, _ := db.pc.Do(key, func() (interface{}, error) {
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
return nil, nil
} else if err != nil {

View File

@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool {
}
func isAbsoluteURL(s string) bool {
if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") {
return false
}
if _, err := url.Parse(s); err != nil {
if u, err := url.Parse(s); err != nil || !u.IsAbs() {
return false
}
return true
}
func allLinksFromHTMLString(html, baseURL string) ([]string, error) {
return allLinksFromHTML(strings.NewReader(html), baseURL)
}
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
doc, err := goquery.NewDocumentFromReader(r)
if err != nil {

81
utils_test.go Normal file
View File

@ -0,0 +1,81 @@
package main
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func Test_urlize(t *testing.T) {
if res := urlize("äbc ef"); res != "bc-ef" {
t.Errorf("Wrong result, got: %v", res)
}
}
func Test_sortedStrings(t *testing.T) {
input := []string{"a", "c", "b"}
if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
t.Errorf("Wrong result, got: %v", res)
}
}
func Test_generateRandomString(t *testing.T) {
if l := len(generateRandomString(30)); l != 30 {
t.Errorf("Wrong length: %v", l)
}
}
func Test_isAllowedHost(t *testing.T) {
req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil)
req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil)
if isAllowedHost(req1, "example.com") != true {
t.Error("Wrong result")
}
if isAllowedHost(req1, "example.net") != false {
t.Error("Wrong result")
}
if isAllowedHost(req2, "example.com") != true {
t.Error("Wrong result")
}
if isAllowedHost(req3, "example.com") != true {
t.Error("Wrong result")
}
}
func Test_isAbsoluteURL(t *testing.T) {
if isAbsoluteURL("http://example.com") != true {
t.Error("Wrong result")
}
if isAbsoluteURL("https://example.com") != true {
t.Error("Wrong result")
}
if isAbsoluteURL("/test") != false {
t.Error("Wrong result")
}
}
func Test_wordCount(t *testing.T) {
if wordCount("abc def abc") != 3 {
t.Error("Wrong result")
}
}
func Test_allLinksFromHTMLString(t *testing.T) {
baseUrl := "https://example.net/post/abc"
html := `<a href="relative1">Test</a><a href="relative1">Test</a><a href="/relative2">Test</a><a href="https://example.com">Test</a>`
expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
t.Errorf("Got error: %v", err)
} else if !reflect.DeepEqual(result, expected) {
t.Errorf("Wrong result, got: %v", result)
}
}