Some improvements and data race fixes

This commit is contained in:
Jan-Lukas Else 2022-02-12 12:37:13 +01:00
parent d8caf1e6f5
commit 51eaf24ff2
12 changed files with 86 additions and 37 deletions

View File

@ -25,13 +25,13 @@ func (a *goBlog) initAPSendQueue() {
for { for {
qi, err := a.db.peekQueue("ap") qi, err := a.db.peekQueue("ap")
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("activitypub send queue:", err.Error())
continue continue
} else if qi != nil { } else if qi != nil {
var r apRequest var r apRequest
err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&r) err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&r)
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("activitypub send queue:", err.Error())
_ = a.db.dequeue(qi) _ = a.db.dequeue(qi)
continue continue
} }
@ -49,7 +49,7 @@ func (a *goBlog) initAPSendQueue() {
} }
err = a.db.dequeue(qi) err = a.db.dequeue(qi)
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("activitypub send queue:", err.Error())
} }
} else { } else {
// No item in the queue, wait a moment // No item in the queue, wait a moment

View File

@ -19,11 +19,13 @@ func Test_loadActivityPubPrivateKey(t *testing.T) {
}, },
}, },
} }
_ = app.initDatabase(false) err := app.initDatabase(false)
require.NoError(t, err)
defer app.db.close() defer app.db.close()
require.NotNil(t, app.db)
// Generate // Generate
err := app.loadActivityPubPrivateKey() err = app.loadActivityPubPrivateKey()
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, app.apPrivateKey) assert.NotNil(t, app.apPrivateKey)

View File

@ -129,6 +129,9 @@ func (a *goBlog) openDatabase(file string, logging bool) (*database, error) {
// Main features // Main features
func (db *database) dump(file string) { func (db *database) dump(file string) {
if db == nil || db.db == nil {
return
}
// Lock execution // Lock execution
db.em.Lock() db.em.Lock()
defer db.em.Unlock() defer db.em.Unlock()
@ -144,10 +147,16 @@ func (db *database) dump(file string) {
} }
func (db *database) close() error { func (db *database) close() error {
if db == nil || db.db == nil {
return nil
}
return db.db.Close() return db.db.Close()
} }
func (db *database) prepare(query string, args ...interface{}) (*sql.Stmt, []interface{}, error) { func (db *database) prepare(query string, args ...interface{}) (*sql.Stmt, []interface{}, error) {
if db == nil || db.db == nil {
return nil, nil, errors.New("database not initialized")
}
if len(args) > 0 && args[0] == dbNoCache { if len(args) > 0 && args[0] == dbNoCache {
return nil, args[1:], nil return nil, args[1:], nil
} }
@ -178,6 +187,9 @@ func (db *database) prepare(query string, args ...interface{}) (*sql.Stmt, []int
const dbNoCache = "nocache" const dbNoCache = "nocache"
func (db *database) exec(query string, args ...interface{}) (sql.Result, error) { func (db *database) exec(query string, args ...interface{}) (sql.Result, error) {
if db == nil || db.db == nil {
return nil, errors.New("database not initialized")
}
// Maybe prepare // Maybe prepare
st, args, _ := db.prepare(query, args...) st, args, _ := db.prepare(query, args...)
// Lock execution // Lock execution
@ -194,6 +206,9 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error)
} }
func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) {
if db == nil || db.db == nil {
return nil, errors.New("database not initialized")
}
// Maybe prepare // Maybe prepare
st, args, _ := db.prepare(query, args...) st, args, _ := db.prepare(query, args...)
// Prepare context, call hook // Prepare context, call hook
@ -207,6 +222,9 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error)
} }
func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) {
if db == nil || db.db == nil {
return nil, errors.New("database not initialized")
}
// Maybe prepare // Maybe prepare
st, args, _ := db.prepare(query, args...) st, args, _ := db.prepare(query, args...)
// Prepare context, call hook // Prepare context, call hook

View File

@ -17,12 +17,12 @@ func (a *goBlog) healthcheck() bool {
defer cancelFunc() defer cancelFunc()
req, err := http.NewRequestWithContext(timeoutContext, http.MethodGet, a.getFullAddress("/ping"), nil) req, err := http.NewRequestWithContext(timeoutContext, http.MethodGet, a.getFullAddress("/ping"), nil)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println("healthcheck:", err.Error())
return false return false
} }
resp, err := a.httpClient.Do(req) resp, err := a.httpClient.Do(req)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println("healthcheck:", err.Error())
return false return false
} }
defer resp.Body.Close() defer resp.Body.Close()

