diff --git a/activityPub.go b/activityPub.go
index a475a95..10dee17 100644
--- a/activityPub.go
+++ b/activityPub.go
@@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) {
_ = a.createWebmention(id, inReplyTo)
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
// May be a mention; find links to blog and save them as webmentions
- if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil {
+ if links, err := allLinksFromHTMLString(content, id); err == nil {
for _, link := range links {
if strings.Contains(link, blogIri) {
_ = a.createWebmention(id, link)
diff --git a/check.go b/check.go
index 919b306..14f042e 100644
--- a/check.go
+++ b/check.go
@@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err
wg.Add(1)
go func(p *post) {
defer wg.Done()
- links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p))
+ links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p))
for _, link := range links {
linkChan <- stringPair{a.fullPostURL(p), link}
}
diff --git a/database.go b/database.go
index 91b451d..5cac4fb 100644
--- a/database.go
+++ b/database.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"database/sql"
"errors"
"log"
@@ -12,21 +13,22 @@ import (
)
type database struct {
- db *sql.DB
- stmts map[string]*sql.Stmt
- g singleflight.Group
- persistentCacheGroup singleflight.Group
+ db *sql.DB
+ c context.Context
+ cf context.CancelFunc
+ stmts map[string]*sql.Stmt
+ g singleflight.Group
+ pc singleflight.Group
}
func (a *goBlog) initDatabase() (err error) {
// Setup db
- db, err := a.openDatabase(a.cfg.Db.File)
+ db, err := a.openDatabase(a.cfg.Db.File, true)
if err != nil {
return err
}
// Create appDB
a.db = db
- db.vacuum()
addShutdownFunc(func() {
_ = db.close()
log.Println("Closed database")
@@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) {
return nil
}
-func (a *goBlog) openDatabase(file string) (*database, error) {
+func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
// Register driver
dbDriverName := generateRandomString(15)
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
@@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) {
return nil, errors.New("sqlite not compiled with FTS5")
}
// Migrate DB
- err = migrateDb(db)
+ err = migrateDb(db, logging)
if err != nil {
return nil, err
}
+ c, cf := context.WithCancel(context.Background())
return &database{
db: db,
stmts: map[string]*sql.Stmt{},
+ c: c,
+ cf: cf,
}, nil
}
@@ -109,21 +114,17 @@ func (db *database) dump(file string) {
}
func (db *database) close() error {
- db.vacuum()
+ db.cf()
return db.db.Close()
}
-func (db *database) vacuum() {
- _, _ = db.exec("VACUUM")
-}
-
func (db *database) prepare(query string) (*sql.Stmt, error) {
stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
stmt, ok := db.stmts[query]
if ok && stmt != nil {
return stmt, nil
}
- stmt, err := db.db.Prepare(query)
+ stmt, err := db.db.PrepareContext(db.c, query)
if err != nil {
return nil, err
}
@@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
if err != nil {
return nil, err
}
- return stmt.Exec(args...)
+ return stmt.ExecContext(db.c, args...)
}
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
// Can't prepare the statement
- return db.db.Exec(query, args...)
+ return db.db.ExecContext(db.c, query, args...)
}
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
@@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
if err != nil {
return nil, err
}
- return stmt.Query(args...)
+ return stmt.QueryContext(db.c, args...)
}
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
@@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error
if err != nil {
return nil, err
}
- return stmt.QueryRow(args...), nil
+ return stmt.QueryRowContext(db.c, args...), nil
}
// Other things
diff --git a/databaseMigrations.go b/databaseMigrations.go
index aa06be3..8b6e2c1 100644
--- a/databaseMigrations.go
+++ b/databaseMigrations.go
@@ -7,10 +7,12 @@ import (
"github.com/lopezator/migrator"
)
-func migrateDb(db *sql.DB) error {
+func migrateDb(db *sql.DB, logging bool) error {
m, err := migrator.New(
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
- log.Printf(s, i)
+ if logging {
+ log.Printf(s, i)
+ }
})),
migrator.Migrations(
&migrator.Migration{
diff --git a/database_test.go b/database_test.go
new file mode 100644
index 0000000..c431848
--- /dev/null
+++ b/database_test.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "testing"
+)
+
+func Test_database(t *testing.T) {
+ t.Run("Basic Database Test", func(t *testing.T) {
+ app := &goBlog{}
+
+ db, err := app.openDatabase(":memory:", false)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+
+ _, err = db.execMulti("create table test(test text);")
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+
+ _, err = db.exec("insert into test (test) values ('Test')")
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+
+ row, err := db.queryRow("select count(test) from test")
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ var test1 int
+ err = row.Scan(&test1)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if test1 != 1 {
+ t.Error("Wrong result")
+ }
+
+ rows, err := db.query("select count(test), test from test")
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ var test2 int
+ var testStr string
+ if !rows.Next() {
+ t.Error("No result row")
+ }
+ err = rows.Scan(&test2, &testStr)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if test2 != 1 || testStr != "Test" {
+ t.Error("Wrong result")
+ }
+
+ err = db.close()
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ })
+}
diff --git a/markdown.go b/markdown.go
index 9ad8385..beb0669 100644
--- a/markdown.go
+++ b/markdown.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "strings"
marktag "git.jlel.se/jlelse/goldmark-mark"
"github.com/PuerkitoBio/goquery"
@@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string {
if err != nil {
return ""
}
- return d.Text()
+ return strings.TrimSpace(d.Text())
}
// Extensions etc...
@@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e
_, _ = w.Write(util.EscapeHTML(newDestination))
_, _ = w.WriteRune('"')
// Open external links (links that start with "http") in new tab
- if bytes.HasPrefix(n.Destination, []byte("http")) {
+ if isAbsoluteURL(string(n.Destination)) {
_, _ = w.WriteString(` target="_blank" rel="noopener"`)
}
// Title
diff --git a/markdown_test.go b/markdown_test.go
new file mode 100644
index 0000000..b194ab8
--- /dev/null
+++ b/markdown_test.go
@@ -0,0 +1,68 @@
+package main
+
+import (
+ "strings"
+ "testing"
+)
+
+func Test_markdown(t *testing.T) {
+ t.Run("Basic Markdown tests", func(t *testing.T) {
+ app := &goBlog{
+ cfg: &config{
+ Server: &configServer{
+ PublicAddress: "https://example.com",
+ },
+ },
+ }
+
+ app.initMarkdown()
+
+ // Relative / absolute links
+
+ rendered, err := app.renderMarkdown("[Relative](/relative)", false)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if !strings.Contains(string(rendered), `href="/relative"`) {
+ t.Errorf("Wrong result, got %v", string(rendered))
+ }
+
+ rendered, err = app.renderMarkdown("[Relative](/relative)", true)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if !strings.Contains(string(rendered), `href="https://example.com/relative"`) {
+ t.Errorf("Wrong result, got %v", string(rendered))
+ }
+ if strings.Contains(string(rendered), `target="_blank"`) {
+ t.Errorf("Wrong result, got %v", string(rendered))
+ }
+
+ // External links
+
+ rendered, err = app.renderMarkdown("[External](https://example.com)", true)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if !strings.Contains(string(rendered), `target="_blank"`) {
+ t.Errorf("Wrong result, got %v", string(rendered))
+ }
+
+ // Link title
+
+ rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true)
+ if err != nil {
+ t.Errorf("Error: %v", err)
+ }
+ if !strings.Contains(string(rendered), `title="Test-Title"`) {
+ t.Errorf("Wrong result, got %v", string(rendered))
+ }
+
+ // Text
+
+ renderedText := app.renderText("**This** *is* [text](/)")
+ if renderedText != "This is text" {
+ t.Errorf("Wrong result, got \"%v\"", renderedText)
+ }
+ })
+}
diff --git a/persistentCache.go b/persistentCache.go
index 07bf090..c10e643 100644
--- a/persistentCache.go
+++ b/persistentCache.go
@@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error {
}
func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
- d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) {
+ d, err, _ := db.pc.Do(key, func() (interface{}, error) {
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
diff --git a/utils.go b/utils.go
index f3ca3f8..77f78da 100644
--- a/utils.go
+++ b/utils.go
@@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool {
}
func isAbsoluteURL(s string) bool {
- if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") {
- return false
- }
- if _, err := url.Parse(s); err != nil {
+ if u, err := url.Parse(s); err != nil || !u.IsAbs() {
return false
}
return true
}
+func allLinksFromHTMLString(html, baseURL string) ([]string, error) {
+ return allLinksFromHTML(strings.NewReader(html), baseURL)
+}
+
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
doc, err := goquery.NewDocumentFromReader(r)
if err != nil {
diff --git a/utils_test.go b/utils_test.go
new file mode 100644
index 0000000..eb162e1
--- /dev/null
+++ b/utils_test.go
@@ -0,0 +1,81 @@
+package main
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "testing"
+)
+
+func Test_urlize(t *testing.T) {
+ if res := urlize("äbc ef"); res != "bc-ef" {
+ t.Errorf("Wrong result, got: %v", res)
+ }
+}
+
+func Test_sortedStrings(t *testing.T) {
+ input := []string{"a", "c", "b"}
+ if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
+ t.Errorf("Wrong result, got: %v", res)
+ }
+}
+
+func Test_generateRandomString(t *testing.T) {
+ if l := len(generateRandomString(30)); l != 30 {
+ t.Errorf("Wrong length: %v", l)
+ }
+}
+
+func Test_isAllowedHost(t *testing.T) {
+ req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
+ req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil)
+ req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil)
+
+ if isAllowedHost(req1, "example.com") != true {
+ t.Error("Wrong result")
+ }
+
+ if isAllowedHost(req1, "example.net") != false {
+ t.Error("Wrong result")
+ }
+
+ if isAllowedHost(req2, "example.com") != true {
+ t.Error("Wrong result")
+ }
+
+ if isAllowedHost(req3, "example.com") != true {
+ t.Error("Wrong result")
+ }
+}
+
+func Test_isAbsoluteURL(t *testing.T) {
+ if isAbsoluteURL("http://example.com") != true {
+ t.Error("Wrong result")
+ }
+
+ if isAbsoluteURL("https://example.com") != true {
+ t.Error("Wrong result")
+ }
+
+ if isAbsoluteURL("/test") != false {
+ t.Error("Wrong result")
+ }
+}
+
+func Test_wordCount(t *testing.T) {
+ if wordCount("abc def abc") != 3 {
+ t.Error("Wrong result")
+ }
+}
+
+func Test_allLinksFromHTMLString(t *testing.T) {
+ baseUrl := "https://example.net/post/abc"
+ html := `TestTestTestTest`
+ expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
+
+ if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
+ t.Errorf("Got error: %v", err)
+ } else if !reflect.DeepEqual(result, expected) {
+ t.Errorf("Wrong result, got: %v", result)
+ }
+}