From ca99b726a8d9eec34945039964e46da6c8d40743 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Mon, 14 Jun 2021 16:29:22 +0200 Subject: [PATCH] Small fixes and higher test coverage --- activityPub.go | 2 +- check.go | 2 +- database.go | 37 ++++++++++---------- databaseMigrations.go | 6 ++-- database_test.go | 61 ++++++++++++++++++++++++++++++++ markdown.go | 5 +-- markdown_test.go | 68 ++++++++++++++++++++++++++++++++++++ persistentCache.go | 2 +- utils.go | 9 ++--- utils_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 244 insertions(+), 29 deletions(-) create mode 100644 database_test.go create mode 100644 markdown_test.go create mode 100644 utils_test.go diff --git a/activityPub.go b/activityPub.go index a475a95..10dee17 100644 --- a/activityPub.go +++ b/activityPub.go @@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) { _ = a.createWebmention(id, inReplyTo) } else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 { // May be a mention; find links to blog and save them as webmentions - if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil { + if links, err := allLinksFromHTMLString(content, id); err == nil { for _, link := range links { if strings.Contains(link, blogIri) { _ = a.createWebmention(id, link) diff --git a/check.go b/check.go index 919b306..14f042e 100644 --- a/check.go +++ b/check.go @@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err wg.Add(1) go func(p *post) { defer wg.Done() - links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p)) + links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p)) for _, link := range links { linkChan <- stringPair{a.fullPostURL(p), link} } diff --git a/database.go b/database.go index 91b451d..5cac4fb 100644 --- a/database.go +++ b/database.go @@ -1,6 +1,7 @@ package main import ( + "context" "database/sql" "errors" "log" @@ -12,21 +13,22 @@ import ( ) type database struct { - db *sql.DB - stmts map[string]*sql.Stmt - g singleflight.Group - persistentCacheGroup singleflight.Group + db *sql.DB + c context.Context + cf context.CancelFunc + stmts map[string]*sql.Stmt + g singleflight.Group + pc singleflight.Group } func (a *goBlog) initDatabase() (err error) { // Setup db - db, err := a.openDatabase(a.cfg.Db.File) + db, err := a.openDatabase(a.cfg.Db.File, true) if err != nil { return err } // Create appDB a.db = db - db.vacuum() addShutdownFunc(func() { _ = db.close() log.Println("Closed database") @@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) { return nil } -func (a *goBlog) openDatabase(file string) (*database, error) { +func (a *goBlog) openDatabase(file string, logging bool) (*database, error) { // Register driver dbDriverName := generateRandomString(15) sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{ @@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) { return nil, errors.New("sqlite not compiled with FTS5") } // Migrate DB - err = migrateDb(db) + err = migrateDb(db, logging) if err != nil { return nil, err } + c, cf := context.WithCancel(context.Background()) return &database{ db: db, stmts: map[string]*sql.Stmt{}, + c: c, + cf: cf, }, nil } @@ -109,21 +114,17 @@ func (db *database) dump(file string) { } func (db *database) close() error { - db.vacuum() + db.cf() return db.db.Close() } -func (db *database) vacuum() { - _, _ = db.exec("VACUUM") -} - func (db *database) prepare(query string) (*sql.Stmt, error) { stmt, err, _ := db.g.Do(query, func() (interface{}, error) { stmt, ok := db.stmts[query] if ok && stmt != nil { return stmt, nil } - stmt, err := db.db.Prepare(query) + stmt, err := db.db.PrepareContext(db.c, query) if err != nil { return nil, err } @@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error) if err != nil { return nil, err } - return stmt.Exec(args...) + return stmt.ExecContext(db.c, args...) } func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) { // Can't prepare the statement - return db.db.Exec(query, args...) + return db.db.ExecContext(db.c, query, args...) } func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { @@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) if err != nil { return nil, err } - return stmt.Query(args...) + return stmt.QueryContext(db.c, args...) } func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { @@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error if err != nil { return nil, err } - return stmt.QueryRow(args...), nil + return stmt.QueryRowContext(db.c, args...), nil } // Other things diff --git a/databaseMigrations.go b/databaseMigrations.go index aa06be3..8b6e2c1 100644 --- a/databaseMigrations.go +++ b/databaseMigrations.go @@ -7,10 +7,12 @@ import ( "github.com/lopezator/migrator" ) -func migrateDb(db *sql.DB) error { +func migrateDb(db *sql.DB, logging bool) error { m, err := migrator.New( migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) { - log.Printf(s, i) + if logging { + log.Printf(s, i) + } })), migrator.Migrations( &migrator.Migration{ diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..c431848 --- /dev/null +++ b/database_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "testing" +) + +func Test_database(t *testing.T) { + t.Run("Basic Database Test", func(t *testing.T) { + app := &goBlog{} + + db, err := app.openDatabase(":memory:", false) + if err != nil { + t.Errorf("Error: %v", err) + } + + _, err = db.execMulti("create table test(test text);") + if err != nil { + t.Errorf("Error: %v", err) + } + + _, err = db.exec("insert into test (test) values ('Test')") + if err != nil { + t.Errorf("Error: %v", err) + } + + row, err := db.queryRow("select count(test) from test") + if err != nil { + t.Errorf("Error: %v", err) + } + var test1 int + err = row.Scan(&test1) + if err != nil { + t.Errorf("Error: %v", err) + } + if test1 != 1 { + t.Error("Wrong result") + } + + rows, err := db.query("select count(test), test from test") + if err != nil { + t.Errorf("Error: %v", err) + } + var test2 int + var testStr string + if !rows.Next() { + t.Error("No result row") + } + err = rows.Scan(&test2, &testStr) + if err != nil { + t.Errorf("Error: %v", err) + } + if test2 != 1 || testStr != "Test" { + t.Error("Wrong result") + } + + err = db.close() + if err != nil { + t.Errorf("Error: %v", err) + } + }) +} diff --git a/markdown.go b/markdown.go index 9ad8385..beb0669 100644 --- a/markdown.go +++ b/markdown.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "strings" marktag "git.jlel.se/jlelse/goldmark-mark" "github.com/PuerkitoBio/goquery" @@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string { if err != nil { return "" } - return d.Text() + return strings.TrimSpace(d.Text()) } // Extensions etc... @@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e _, _ = w.Write(util.EscapeHTML(newDestination)) _, _ = w.WriteRune('"') // Open external links (links that start with "http") in new tab - if bytes.HasPrefix(n.Destination, []byte("http")) { + if isAbsoluteURL(string(n.Destination)) { _, _ = w.WriteString(` target="_blank" rel="noopener"`) } // Title diff --git a/markdown_test.go b/markdown_test.go new file mode 100644 index 0000000..b194ab8 --- /dev/null +++ b/markdown_test.go @@ -0,0 +1,68 @@ +package main + +import ( + "strings" + "testing" +) + +func Test_markdown(t *testing.T) { + t.Run("Basic Markdown tests", func(t *testing.T) { + app := &goBlog{ + cfg: &config{ + Server: &configServer{ + PublicAddress: "https://example.com", + }, + }, + } + + app.initMarkdown() + + // Relative / absolute links + + rendered, err := app.renderMarkdown("[Relative](/relative)", false) + if err != nil { + t.Errorf("Error: %v", err) + } + if !strings.Contains(string(rendered), `href="/relative"`) { + t.Errorf("Wrong result, got %v", string(rendered)) + } + + rendered, err = app.renderMarkdown("[Relative](/relative)", true) + if err != nil { + t.Errorf("Error: %v", err) + } + if !strings.Contains(string(rendered), `href="https://example.com/relative"`) { + t.Errorf("Wrong result, got %v", string(rendered)) + } + if strings.Contains(string(rendered), `target="_blank"`) { + t.Errorf("Wrong result, got %v", string(rendered)) + } + + // External links + + rendered, err = app.renderMarkdown("[External](https://example.com)", true) + if err != nil { + t.Errorf("Error: %v", err) + } + if !strings.Contains(string(rendered), `target="_blank"`) { + t.Errorf("Wrong result, got %v", string(rendered)) + } + + // Link title + + rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true) + if err != nil { + t.Errorf("Error: %v", err) + } + if !strings.Contains(string(rendered), `title="Test-Title"`) { + t.Errorf("Wrong result, got %v", string(rendered)) + } + + // Text + + renderedText := app.renderText("**This** *is* [text](/)") + if renderedText != "This is text" { + t.Errorf("Wrong result, got \"%v\"", renderedText) + } + }) +} diff --git a/persistentCache.go b/persistentCache.go index 07bf090..c10e643 100644 --- a/persistentCache.go +++ b/persistentCache.go @@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error { } func (db *database) retrievePersistentCache(key string) (data []byte, err error) { - d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) { + d, err, _ := db.pc.Do(key, func() (interface{}, error) { if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows { return nil, nil } else if err != nil { diff --git a/utils.go b/utils.go index f3ca3f8..77f78da 100644 --- a/utils.go +++ b/utils.go @@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool { } func isAbsoluteURL(s string) bool { - if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") { - return false - } - if _, err := url.Parse(s); err != nil { + if u, err := url.Parse(s); err != nil || !u.IsAbs() { return false } return true } +func allLinksFromHTMLString(html, baseURL string) ([]string, error) { + return allLinksFromHTML(strings.NewReader(html), baseURL) +} + func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) { doc, err := goquery.NewDocumentFromReader(r) if err != nil { diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..eb162e1 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func Test_urlize(t *testing.T) { + if res := urlize("äbc ef"); res != "bc-ef" { + t.Errorf("Wrong result, got: %v", res) + } +} + +func Test_sortedStrings(t *testing.T) { + input := []string{"a", "c", "b"} + if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) { + t.Errorf("Wrong result, got: %v", res) + } +} + +func Test_generateRandomString(t *testing.T) { + if l := len(generateRandomString(30)); l != 30 { + t.Errorf("Wrong length: %v", l) + } +} + +func Test_isAllowedHost(t *testing.T) { + req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil) + req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil) + + if isAllowedHost(req1, "example.com") != true { + t.Error("Wrong result") + } + + if isAllowedHost(req1, "example.net") != false { + t.Error("Wrong result") + } + + if isAllowedHost(req2, "example.com") != true { + t.Error("Wrong result") + } + + if isAllowedHost(req3, "example.com") != true { + t.Error("Wrong result") + } +} + +func Test_isAbsoluteURL(t *testing.T) { + if isAbsoluteURL("http://example.com") != true { + t.Error("Wrong result") + } + + if isAbsoluteURL("https://example.com") != true { + t.Error("Wrong result") + } + + if isAbsoluteURL("/test") != false { + t.Error("Wrong result") + } +} + +func Test_wordCount(t *testing.T) { + if wordCount("abc def abc") != 3 { + t.Error("Wrong result") + } +} + +func Test_allLinksFromHTMLString(t *testing.T) { + baseUrl := "https://example.net/post/abc" + html := `TestTestTestTest` + expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"} + + if result, err := allLinksFromHTMLString(html, baseUrl); err != nil { + t.Errorf("Got error: %v", err) + } else if !reflect.DeepEqual(result, expected) { + t.Errorf("Wrong result, got: %v", result) + } +}