View File

@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
) )
type fakeHttpClient struct { type fakeHttpClient struct {
@ -11,11 +12,14 @@ type fakeHttpClient struct {
req *http.Request req *http.Request
res *http.Response res *http.Response
handler http.Handler handler http.Handler
mu sync.Mutex
} }
func newFakeHttpClient() *fakeHttpClient { func newFakeHttpClient() *fakeHttpClient {
fc := &fakeHttpClient{} fc := &fakeHttpClient{}
fc.Client = newHandlerClient(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { fc.Client = newHandlerClient(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
fc.mu.Lock()
defer fc.mu.Unlock()
fc.req = r fc.req = r
if fc.handler != nil { if fc.handler != nil {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -34,14 +38,18 @@ func newFakeHttpClient() *fakeHttpClient {
} }
func (c *fakeHttpClient) clean() { func (c *fakeHttpClient) clean() {
c.mu.Lock()
c.req = nil c.req = nil
c.res = nil c.res = nil
c.handler = nil c.handler = nil
c.mu.Unlock()
} }
func (c *fakeHttpClient) setHandler(handler http.Handler) { func (c *fakeHttpClient) setHandler(handler http.Handler) {
c.clean() c.clean()
c.mu.Lock()
c.handler = handler c.handler = handler
c.mu.Unlock()
} }
func (c *fakeHttpClient) setFakeResponse(statusCode int, body string) { func (c *fakeHttpClient) setFakeResponse(statusCode int, body string) {

View File

@ -43,9 +43,14 @@ func Test_indexNow(t *testing.T) {
}) })
// Wait for hooks to run // Wait for hooks to run
time.Sleep(300 * time.Millisecond) fc.mu.Lock()
for fc.req == nil {
fc.mu.Unlock()
time.Sleep(10 * time.Millisecond)
fc.mu.Lock()
}
fc.mu.Unlock()
// Check fake http client // Check fake http client
require.NotNil(t, fc.req)
require.Equal(t, "https://api.indexnow.org/indexnow?key="+app.inKey+"&url=http%3A%2F%2Flocalhost%3A8080%2Ftestpost", fc.req.URL.String()) require.Equal(t, "https://api.indexnow.org/indexnow?key="+app.inKey+"&url=http%3A%2F%2Flocalhost%3A8080%2Ftestpost", fc.req.URL.String())
} }

View File

@ -89,13 +89,13 @@ func (a *goBlog) renderText(s string) string {
return "" return ""
} }
pipeReader, pipeWriter := io.Pipe() pipeReader, pipeWriter := io.Pipe()
var err error
go func() { go func() {
err = a.renderMarkdownToWriter(pipeWriter, s, false) writeErr := a.renderMarkdownToWriter(pipeWriter, s, false)
_ = pipeWriter.Close() _ = pipeWriter.CloseWithError(writeErr)
}() }()
text := htmlTextFromReader(pipeReader) text, readErr := htmlTextFromReader(pipeReader)
if err != nil { _ = pipeReader.CloseWithError(readErr)
if readErr != nil {
return "" return ""
} }
return text return text
@ -106,13 +106,13 @@ func (a *goBlog) renderMdTitle(s string) string {
return "" return ""
} }
pipeReader, pipeWriter := io.Pipe() pipeReader, pipeWriter := io.Pipe()
var err error
go func() { go func() {
err = a.titleMd.Convert([]byte(s), pipeWriter) writeErr := a.titleMd.Convert([]byte(s), pipeWriter)
_ = pipeWriter.Close() _ = pipeWriter.CloseWithError(writeErr)
}() }()
text := htmlTextFromReader(pipeReader) text, readErr := htmlTextFromReader(pipeReader)
if err != nil { _ = pipeReader.CloseWithError(readErr)
if readErr != nil {
return "" return ""
} }
return text return text

View File

@ -6,11 +6,17 @@ import (
) )
func (db *database) cachePersistently(key string, data []byte) error { func (db *database) cachePersistently(key string, data []byte) error {
if db == nil {
return errors.New("database is nil")
}
_, err := db.exec("insert or replace into persistent_cache(key, data, date) values(@key, @data, @date)", sql.Named("key", key), sql.Named("data", data), sql.Named("date", utcNowString())) _, err := db.exec("insert or replace into persistent_cache(key, data, date) values(@key, @data, @date)", sql.Named("key", key), sql.Named("data", data), sql.Named("date", utcNowString()))
return err return err
} }
func (db *database) retrievePersistentCache(key string) (data []byte, err error) { func (db *database) retrievePersistentCache(key string) (data []byte, err error) {
if db == nil {
return nil, errors.New("database is nil")
}
d, err, _ := db.pc.Do(key, func() (interface{}, error) { d, err, _ := db.pc.Do(key, func() (interface{}, error) {
if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err != nil { if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err != nil {
return nil, err return nil, err

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"bufio"
"io" "io"
"net/http" "net/http"
@ -41,12 +42,15 @@ func (a *goBlog) renderWithStatusCode(w http.ResponseWriter, r *http.Request, st
// Render // Render
pipeReader, pipeWriter := io.Pipe() pipeReader, pipeWriter := io.Pipe()
go func() { go func() {
mw := a.min.Writer(contenttype.HTML, pipeWriter) bufferedPipeWriter := bufio.NewWriter(pipeWriter)
f(newHtmlBuilder(mw), data) minifyWriter := a.min.Writer(contenttype.HTML, bufferedPipeWriter)
_ = mw.Close() f(newHtmlBuilder(minifyWriter), data)
_ = minifyWriter.Close()
_ = bufferedPipeWriter.Flush()
_ = pipeWriter.Close() _ = pipeWriter.Close()
}() }()
_, _ = io.Copy(w, pipeReader) _, readErr := io.Copy(w, pipeReader)
_ = pipeReader.CloseWithError(readErr)
} }
func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { func (a *goBlog) checkRenderData(r *http.Request, data *renderData) {

View File

@ -117,8 +117,10 @@ func (a *goBlog) initChromaCSS() error {
// Generate and minify CSS // Generate and minify CSS
pipeReader, pipeWriter := io.Pipe() pipeReader, pipeWriter := io.Pipe()
go func() { go func() {
_ = chromahtml.New(chromahtml.ClassPrefix("c-")).WriteCSS(pipeWriter, chromaStyle) writeErr := chromahtml.New(chromahtml.ClassPrefix("c-")).WriteCSS(pipeWriter, chromaStyle)
_ = pipeWriter.Close() _ = pipeWriter.CloseWithError(writeErr)
}() }()
return a.compileAsset(chromaPath, pipeReader) readErr := a.compileAsset(chromaPath, pipeReader)
_ = pipeReader.CloseWithError(readErr)
return readErr
} }

View File

@ -221,10 +221,11 @@ func mBytesString(size int64) string {
} }
func htmlText(s string) string { func htmlText(s string) string {
return htmlTextFromReader(strings.NewReader(s)) text, _ := htmlTextFromReader(strings.NewReader(s))
return text
} }
func htmlTextFromReader(r io.Reader) string { func htmlTextFromReader(r io.Reader) (string, error) {
// Build policy to only allow a subset of HTML tags // Build policy to only allow a subset of HTML tags
textPolicy := bluemonday.StrictPolicy() textPolicy := bluemonday.StrictPolicy()
textPolicy.AllowElements("h1", "h2", "h3", "h4", "h5", "h6") // Headers textPolicy.AllowElements("h1", "h2", "h3", "h4", "h5", "h6") // Headers
@ -232,7 +233,10 @@ func htmlTextFromReader(r io.Reader) string {
textPolicy.AllowElements("ol", "ul", "li") // Lists textPolicy.AllowElements("ol", "ul", "li") // Lists
textPolicy.AllowElements("blockquote") // Blockquotes textPolicy.AllowElements("blockquote") // Blockquotes
// Read filtered HTML into document // Read filtered HTML into document
doc, _ := goquery.NewDocumentFromReader(textPolicy.SanitizeReader(r)) doc, err := goquery.NewDocumentFromReader(textPolicy.SanitizeReader(r))
if err != nil {
return "", err
}
var text strings.Builder var text strings.Builder
if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 { if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 {
// Input was real HTML, so build the text from the body // Input was real HTML, so build the text from the body
@ -242,25 +246,25 @@ func htmlTextFromReader(r io.Reader) string {
childs.Each(func(i int, sel *goquery.Selection) { childs.Each(func(i int, sel *goquery.Selection) {
if i > 0 && // Not first child if i > 0 && // Not first child
sel.Is("h1, h2, h3, h4, h5, h6, p, ol, ul, li, blockquote") { // All elements that start a new paragraph sel.Is("h1, h2, h3, h4, h5, h6, p, ol, ul, li, blockquote") { // All elements that start a new paragraph
text.WriteString("\n\n") _, _ = text.WriteString("\n\n")
} }
if sel.Is("ol > li") { // List item in ordered list if sel.Is("ol > li") { // List item in ordered list
fmt.Fprintf(&text, "%d. ", i+1) // Add list item number _, _ = fmt.Fprintf(&text, "%d. ", i+1) // Add list item number
} }
if sel.Children().Length() > 0 { // Has children if sel.Children().Length() > 0 { // Has children
printChilds(sel.Children()) // Recursive call to print childs printChilds(sel.Children()) // Recursive call to print childs
} else { } else {
text.WriteString(sel.Text()) // Print text _, _ = text.WriteString(sel.Text()) // Print text
} }
}) })
} }
printChilds(bodyChild) printChilds(bodyChild)
} else { } else {
// Input was probably just text, so just use the text // Input was probably just text, so just use the text
text.WriteString(doc.Text()) _, _ = text.WriteString(doc.Text())
} }
// Trim whitespace and return // Trim whitespace and return
return strings.TrimSpace(text.String()) return strings.TrimSpace(text.String()), nil
} }
func cleanHTMLText(s string) string { func cleanHTMLText(s string) string {

View File

@ -23,13 +23,13 @@ func (a *goBlog) initWebmentionQueue() {
for { for {
qi, err := a.db.peekQueue("wm") qi, err := a.db.peekQueue("wm")
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("webmention queue:", err.Error())
continue continue
} else if qi != nil { } else if qi != nil {
var m mention var m mention
err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&m) err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&m)
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("webmention queue:", err.Error())
_ = a.db.dequeue(qi) _ = a.db.dequeue(qi)
continue continue
} }
@ -39,7 +39,7 @@ func (a *goBlog) initWebmentionQueue() {
} }
err = a.db.dequeue(qi) err = a.db.dequeue(qi)
if err != nil { if err != nil {
log.Println(err.Error()) log.Println("webmention queue:", err.Error())
} }
} else { } else {
// No item in the queue, wait a moment // No item in the queue, wait a moment