mirror of https://github.com/jlelse/GoBlog
Small fixes and higher test coverage
This commit is contained in:
parent
b83c09d5ac
commit
ca99b726a8
|
@ -174,7 +174,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) {
|
|||
_ = a.createWebmention(id, inReplyTo)
|
||||
} else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 {
|
||||
// May be a mention; find links to blog and save them as webmentions
|
||||
if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil {
|
||||
if links, err := allLinksFromHTMLString(content, id); err == nil {
|
||||
for _, link := range links {
|
||||
if strings.Contains(link, blogIri) {
|
||||
_ = a.createWebmention(id, link)
|
||||
|
|
2
check.go
2
check.go
|
@ -90,7 +90,7 @@ func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) err
|
|||
wg.Add(1)
|
||||
go func(p *post) {
|
||||
defer wg.Done()
|
||||
links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p))
|
||||
links, _ := allLinksFromHTMLString(string(a.absoluteHTML(p)), a.fullPostURL(p))
|
||||
for _, link := range links {
|
||||
linkChan <- stringPair{a.fullPostURL(p), link}
|
||||
}
|
||||
|
|
31
database.go
31
database.go
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
@ -13,20 +14,21 @@ import (
|
|||
|
||||
type database struct {
|
||||
db *sql.DB
|
||||
c context.Context
|
||||
cf context.CancelFunc
|
||||
stmts map[string]*sql.Stmt
|
||||
g singleflight.Group
|
||||
persistentCacheGroup singleflight.Group
|
||||
pc singleflight.Group
|
||||
}
|
||||
|
||||
func (a *goBlog) initDatabase() (err error) {
|
||||
// Setup db
|
||||
db, err := a.openDatabase(a.cfg.Db.File)
|
||||
db, err := a.openDatabase(a.cfg.Db.File, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Create appDB
|
||||
a.db = db
|
||||
db.vacuum()
|
||||
addShutdownFunc(func() {
|
||||
_ = db.close()
|
||||
log.Println("Closed database")
|
||||
|
@ -40,7 +42,7 @@ func (a *goBlog) initDatabase() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *goBlog) openDatabase(file string) (*database, error) {
|
||||
func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
|
||||
// Register driver
|
||||
dbDriverName := generateRandomString(15)
|
||||
sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{
|
||||
|
@ -85,13 +87,16 @@ func (a *goBlog) openDatabase(file string) (*database, error) {
|
|||
return nil, errors.New("sqlite not compiled with FTS5")
|
||||
}
|
||||
// Migrate DB
|
||||
err = migrateDb(db)
|
||||
err = migrateDb(db, logging)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, cf := context.WithCancel(context.Background())
|
||||
return &database{
|
||||
db: db,
|
||||
stmts: map[string]*sql.Stmt{},
|
||||
c: c,
|
||||
cf: cf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -109,21 +114,17 @@ func (db *database) dump(file string) {
|
|||
}
|
||||
|
||||
func (db *database) close() error {
|
||||
db.vacuum()
|
||||
db.cf()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *database) vacuum() {
|
||||
_, _ = db.exec("VACUUM")
|
||||
}
|
||||
|
||||
func (db *database) prepare(query string) (*sql.Stmt, error) {
|
||||
stmt, err, _ := db.g.Do(query, func() (interface{}, error) {
|
||||
stmt, ok := db.stmts[query]
|
||||
if ok && stmt != nil {
|
||||
return stmt, nil
|
||||
}
|
||||
stmt, err := db.db.Prepare(query)
|
||||
stmt, err := db.db.PrepareContext(db.c, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -141,12 +142,12 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stmt.Exec(args...)
|
||||
return stmt.ExecContext(db.c, args...)
|
||||
}
|
||||
|
||||
func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) {
|
||||
// Can't prepare the statement
|
||||
return db.db.Exec(query, args...)
|
||||
return db.db.ExecContext(db.c, query, args...)
|
||||
}
|
||||
|
||||
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
|
@ -154,7 +155,7 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stmt.Query(args...)
|
||||
return stmt.QueryContext(db.c, args...)
|
||||
}
|
||||
|
||||
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
|
||||
|
@ -162,7 +163,7 @@ func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stmt.QueryRow(args...), nil
|
||||
return stmt.QueryRowContext(db.c, args...), nil
|
||||
}
|
||||
|
||||
// Other things
|
||||
|
|
|
@ -7,10 +7,12 @@ import (
|
|||
"github.com/lopezator/migrator"
|
||||
)
|
||||
|
||||
func migrateDb(db *sql.DB) error {
|
||||
func migrateDb(db *sql.DB, logging bool) error {
|
||||
m, err := migrator.New(
|
||||
migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) {
|
||||
if logging {
|
||||
log.Printf(s, i)
|
||||
}
|
||||
})),
|
||||
migrator.Migrations(
|
||||
&migrator.Migration{
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_database(t *testing.T) {
|
||||
t.Run("Basic Database Test", func(t *testing.T) {
|
||||
app := &goBlog{}
|
||||
|
||||
db, err := app.openDatabase(":memory:", false)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.execMulti("create table test(test text);")
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.exec("insert into test (test) values ('Test')")
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
|
||||
row, err := db.queryRow("select count(test) from test")
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
var test1 int
|
||||
err = row.Scan(&test1)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if test1 != 1 {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
rows, err := db.query("select count(test), test from test")
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
var test2 int
|
||||
var testStr string
|
||||
if !rows.Next() {
|
||||
t.Error("No result row")
|
||||
}
|
||||
err = rows.Scan(&test2, &testStr)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if test2 != 1 || testStr != "Test" {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
err = db.close()
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
marktag "git.jlel.se/jlelse/goldmark-mark"
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
|
@ -62,7 +63,7 @@ func (a *goBlog) renderText(s string) string {
|
|||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return d.Text()
|
||||
return strings.TrimSpace(d.Text())
|
||||
}
|
||||
|
||||
// Extensions etc...
|
||||
|
@ -104,7 +105,7 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e
|
|||
_, _ = w.Write(util.EscapeHTML(newDestination))
|
||||
_, _ = w.WriteRune('"')
|
||||
// Open external links (links that start with "http") in new tab
|
||||
if bytes.HasPrefix(n.Destination, []byte("http")) {
|
||||
if isAbsoluteURL(string(n.Destination)) {
|
||||
_, _ = w.WriteString(` target="_blank" rel="noopener"`)
|
||||
}
|
||||
// Title
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_markdown(t *testing.T) {
|
||||
t.Run("Basic Markdown tests", func(t *testing.T) {
|
||||
app := &goBlog{
|
||||
cfg: &config{
|
||||
Server: &configServer{
|
||||
PublicAddress: "https://example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.initMarkdown()
|
||||
|
||||
// Relative / absolute links
|
||||
|
||||
rendered, err := app.renderMarkdown("[Relative](/relative)", false)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(rendered), `href="/relative"`) {
|
||||
t.Errorf("Wrong result, got %v", string(rendered))
|
||||
}
|
||||
|
||||
rendered, err = app.renderMarkdown("[Relative](/relative)", true)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(rendered), `href="https://example.com/relative"`) {
|
||||
t.Errorf("Wrong result, got %v", string(rendered))
|
||||
}
|
||||
if strings.Contains(string(rendered), `target="_blank"`) {
|
||||
t.Errorf("Wrong result, got %v", string(rendered))
|
||||
}
|
||||
|
||||
// External links
|
||||
|
||||
rendered, err = app.renderMarkdown("[External](https://example.com)", true)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(rendered), `target="_blank"`) {
|
||||
t.Errorf("Wrong result, got %v", string(rendered))
|
||||
}
|
||||
|
||||
// Link title
|
||||
|
||||
rendered, err = app.renderMarkdown(`[With title](https://example.com "Test-Title")`, true)
|
||||
if err != nil {
|
||||
t.Errorf("Error: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(rendered), `title="Test-Title"`) {
|
||||
t.Errorf("Wrong result, got %v", string(rendered))
|
||||
}
|
||||
|
||||
// Text
|
||||
|
||||
renderedText := app.renderText("**This** *is* [text](/)")
|
||||
if renderedText != "This is text" {
|
||||
t.Errorf("Wrong result, got \"%v\"", renderedText)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -12,7 +12,7 @@ func (db *database) cachePersistently(key string, data []byte) error {
|
|||
}
|
||||
|
||||
func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
|
||||
d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) {
|
||||
d, err, _ := db.pc.Do(key, func() (interface{}, error) {
|
||||
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
|
|
9
utils.go
9
utils.go
|
@ -59,15 +59,16 @@ func isAllowedHost(r *http.Request, hosts ...string) bool {
|
|||
}
|
||||
|
||||
func isAbsoluteURL(s string) bool {
|
||||
if !strings.HasPrefix(s, "https://") && !strings.HasPrefix(s, "http://") {
|
||||
return false
|
||||
}
|
||||
if _, err := url.Parse(s); err != nil {
|
||||
if u, err := url.Parse(s); err != nil || !u.IsAbs() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func allLinksFromHTMLString(html, baseURL string) ([]string, error) {
|
||||
return allLinksFromHTML(strings.NewReader(html), baseURL)
|
||||
}
|
||||
|
||||
func allLinksFromHTML(r io.Reader, baseURL string) ([]string, error) {
|
||||
doc, err := goquery.NewDocumentFromReader(r)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_urlize(t *testing.T) {
|
||||
if res := urlize("äbc ef"); res != "bc-ef" {
|
||||
t.Errorf("Wrong result, got: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sortedStrings(t *testing.T) {
|
||||
input := []string{"a", "c", "b"}
|
||||
if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
|
||||
t.Errorf("Wrong result, got: %v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateRandomString(t *testing.T) {
|
||||
if l := len(generateRandomString(30)); l != 30 {
|
||||
t.Errorf("Wrong length: %v", l)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isAllowedHost(t *testing.T) {
|
||||
req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil)
|
||||
req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil)
|
||||
|
||||
if isAllowedHost(req1, "example.com") != true {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
if isAllowedHost(req1, "example.net") != false {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
if isAllowedHost(req2, "example.com") != true {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
if isAllowedHost(req3, "example.com") != true {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isAbsoluteURL(t *testing.T) {
|
||||
if isAbsoluteURL("http://example.com") != true {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
if isAbsoluteURL("https://example.com") != true {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
|
||||
if isAbsoluteURL("/test") != false {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wordCount(t *testing.T) {
|
||||
if wordCount("abc def abc") != 3 {
|
||||
t.Error("Wrong result")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allLinksFromHTMLString(t *testing.T) {
|
||||
baseUrl := "https://example.net/post/abc"
|
||||
html := `<a href="relative1">Test</a><a href="relative1">Test</a><a href="/relative2">Test</a><a href="https://example.com">Test</a>`
|
||||
expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
|
||||
|
||||
if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
|
||||
t.Errorf("Got error: %v", err)
|
||||
} else if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Wrong result, got: %v", result)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue