From 9e423526bd532809e3abc8a76413bf42cf3639e8 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Wed, 1 Sep 2021 16:38:08 +0200 Subject: [PATCH] Various improvements & more tests --- authentication_test.go | 2 +- comments.go | 20 +++-- comments_test.go | 175 +++++++++++++++++++++++++++++++++++++++++ httpRouters.go | 2 +- utils.go | 20 +++-- utils_test.go | 60 ++++++++------ 6 files changed, 233 insertions(+), 46 deletions(-) create mode 100644 comments_test.go diff --git a/authentication_test.go b/authentication_test.go index 1425856..43416df 100644 --- a/authentication_test.go +++ b/authentication_test.go @@ -122,7 +122,7 @@ func Test_authMiddleware(t *testing.T) { data.Add("password", "pass") req := httptest.NewRequest(http.MethodPost, "/abc", strings.NewReader(data.Encode())) - req.Header.Add("Content-Type", contenttype.WWWForm) + req.Header.Add(contentType, contenttype.WWWForm) rec := httptest.NewRecorder() diff --git a/comments.go b/comments.go index dad5f2a..cdd3a6e 100644 --- a/comments.go +++ b/comments.go @@ -2,16 +2,17 @@ package main import ( "database/sql" - "fmt" "net/http" "net/url" + "path" "strconv" "strings" "github.com/go-chi/chi/v5" - "github.com/microcosm-cc/bluemonday" ) +const commentPath = "/comment" + type comment struct { ID int Target string @@ -42,7 +43,7 @@ func (a *goBlog) serveComment(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogKey).(string) a.render(w, r, templateComment, &renderData{ BlogString: blog, - Canonical: a.getFullAddress(a.cfg.Blogs[blog].getRelativePath(fmt.Sprintf("/comment/%d", id))), + Canonical: a.getFullAddress(a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(id)))), Data: comment, }) } @@ -54,17 +55,13 @@ func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) { return } // Check and clean comment - strict := bluemonday.StrictPolicy() - comment := strings.TrimSpace(strict.Sanitize(r.FormValue("comment"))) + comment := cleanHTMLText(r.FormValue("comment")) if comment == "" { a.serveError(w, r, "Comment is empty", http.StatusBadRequest) return } - name := strings.TrimSpace(strict.Sanitize(r.FormValue("name"))) - if name == "" { - name = "Anonymous" - } - website := strings.TrimSpace(strict.Sanitize(r.FormValue("website"))) + name := defaultIfEmpty(cleanHTMLText(r.FormValue("name")), "Anonymous") + website := cleanHTMLText(r.FormValue("website")) // Insert result, err := a.db.exec("insert into comments (target, comment, name, website) values (@target, @comment, @name, @website)", sql.Named("target", target), sql.Named("comment", comment), sql.Named("name", name), sql.Named("website", website)) if err != nil { @@ -75,7 +72,8 @@ func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) { // Serve error a.serveError(w, r, err.Error(), http.StatusInternalServerError) } else { - commentAddress := fmt.Sprintf("%s/%d", a.getRelativePath(r.Context().Value(blogKey).(string), "/comment"), commentID) + blog := r.Context().Value(blogKey).(string) + commentAddress := a.getRelativePath(blog, path.Join(commentPath, strconv.Itoa(int(commentID)))) // Send webmention _ = a.createWebmention(a.getFullAddress(commentAddress), a.getFullAddress(target)) // Redirect to comment diff --git a/comments_test.go b/comments_test.go new file mode 100644 index 0000000..2108675 --- /dev/null +++ b/comments_test.go @@ -0,0 +1,175 @@ +package main + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.goblog.app/app/pkgs/contenttype" +) + +func Test_comments(t *testing.T) { + app := &goBlog{ + cfg: &config{ + Db: &configDb{ + File: filepath.Join(t.TempDir(), "test.db"), + }, + Server: &configServer{ + PublicAddress: "https://example.com", + }, + Blogs: map[string]*configBlog{ + "en": { + Lang: "en", + }, + }, + DefaultBlog: "en", + User: &configUser{}, + }, + } + + _ = app.initDatabase(false) + app.initComponents() + + t.Run("Successful comment", func(t *testing.T) { + + // Create comment + + data := url.Values{} + data.Add("target", "https://example.com/test") + data.Add("comment", "This is just a test") + data.Add("name", "Test name") + data.Add("website", "https://goblog.app") + + req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode())) + req.Header.Add(contentType, contenttype.WWWForm) + rec := httptest.NewRecorder() + + app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en"))) + + res := rec.Result() + + assert.Equal(t, http.StatusFound, res.StatusCode) + assert.Equal(t, "/comment/1", res.Header.Get("Location")) + + // View comment + + mux := chi.NewMux() + mux.Use(middleware.WithValue(blogKey, "en")) + mux.Get("/comment/{id}", app.serveComment) + + req = httptest.NewRequest(http.MethodGet, "/comment/1", nil) + rec = httptest.NewRecorder() + + mux.ServeHTTP(rec, req) + + res = rec.Result() + resBody, _ := io.ReadAll(res.Body) + resBodyStr := string(resBody) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Contains(t, resBodyStr, "https://goblog.app") + assert.Contains(t, resBodyStr, "Test name") + assert.Contains(t, resBodyStr, "This is just a test") + assert.Contains(t, resBodyStr, "/test") + + // Count comments + + cc, err := app.db.countComments(&commentsRequestConfig{ + limit: 100, + offset: 0, + }) + require.NoError(t, err) + assert.Equal(t, 1, cc) + + // Get comment + + comments, err := app.db.getComments(&commentsRequestConfig{}) + require.NoError(t, err) + if assert.Len(t, comments, 1) { + comment := comments[0] + assert.Equal(t, "https://goblog.app", comment.Website) + assert.Equal(t, "Test name", comment.Name) + assert.Equal(t, "This is just a test", comment.Comment) + assert.Equal(t, "/test", comment.Target) + } + + // Delete comment + + err = app.db.deleteComment(1) + require.NoError(t, err) + cc, err = app.db.countComments(&commentsRequestConfig{}) + require.NoError(t, err) + assert.Equal(t, 0, cc) + + }) + + t.Run("Anonymous comment", func(t *testing.T) { + + // Create comment + + data := url.Values{} + data.Add("target", "https://example.com/test") + data.Add("comment", "This is just a test") + + req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode())) + req.Header.Add(contentType, contenttype.WWWForm) + rec := httptest.NewRecorder() + + app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en"))) + + res := rec.Result() + + assert.Equal(t, http.StatusFound, res.StatusCode) + assert.Equal(t, "/comment/2", res.Header.Get("Location")) + + // Get comment + + comments, err := app.db.getComments(&commentsRequestConfig{}) + require.NoError(t, err) + if assert.Len(t, comments, 1) { + comment := comments[0] + assert.Equal(t, "/test", comment.Target) + assert.Equal(t, "This is just a test", comment.Comment) + assert.Equal(t, "Anonymous", comment.Name) + assert.Equal(t, "", comment.Website) + } + + // Delete comment + + err = app.db.deleteComment(2) + require.NoError(t, err) + + }) + + t.Run("Empty comment", func(t *testing.T) { + + data := url.Values{} + data.Add("target", "https://example.com/test") + data.Add("comment", "") + + req := httptest.NewRequest(http.MethodPost, commentPath, strings.NewReader(data.Encode())) + req.Header.Add(contentType, contenttype.WWWForm) + rec := httptest.NewRecorder() + + app.createComment(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en"))) + + res := rec.Result() + + assert.Equal(t, http.StatusBadRequest, res.StatusCode) + + cc, err := app.db.countComments(&commentsRequestConfig{}) + require.NoError(t, err) + assert.Equal(t, 0, cc) + + }) + +} diff --git a/httpRouters.go b/httpRouters.go index c49087c..ba3bb82 100644 --- a/httpRouters.go +++ b/httpRouters.go @@ -346,7 +346,7 @@ func (a *goBlog) blogEditorRouter(conf *configBlog) func(r chi.Router) { func (a *goBlog) blogCommentsRouter(conf *configBlog) func(r chi.Router) { return func(r chi.Router) { if commentsConfig := conf.Comments; commentsConfig != nil && commentsConfig.Enabled { - commentsPath := conf.getRelativePath("/comment") + commentsPath := conf.getRelativePath(commentPath) r.Route(commentsPath, func(r chi.Router) { r.Use( a.privateModeHandler, diff --git a/utils.go b/utils.go index 8fd4261..5563534 100644 --- a/utils.go +++ b/utils.go @@ -22,15 +22,21 @@ import ( type contextKey string func urlize(str string) string { - var sb strings.Builder - for _, c := range strings.ToLower(str) { - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' { - _, _ = sb.WriteRune(c) + return strings.Map(func(c rune) rune { + if c >= 'a' && c <= 'z' || c >= '0' && c <= '9' { + // Is lower case ASCII or number, return unmodified + return c + } else if c >= 'A' && c <= 'Z' { + // Is upper case ASCII, make lower case + return c + 'a' - 'A' } else if c == ' ' { - _, _ = sb.WriteRune('-') + // Space, replace with '-' + return '-' + } else { + // Drop character + return -1 } - } - return sb.String() + }, str) } func sortedStrings(s []string) []string { diff --git a/utils_test.go b/utils_test.go index 4f27dda..b6dc9c4 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,63 +1,66 @@ package main import ( - "reflect" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_urlize(t *testing.T) { - if res := urlize("äbc ef"); res != "bc-ef" { - t.Errorf("Wrong result, got: %v", res) + assert.Equal(t, "bc-ef", urlize("äbc ef")) + assert.Equal(t, "this-is-a-test", urlize("This Is A Test")) +} + +func Benchmark_urlize(b *testing.B) { + for i := 0; i < b.N; i++ { + urlize("äbc ef") } } func Test_sortedStrings(t *testing.T) { - input := []string{"a", "c", "b"} - if res := sortedStrings(input); !reflect.DeepEqual(res, []string{"a", "b", "c"}) { - t.Errorf("Wrong result, got: %v", res) - } + assert.Equal(t, []string{"a", "b", "c"}, sortedStrings([]string{"a", "c", "b"})) } func Test_generateRandomString(t *testing.T) { - if l := len(generateRandomString(30)); l != 30 { - t.Errorf("Wrong length: %v", l) - } + assert.Len(t, generateRandomString(30), 30) + assert.Len(t, generateRandomString(50), 50) } func Test_isAbsoluteURL(t *testing.T) { - if isAbsoluteURL("http://example.com") != true { - t.Error("Wrong result") - } - - if isAbsoluteURL("https://example.com") != true { - t.Error("Wrong result") - } - - if isAbsoluteURL("/test") != false { - t.Error("Wrong result") - } + assert.True(t, isAbsoluteURL("http://example.com")) + assert.True(t, isAbsoluteURL("https://example.com")) + assert.False(t, isAbsoluteURL("/test")) } func Test_wordCount(t *testing.T) { assert.Equal(t, 3, wordCount("abc def abc")) } +func Benchmark_wordCount(b *testing.B) { + for i := 0; i < b.N; i++ { + wordCount("abc def abc") + } +} + func Test_charCount(t *testing.T) { assert.Equal(t, 4, charCount(" t e\n s t €.☺️")) } +func Benchmark_charCount(b *testing.B) { + for i := 0; i < b.N; i++ { + charCount(" t e\n s t €.☺️") + } +} + func Test_allLinksFromHTMLString(t *testing.T) { baseUrl := "https://example.net/post/abc" html := `TestTestTestTest` expected := []string{"https://example.net/post/relative1", "https://example.net/relative2", "https://example.com"} - if result, err := allLinksFromHTMLString(html, baseUrl); err != nil { - t.Errorf("Got error: %v", err) - } else if !reflect.DeepEqual(result, expected) { - t.Errorf("Wrong result, got: %v", result) - } + result, err := allLinksFromHTMLString(html, baseUrl) + require.NoError(t, err) + assert.Equal(t, expected, result) } func Test_urlHasExt(t *testing.T) { @@ -77,3 +80,8 @@ func Test_cleanHTMLText(t *testing.T) { assert.Equal(t, `"This is a 'test'" 😁`, cleanHTMLText(`"This is a 'test'" 😁`)) assert.Equal(t, `Test`, cleanHTMLText(`Test`)) } + +func Test_containsStrings(t *testing.T) { + assert.True(t, containsStrings("Test", "xx", "es", "st")) + assert.False(t, containsStrings("Test", "xx", "aa")) +}