From 51eaf24ff2907ba3404e7a3dbcc76ae936533d40 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Sat, 12 Feb 2022 12:37:13 +0100 Subject: [PATCH] Some improvements and data race fixes --- activityPubSending.go | 6 +++--- activityPub_test.go | 6 ++++-- database.go | 18 ++++++++++++++++++ healthcheck.go | 4 ++-- httpClient_test.go | 8 ++++++++ indexnow_test.go | 9 +++++++-- markdown.go | 20 ++++++++++---------- persistentCache.go | 6 ++++++ render.go | 12 ++++++++---- templateAssets.go | 8 +++++--- utils.go | 20 ++++++++++++-------- webmentionVerification.go | 6 +++--- 12 files changed, 86 insertions(+), 37 deletions(-) diff --git a/activityPubSending.go b/activityPubSending.go index bddbb0e..411cdc0 100644 --- a/activityPubSending.go +++ b/activityPubSending.go @@ -25,13 +25,13 @@ func (a *goBlog) initAPSendQueue() { for { qi, err := a.db.peekQueue("ap") if err != nil { - log.Println(err.Error()) + log.Println("activitypub send queue:", err.Error()) continue } else if qi != nil { var r apRequest err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&r) if err != nil { - log.Println(err.Error()) + log.Println("activitypub send queue:", err.Error()) _ = a.db.dequeue(qi) continue } @@ -49,7 +49,7 @@ func (a *goBlog) initAPSendQueue() { } err = a.db.dequeue(qi) if err != nil { - log.Println(err.Error()) + log.Println("activitypub send queue:", err.Error()) } } else { // No item in the queue, wait a moment diff --git a/activityPub_test.go b/activityPub_test.go index 91d44c2..f8621f5 100644 --- a/activityPub_test.go +++ b/activityPub_test.go @@ -19,11 +19,13 @@ func Test_loadActivityPubPrivateKey(t *testing.T) { }, }, } - _ = app.initDatabase(false) + err := app.initDatabase(false) + require.NoError(t, err) defer app.db.close() + require.NotNil(t, app.db) // Generate - err := app.loadActivityPubPrivateKey() + err = app.loadActivityPubPrivateKey() require.NoError(t, err) assert.NotNil(t, app.apPrivateKey) diff --git a/database.go b/database.go index 545d8f3..6eef4fa 100644 --- a/database.go +++ b/database.go @@ -129,6 +129,9 @@ func (a *goBlog) openDatabase(file string, logging bool) (*database, error) { // Main features func (db *database) dump(file string) { + if db == nil || db.db == nil { + return + } // Lock execution db.em.Lock() defer db.em.Unlock() @@ -144,10 +147,16 @@ func (db *database) dump(file string) { } func (db *database) close() error { + if db == nil || db.db == nil { + return nil + } return db.db.Close() } func (db *database) prepare(query string, args ...interface{}) (*sql.Stmt, []interface{}, error) { + if db == nil || db.db == nil { + return nil, nil, errors.New("database not initialized") + } if len(args) > 0 && args[0] == dbNoCache { return nil, args[1:], nil } @@ -178,6 +187,9 @@ func (db *database) prepare(query string, args ...interface{}) (*sql.Stmt, []int const dbNoCache = "nocache" func (db *database) exec(query string, args ...interface{}) (sql.Result, error) { + if db == nil || db.db == nil { + return nil, errors.New("database not initialized") + } // Maybe prepare st, args, _ := db.prepare(query, args...) // Lock execution @@ -194,6 +206,9 @@ func (db *database) exec(query string, args ...interface{}) (sql.Result, error) } func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { + if db == nil || db.db == nil { + return nil, errors.New("database not initialized") + } // Maybe prepare st, args, _ := db.prepare(query, args...) // Prepare context, call hook @@ -207,6 +222,9 @@ func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) } func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { + if db == nil || db.db == nil { + return nil, errors.New("database not initialized") + } // Maybe prepare st, args, _ := db.prepare(query, args...) // Prepare context, call hook diff --git a/healthcheck.go b/healthcheck.go index d229314..f8ab31f 100644 --- a/healthcheck.go +++ b/healthcheck.go @@ -17,12 +17,12 @@ func (a *goBlog) healthcheck() bool { defer cancelFunc() req, err := http.NewRequestWithContext(timeoutContext, http.MethodGet, a.getFullAddress("/ping"), nil) if err != nil { - fmt.Println(err.Error()) + fmt.Println("healthcheck:", err.Error()) return false } resp, err := a.httpClient.Do(req) if err != nil { - fmt.Println(err.Error()) + fmt.Println("healthcheck:", err.Error()) return false } defer resp.Body.Close() diff --git a/httpClient_test.go b/httpClient_test.go index 2d4a001..1fd424e 100644 --- a/httpClient_test.go +++ b/httpClient_test.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync" ) type fakeHttpClient struct { @@ -11,11 +12,14 @@ type fakeHttpClient struct { req *http.Request res *http.Response handler http.Handler + mu sync.Mutex } func newFakeHttpClient() *fakeHttpClient { fc := &fakeHttpClient{} fc.Client = newHandlerClient(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + fc.mu.Lock() + defer fc.mu.Unlock() fc.req = r if fc.handler != nil { rec := httptest.NewRecorder() @@ -34,14 +38,18 @@ func newFakeHttpClient() *fakeHttpClient { } func (c *fakeHttpClient) clean() { + c.mu.Lock() c.req = nil c.res = nil c.handler = nil + c.mu.Unlock() } func (c *fakeHttpClient) setHandler(handler http.Handler) { c.clean() + c.mu.Lock() c.handler = handler + c.mu.Unlock() } func (c *fakeHttpClient) setFakeResponse(statusCode int, body string) { diff --git a/indexnow_test.go b/indexnow_test.go index f88f49b..2fe5f82 100644 --- a/indexnow_test.go +++ b/indexnow_test.go @@ -43,9 +43,14 @@ func Test_indexNow(t *testing.T) { }) // Wait for hooks to run - time.Sleep(300 * time.Millisecond) + fc.mu.Lock() + for fc.req == nil { + fc.mu.Unlock() + time.Sleep(10 * time.Millisecond) + fc.mu.Lock() + } + fc.mu.Unlock() // Check fake http client - require.NotNil(t, fc.req) require.Equal(t, "https://api.indexnow.org/indexnow?key="+app.inKey+"&url=http%3A%2F%2Flocalhost%3A8080%2Ftestpost", fc.req.URL.String()) } diff --git a/markdown.go b/markdown.go index c1d0785..1003926 100644 --- a/markdown.go +++ b/markdown.go @@ -89,13 +89,13 @@ func (a *goBlog) renderText(s string) string { return "" } pipeReader, pipeWriter := io.Pipe() - var err error go func() { - err = a.renderMarkdownToWriter(pipeWriter, s, false) - _ = pipeWriter.Close() + writeErr := a.renderMarkdownToWriter(pipeWriter, s, false) + _ = pipeWriter.CloseWithError(writeErr) }() - text := htmlTextFromReader(pipeReader) - if err != nil { + text, readErr := htmlTextFromReader(pipeReader) + _ = pipeReader.CloseWithError(readErr) + if readErr != nil { return "" } return text @@ -106,13 +106,13 @@ func (a *goBlog) renderMdTitle(s string) string { return "" } pipeReader, pipeWriter := io.Pipe() - var err error go func() { - err = a.titleMd.Convert([]byte(s), pipeWriter) - _ = pipeWriter.Close() + writeErr := a.titleMd.Convert([]byte(s), pipeWriter) + _ = pipeWriter.CloseWithError(writeErr) }() - text := htmlTextFromReader(pipeReader) - if err != nil { + text, readErr := htmlTextFromReader(pipeReader) + _ = pipeReader.CloseWithError(readErr) + if readErr != nil { return "" } return text diff --git a/persistentCache.go b/persistentCache.go index e32b221..1699fa7 100644 --- a/persistentCache.go +++ b/persistentCache.go @@ -6,11 +6,17 @@ import ( ) func (db *database) cachePersistently(key string, data []byte) error { + if db == nil { + return errors.New("database is nil") + } _, err := db.exec("insert or replace into persistent_cache(key, data, date) values(@key, @data, @date)", sql.Named("key", key), sql.Named("data", data), sql.Named("date", utcNowString())) return err } func (db *database) retrievePersistentCache(key string) (data []byte, err error) { + if db == nil { + return nil, errors.New("database is nil") + } d, err, _ := db.pc.Do(key, func() (interface{}, error) { if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err != nil { return nil, err diff --git a/render.go b/render.go index b793ef5..cc477bf 100644 --- a/render.go +++ b/render.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "io" "net/http" @@ -41,12 +42,15 @@ func (a *goBlog) renderWithStatusCode(w http.ResponseWriter, r *http.Request, st // Render pipeReader, pipeWriter := io.Pipe() go func() { - mw := a.min.Writer(contenttype.HTML, pipeWriter) - f(newHtmlBuilder(mw), data) - _ = mw.Close() + bufferedPipeWriter := bufio.NewWriter(pipeWriter) + minifyWriter := a.min.Writer(contenttype.HTML, bufferedPipeWriter) + f(newHtmlBuilder(minifyWriter), data) + _ = minifyWriter.Close() + _ = bufferedPipeWriter.Flush() _ = pipeWriter.Close() }() - _, _ = io.Copy(w, pipeReader) + _, readErr := io.Copy(w, pipeReader) + _ = pipeReader.CloseWithError(readErr) } func (a *goBlog) checkRenderData(r *http.Request, data *renderData) { diff --git a/templateAssets.go b/templateAssets.go index 0143a21..63c937b 100644 --- a/templateAssets.go +++ b/templateAssets.go @@ -117,8 +117,10 @@ func (a *goBlog) initChromaCSS() error { // Generate and minify CSS pipeReader, pipeWriter := io.Pipe() go func() { - _ = chromahtml.New(chromahtml.ClassPrefix("c-")).WriteCSS(pipeWriter, chromaStyle) - _ = pipeWriter.Close() + writeErr := chromahtml.New(chromahtml.ClassPrefix("c-")).WriteCSS(pipeWriter, chromaStyle) + _ = pipeWriter.CloseWithError(writeErr) }() - return a.compileAsset(chromaPath, pipeReader) + readErr := a.compileAsset(chromaPath, pipeReader) + _ = pipeReader.CloseWithError(readErr) + return readErr } diff --git a/utils.go b/utils.go index 3fa809a..cd5e273 100644 --- a/utils.go +++ b/utils.go @@ -221,10 +221,11 @@ func mBytesString(size int64) string { } func htmlText(s string) string { - return htmlTextFromReader(strings.NewReader(s)) + text, _ := htmlTextFromReader(strings.NewReader(s)) + return text } -func htmlTextFromReader(r io.Reader) string { +func htmlTextFromReader(r io.Reader) (string, error) { // Build policy to only allow a subset of HTML tags textPolicy := bluemonday.StrictPolicy() textPolicy.AllowElements("h1", "h2", "h3", "h4", "h5", "h6") // Headers @@ -232,7 +233,10 @@ func htmlTextFromReader(r io.Reader) string { textPolicy.AllowElements("ol", "ul", "li") // Lists textPolicy.AllowElements("blockquote") // Blockquotes // Read filtered HTML into document - doc, _ := goquery.NewDocumentFromReader(textPolicy.SanitizeReader(r)) + doc, err := goquery.NewDocumentFromReader(textPolicy.SanitizeReader(r)) + if err != nil { + return "", err + } var text strings.Builder if bodyChild := doc.Find("body").Children(); bodyChild.Length() > 0 { // Input was real HTML, so build the text from the body @@ -242,25 +246,25 @@ func htmlTextFromReader(r io.Reader) string { childs.Each(func(i int, sel *goquery.Selection) { if i > 0 && // Not first child sel.Is("h1, h2, h3, h4, h5, h6, p, ol, ul, li, blockquote") { // All elements that start a new paragraph - text.WriteString("\n\n") + _, _ = text.WriteString("\n\n") } if sel.Is("ol > li") { // List item in ordered list - fmt.Fprintf(&text, "%d. ", i+1) // Add list item number + _, _ = fmt.Fprintf(&text, "%d. ", i+1) // Add list item number } if sel.Children().Length() > 0 { // Has children printChilds(sel.Children()) // Recursive call to print childs } else { - text.WriteString(sel.Text()) // Print text + _, _ = text.WriteString(sel.Text()) // Print text } }) } printChilds(bodyChild) } else { // Input was probably just text, so just use the text - text.WriteString(doc.Text()) + _, _ = text.WriteString(doc.Text()) } // Trim whitespace and return - return strings.TrimSpace(text.String()) + return strings.TrimSpace(text.String()), nil } func cleanHTMLText(s string) string { diff --git a/webmentionVerification.go b/webmentionVerification.go index 9142269..7819534 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -23,13 +23,13 @@ func (a *goBlog) initWebmentionQueue() { for { qi, err := a.db.peekQueue("wm") if err != nil { - log.Println(err.Error()) + log.Println("webmention queue:", err.Error()) continue } else if qi != nil { var m mention err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&m) if err != nil { - log.Println(err.Error()) + log.Println("webmention queue:", err.Error()) _ = a.db.dequeue(qi) continue } @@ -39,7 +39,7 @@ func (a *goBlog) initWebmentionQueue() { } err = a.db.dequeue(qi) if err != nil { - log.Println(err.Error()) + log.Println("webmention queue:", err.Error()) } } else { // No item in the queue, wait a moment