mirror of https://github.com/jlelse/GoBlog
Small fixes and higher test coverage
This commit is contained in:
parent
b83c09d5ac
commit
ca99b726a8
|
@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) {
|
||||||
_ = a.createWebmention(id, inReplyTo)
|
_ = a.createWebmention(id, inReplyTo)
|
||||||
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
|
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
|
||||||
// May be a mention; find links to blog and save them as webmentions
|
// May be a mention; find links to blog and save them as webmentions
|
||||||
if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil {
|
if links, err := allLinksFromHTMLString(content, id); err == nil {
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
if strings.Contains(link, blogIri) {
|
if strings.Contains(link, blogIri) {
|
||||||
_ = a.createWebmention(id, link)
|
_ = a.createWebmention(id, link)
|
||||||
|
|
2
check.go
2
check.go
|
@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(p *post) {
|
go func(p *post) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p))
|
links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p))
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
linkChan <- stringPair{a.fullPostURL(p), link}
|
linkChan <- stringPair{a.fullPostURL(p), link}
|
||||||
}
|
}
|
||||||
|
|
37
database.go
37
database.go
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
@ -12,21 +13,22 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type database struct {
|
type database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
stmts map[string]*sql.Stmt
|
c context.Context
|
||||||
g singleflight.Group
|
cf context.CancelFunc
|
||||||
persistentCacheGroup singleflight.Group
|
stmts map[string]*sql.Stmt
|
||||||
|
g singleflight.Group
|
||||||
|
pc singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *goBlog) initDatabase() (err error) {
|
func (a *goBlog) initDatabase() (err error) {
|
||||||
// Setup db
|
// Setup db
|
||||||
db, err := a.openDatabase(a.cfg.Db.File)
|
db, err := a.openDatabase(a.cfg.Db.File, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Create appDB
|
// Create appDB
|
||||||
a.db = db
|
a.db = db
|
||||||
db.vacuum()
|
|
||||||
addShutdownFunc(func() {
|
addShutdownFunc(func() {
|
||||||
_ = db.close()
|
_ = db.close()
|
||||||
log.Println("Closed database")
|
log.Println("Closed database")
|
||||||
|
@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *goBlog) openDatabase(file string) (*database, error) {
|
func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
|
||||||
// Register driver
|
// Register driver
|
||||||
dbDriverName := generateRandomString(15)
|
dbDriverName := generateRandomString(15)
|
||||||
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
|
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
|
||||||
|
@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) {
|
||||||
return nil, errors.New("sqlite not compiled with FTS5")
|
return nil, errors.New("sqlite not compiled with FTS5")
|
||||||
}
|
}
|
||||||
// Migrate DB
|
// Migrate DB
|
||||||
err = migrateDb(db)
|
err = migrateDb(db, logging)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c, cf := context.WithCancel(context.Background())
|
||||||
return &database{
|
return &database{
|
||||||
db: db,
|
db: db,
|
||||||
stmts: map[string]*sql.Stmt{},
|
stmts: map[string]*sql.Stmt{},
|
||||||
|
c: c,
|
||||||
|
cf: cf,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,21 +114,17 @@ func (db *database) dump(file string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) close() error {
|
func (db *database) close() error {
|
||||||
db.vacuum()
|
db.cf()
|
||||||
return db.db.Close()
|
return db.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) vacuum() {
|
|
||||||
_, _ = db.exec("VACUUM")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *database) prepare(query string) (*sql.Stmt, error) {
|
func (db *database) prepare(query string) (*sql.Stmt, error) {
|
||||||
stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
|
stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
|
||||||
stmt, ok := db.stmts[query]
|
stmt, ok := db.stmts[query]
|
||||||
if ok && stmt != nil {
|
if ok && stmt != nil {
|
||||||
return stmt, nil
|
return stmt, nil
|
||||||
}
|
}
|
||||||
stmt, err := db.db.Prepare(query)
|
stmt, err := db.db.PrepareContext(db.c, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stmt.Exec(args...)
|
return stmt.ExecContext(db.c, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
|
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
|
||||||
// Can't prepare the statement
|
// Can't prepare the statement
|
||||||
return db.db.Exec(query, args...)
|
return db.db.ExecContext(db.c, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
|
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stmt.Query(args...)
|
return stmt.QueryContext(db.c, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
|
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
|
||||||
|
@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return stmt.QueryRow(args...), nil
|
return stmt.QueryRowContext(db.c, args...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other things
|
// Other things
|
||||||
|
|
|
@ -7,10 +7,12 @@ import (
|
||||||
"github.com/lopezator/migrator"
|
"github.com/lopezator/migrator"
|
||||||
)
|
)
|
||||||
|
|
||||||
func migrateDb(db *sql.DB) error {
|
func migrateDb(db *sql.DB, logging bool) error {
|
||||||
m, err := migrator.New(
|
m, err := migrator.New(
|
||||||
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
|
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
|
||||||
log.Printf(s, i)
|
if logging {
|
||||||
|
log.Printf(s, i)
|
||||||
|
}
|
||||||
})),
|
})),
|
||||||
migrator.Migrations(
|
migrator.Migrations(
|
||||||
&migrator.Migration{
|
&migrator.Migration{
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_database(t *testing.T) {
|
||||||
|
t.Run("Basic Database Test", func(t *testing.T) {
|
||||||
|
app := &goBlog{}
|
||||||
|
|
||||||
|
db, err := app.openDatabase(":memory:", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.execMulti("create table test(test text);")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.exec("insert into test (test) values ('Test')")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row, err := db.queryRow("select count(test) from test")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
var test1 int
|
||||||
|
err = row.Scan(&test1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if test1 != 1 {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.query("select count(test), test from test")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
var test2 int
|
||||||
|
var testStr string
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("No result row")
|
||||||
|
}
|
||||||
|
err = rows.Scan(&test2, &testStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if test2 != 1 || testStr != "Test" {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
marktag "git.jlel.se/jlelse/goldmark-mark"
|
marktag "git.jlel.se/jlelse/goldmark-mark"
|
||||||
"github.com/PuerkitoBio/goquery"
|
"github.com/PuerkitoBio/goquery"
|
||||||
|
@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return d.Text()
|
return strings.TrimSpace(d.Text())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extensions etc...
|
// Extensions etc...
|
||||||
|
@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e
|
||||||
_, _ = w.Write(util.EscapeHTML(newDestination))
|
_, _ = w.Write(util.EscapeHTML(newDestination))
|
||||||
_, _ = w.WriteRune('"')
|
_, _ = w.WriteRune('"')
|
||||||
// Open external links (links that start with "http") in new tab
|
// Open external links (links that start with "http") in new tab
|
||||||
if bytes.HasPrefix(n.Destination, []byte("http")) {
|
if isAbsoluteURL(string(n.Destination)) {
|
||||||
_, _ = w.WriteString(` target="_blank" rel="noopener"`)
|
_, _ = w.WriteString(` target="_blank" rel="noopener"`)
|
||||||
}
|
}
|
||||||
// Title
|
// Title
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_markdown(t *testing.T) {
|
||||||
|
t.Run("Basic Markdown tests", func(t *testing.T) {
|
||||||
|
app := &goBlog{
|
||||||
|
cfg: &config{
|
||||||
|
Server: &configServer{
|
||||||
|
PublicAddress: "https://example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app.initMarkdown()
|
||||||
|
|
||||||
|
// Relative / absolute links
|
||||||
|
|
||||||
|
rendered, err := app.renderMarkdown("[Relative](/relative)", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(rendered), `href="/relative"`) {
|
||||||
|
t.Errorf("Wrong result, got %v", string(rendered))
|
||||||
|
}
|
||||||
|
|
||||||
|
rendered, err = app.renderMarkdown("[Relative](/relative)", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(rendered), `href="https://example.com/relative"`) {
|
||||||
|
t.Errorf("Wrong result, got %v", string(rendered))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(rendered), `target="_blank"`) {
|
||||||
|
t.Errorf("Wrong result, got %v", string(rendered))
|
||||||
|
}
|
||||||
|
|
||||||
|
// External links
|
||||||
|
|
||||||
|
rendered, err = app.renderMarkdown("[External](https://example.com)", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(rendered), `target="_blank"`) {
|
||||||
|
t.Errorf("Wrong result, got %v", string(rendered))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link title
|
||||||
|
|
||||||
|
rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(rendered), `title="Test-Title"`) {
|
||||||
|
t.Errorf("Wrong result, got %v", string(rendered))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text
|
||||||
|
|
||||||
|
renderedText := app.renderText("**This** *is* [text](/)")
|
||||||
|
if renderedText != "This is text" {
|
||||||
|
t.Errorf("Wrong result, got \"%v\"", renderedText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
|
func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
|
||||||
d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) {
|
d, err, _ := db.pc.Do(key, func() (interface{}, error) {
|
||||||
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
|
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
9
utils.go
9
utils.go
|
@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAbsoluteURL(s string) bool {
|
func isAbsoluteURL(s string) bool {
|
||||||
if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") {
|
if u, err := url.Parse(s); err != nil || !u.IsAbs() {
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, err := url.Parse(s); err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allLinksFromHTMLString(html, baseURL string) ([]string, error) {
|
||||||
|
return allLinksFromHTML(strings.NewReader(html), baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
|
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
|
||||||
doc, err := goquery.NewDocumentFromReader(r)
|
doc, err := goquery.NewDocumentFromReader(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_urlize(t *testing.T) {
|
||||||
|
if res := urlize("äbc ef"); res != "bc-ef" {
|
||||||
|
t.Errorf("Wrong result, got: %v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_sortedStrings(t *testing.T) {
|
||||||
|
input := []string{"a", "c", "b"}
|
||||||
|
if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
|
||||||
|
t.Errorf("Wrong result, got: %v", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_generateRandomString(t *testing.T) {
|
||||||
|
if l := len(generateRandomString(30)); l != 30 {
|
||||||
|
t.Errorf("Wrong length: %v", l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_isAllowedHost(t *testing.T) {
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil)
|
||||||
|
req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil)
|
||||||
|
|
||||||
|
if isAllowedHost(req1, "example.com") != true {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAllowedHost(req1, "example.net") != false {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAllowedHost(req2, "example.com") != true {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAllowedHost(req3, "example.com") != true {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_isAbsoluteURL(t *testing.T) {
|
||||||
|
if isAbsoluteURL("http://example.com") != true {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAbsoluteURL("https://example.com") != true {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAbsoluteURL("/test") != false {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_wordCount(t *testing.T) {
|
||||||
|
if wordCount("abc def abc") != 3 {
|
||||||
|
t.Error("Wrong result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_allLinksFromHTMLString(t *testing.T) {
|
||||||
|
baseUrl := "https://example.net/post/abc"
|
||||||
|
html := `<a href="relative1">Test</a><a href="relative1">Test</a><a href="/relative2">Test</a><a href="https://example.com">Test</a>`
|
||||||
|
expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
|
||||||
|
|
||||||
|
if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
|
||||||
|
t.Errorf("Got error: %v", err)
|
||||||
|
} else if !reflect.DeepEqual(result, expected) {
|
||||||
|
t.Errorf("Wrong result, got: %v", result)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue