diff --git a/authentication_test.go b/authentication_test.go
index 1425856..43416df 100644
--- a/authentication_test.go
+++ b/authentication_test.go
@@ -122,7 +122,7 @@ func Test_authMiddleware(t *testing.T) {
data.Add("password", "pass")
req := httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader(data.Encode()))
- req.Header.Add("Content-Type", contenttype.WWWForm)
+ req.Header.Add(contentType, contenttype.WWWForm)
rec := httptest.NewRecorder()
diff --git a/comments.go b/comments.go
index dad5f2a..cdd3a6e 100644
--- a/comments.go
+++ b/comments.go
@@ -2,16 +2,17 @@ package main
import (
"database/sql"
- "fmt"
"net/http"
"net/url"
+ "path"
"strconv"
"strings"
"github.com/go-chi/chi/v5"
- "github.com/microcosm-cc/bluemonday"
)
+const commentPath = "/comment"
+
type comment struct {
ID int
Target string
@@ -42,7 +43,7 @@ func (a *goBlog) serveComment(w http.ResponseWriter, r *http.Request) {
blog := r.Context().Value(blogKey).(string)
a.render(w, r, templateComment, &renderData{
BlogString: blog,
- Canonical: a.getFullAddress(a.cfg.Blogs[blog].getRelativePath(fmt.Sprintf("/comment/%d", id))),
+ Canonical: a.getFullAddress(a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(id)))),
Data: comment,
})
}
@@ -54,17 +55,13 @@ func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) {
return
}
// Check and clean comment
- strict := bluemonday.StrictPolicy()
- comment := strings.TrimSpace(strict.Sanitize(r.FormValue("comment")))
+ comment := cleanHTMLText(r.FormValue("comment"))
if comment == "" {
a.serveError(w, r, "Comment is empty", http.StatusBadRequest)
return
}
- name := strings.TrimSpace(strict.Sanitize(r.FormValue("name")))
- if name == "" {
- name = "Anonymous"
- }
- website := strings.TrimSpace(strict.Sanitize(r.FormValue("website")))
+ name := defaultIfEmpty(cleanHTMLText(r.FormValue("name")), "Anonymous")
+ website := cleanHTMLText(r.FormValue("website"))
// Insert
result, err := a.db.exec("insert into comments (target, comment, name, website) values (@target, @comment, @name, @website)", sql.Named("target", target), sql.Named("comment", comment), sql.Named("name", name), sql.Named("website", website))
if err != nil {
@@ -75,7 +72,8 @@ func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) {
// Serve error
a.serveError(w, r, err.Error(), http.StatusInternalServerError)
} else {
- commentAddress := fmt.Sprintf("%s/%d", a.getRelativePath(r.Context().Value(blogKey).(string), "/comment"), commentID)
+ blog := r.Context().Value(blogKey).(string)
+ commentAddress := a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(int(commentID))))
// Send webmention
_ = a.createWebmention(a.getFullAddress(commentAddress), a.getFullAddress(target))
// Redirect to comment
diff --git a/comments_test.go b/comments_test.go
new file mode 100644
index 0000000..2108675
--- /dev/null
+++ b/comments_test.go
@@ -0,0 +1,175 @@
+package main
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/go-chi/chi/v5"
+ "github.com/go-chi/chi/v5/middleware"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.goblog.app/app/pkgs/contenttype"
+)
+
+func Test_comments(t *testing.T) {
+ app := &goBlog{
+ cfg: &config{
+ Db: &configDb{
+ File: filepath.Join(t.TempDir(), "test.db"),
+ },
+ Server: &configServer{
+ PublicAddress: "https://example.com",
+ },
+ Blogs: map[string]*configBlog{
+ "en": {
+ Lang: "en",
+ },
+ },
+ DefaultBlog: "en",
+ User: &configUser{},
+ },
+ }
+
+ _ = app.initDatabase(false)
+ app.initComponents()
+
+ t.Run("Successful comment", func(t *testing.T) {
+
+ // Create comment
+
+ data := url.Values{}
+ data.Add("target", "https://example.com/test")
+ data.Add("comment", "This is just a test")
+ data.Add("name", "Test name")
+ data.Add("website", "https://goblog.app")
+
+ req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode()))
+ req.Header.Add(contentType, contenttype.WWWForm)
+ rec := httptest.NewRecorder()
+
+ app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en")))
+
+ res := rec.Result()
+
+ assert.Equal(t, http.StatusFound, res.StatusCode)
+ assert.Equal(t, "/comment/1", res.Header.Get("Location"))
+
+ // View comment
+
+ mux := chi.NewMux()
+ mux.Use(middleware.WithValue(blogKey, "en"))
+ mux.Get("/comment/{id}", app.serveComment)
+
+ req = httptest.NewRequest(http.MethodGet, "/comment/1", nil)
+ rec = httptest.NewRecorder()
+
+ mux.ServeHTTP(rec, req)
+
+ res = rec.Result()
+ resBody, _ := io.ReadAll(res.Body)
+ resBodyStr := string(resBody)
+
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+ assert.Contains(t, resBodyStr, "https://goblog.app")
+ assert.Contains(t, resBodyStr, "Test name")
+ assert.Contains(t, resBodyStr, "This is just a test")
+ assert.Contains(t, resBodyStr, "/test")
+
+ // Count comments
+
+ cc, err := app.db.countComments(&commentsRequestConfig{
+ limit: 100,
+ offset: 0,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, 1, cc)
+
+ // Get comment
+
+ comments, err := app.db.getComments(&commentsRequestConfig{})
+ require.NoError(t, err)
+ if assert.Len(t, comments, 1) {
+ comment := comments[0]
+ assert.Equal(t, "https://goblog.app", comment.Website)
+ assert.Equal(t, "Test name", comment.Name)
+ assert.Equal(t, "This is just a test", comment.Comment)
+ assert.Equal(t, "/test", comment.Target)
+ }
+
+ // Delete comment
+
+ err = app.db.deleteComment(1)
+ require.NoError(t, err)
+ cc, err = app.db.countComments(&commentsRequestConfig{})
+ require.NoError(t, err)
+ assert.Equal(t, 0, cc)
+
+ })
+
+ t.Run("Anonymous comment", func(t *testing.T) {
+
+ // Create comment
+
+ data := url.Values{}
+ data.Add("target", "https://example.com/test")
+ data.Add("comment", "This is just a test")
+
+ req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode()))
+ req.Header.Add(contentType, contenttype.WWWForm)
+ rec := httptest.NewRecorder()
+
+ app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en")))
+
+ res := rec.Result()
+
+ assert.Equal(t, http.StatusFound, res.StatusCode)
+ assert.Equal(t, "/comment/2", res.Header.Get("Location"))
+
+ // Get comment
+
+ comments, err := app.db.getComments(&commentsRequestConfig{})
+ require.NoError(t, err)
+ if assert.Len(t, comments, 1) {
+ comment := comments[0]
+ assert.Equal(t, "/test", comment.Target)
+ assert.Equal(t, "This is just a test", comment.Comment)
+ assert.Equal(t, "Anonymous", comment.Name)
+ assert.Equal(t, "", comment.Website)
+ }
+
+ // Delete comment
+
+ err = app.db.deleteComment(2)
+ require.NoError(t, err)
+
+ })
+
+ t.Run("Empty comment", func(t *testing.T) {
+
+ data := url.Values{}
+ data.Add("target", "https://example.com/test")
+ data.Add("comment", "")
+
+ req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode()))
+ req.Header.Add(contentType, contenttype.WWWForm)
+ rec := httptest.NewRecorder()
+
+ app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en")))
+
+ res := rec.Result()
+
+ assert.Equal(t, http.StatusBadRequest, res.StatusCode)
+
+ cc, err := app.db.countComments(&commentsRequestConfig{})
+ require.NoError(t, err)
+ assert.Equal(t, 0, cc)
+
+ })
+
+}
diff --git a/httpRouters.go b/httpRouters.go
index c49087c..ba3bb82 100644
--- a/httpRouters.go
+++ b/httpRouters.go
@@ -346,7 +346,7 @@ func (a *goBlog) blogEditorRouter(conf *configBlog) func(r chi.Router) {
func (a *goBlog) blogCommentsRouter(conf *configBlog) func(r chi.Router) {
return func(r chi.Router) {
if commentsConfig := conf.Comments; commentsConfig != nil && commentsConfig.Enabled {
- commentsPath := conf.getRelativePath("/comment")
+ commentsPath := conf.getRelativePath(commentPath)
r.Route(commentsPath, func(r chi.Router) {
r.Use(
a.privateModeHandler,
diff --git a/utils.go b/utils.go
index 8fd4261..5563534 100644
--- a/utils.go
+++ b/utils.go
@@ -22,15 +22,21 @@ import (
type contextKey string
func urlize(str string) string {
- var sb strings.Builder
- for _, c := range strings.ToLower(str) {
- if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' {
- _, _ = sb.WriteRune(c)
+ return strings.Map(func(c rune) rune {
+ if c >= 'a' && c <= 'z' || c >= '0' && c <= '9' {
+ // Is lower case ASCII or number, return unmodified
+ return c
+ } else if c >= 'A' && c <= 'Z' {
+ // Is upper case ASCII, make lower case
+ return c + 'a' - 'A'
} else if c == ' ' {
- _, _ = sb.WriteRune('-')
+ // Space, replace with '-'
+ return '-'
+ } else {
+ // Drop character
+ return -1
}
- }
- return sb.String()
+ }, str)
}
func sortedStrings(s []string) []string {
diff --git a/utils_test.go b/utils_test.go
index 4f27dda..b6dc9c4 100644
--- a/utils_test.go
+++ b/utils_test.go
@@ -1,63 +1,66 @@
package main
import (
- "reflect"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func Test_urlize(t *testing.T) {
- if res := urlize("äbc ef"); res != "bc-ef" {
- t.Errorf("Wrong result, got: %v", res)
+ assert.Equal(t, "bc-ef", urlize("äbc ef"))
+ assert.Equal(t, "this-is-a-test", urlize("This Is A Test"))
+}
+
+func Benchmark_urlize(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ urlize("äbc ef")
}
}
func Test_sortedStrings(t *testing.T) {
- input := []string{"a", "c", "b"}
- if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) {
- t.Errorf("Wrong result, got: %v", res)
- }
+ assert.Equal(t, []string{"a", "b", "c"}, sortedStrings([]string{"a", "c", "b"}))
}
func Test_generateRandomString(t *testing.T) {
- if l := len(generateRandomString(30)); l != 30 {
- t.Errorf("Wrong length: %v", l)
- }
+ assert.Len(t, generateRandomString(30), 30)
+ assert.Len(t, generateRandomString(50), 50)
}
func Test_isAbsoluteURL(t *testing.T) {
- if isAbsoluteURL("http://example.com") != true {
- t.Error("Wrong result")
- }
-
- if isAbsoluteURL("https://example.com") != true {
- t.Error("Wrong result")
- }
-
- if isAbsoluteURL("/test") != false {
- t.Error("Wrong result")
- }
+ assert.True(t, isAbsoluteURL("http://example.com"))
+ assert.True(t, isAbsoluteURL("https://example.com"))
+ assert.False(t, isAbsoluteURL("/test"))
}
func Test_wordCount(t *testing.T) {
assert.Equal(t, 3, wordCount("abc def abc"))
}
+func Benchmark_wordCount(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ wordCount("abc def abc")
+ }
+}
+
func Test_charCount(t *testing.T) {
assert.Equal(t, 4, charCount(" t e\n s t €.☺️"))
}
+func Benchmark_charCount(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ charCount(" t e\n s t €.☺️")
+ }
+}
+
func Test_allLinksFromHTMLString(t *testing.T) {
baseUrl := "https://example.net/post/abc"
html := `TestTestTestTest`
expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"}
- if result, err := allLinksFromHTMLString(html, baseUrl); err != nil {
- t.Errorf("Got error: %v", err)
- } else if !reflect.DeepEqual(result, expected) {
- t.Errorf("Wrong result, got: %v", result)
- }
+ result, err := allLinksFromHTMLString(html, baseUrl)
+ require.NoError(t, err)
+ assert.Equal(t, expected, result)
}
func Test_urlHasExt(t *testing.T) {
@@ -77,3 +80,8 @@ func Test_cleanHTMLText(t *testing.T) {
assert.Equal(t, `"This is a 'test'" 😁`, cleanHTMLText(`"This is a 'test'" 😁`))
assert.Equal(t, `Test`, cleanHTMLText(`Test`))
}
+
+func Test_containsStrings(t *testing.T) {
+ assert.True(t, containsStrings("Test", "xx", "es", "st"))
+ assert.False(t, containsStrings("Test", "xx", "aa"))
+}