diff --git a/activityPub.go b/activityPub.go index c2546b7..1f26f62 100644 --- a/activityPub.go +++ b/activityPub.go @@ -114,7 +114,7 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) { } blogIri := a.apIri(blog) // Verify request - requestActor, requestKey, requestActorStatus, err := apVerifySignature(r) + requestActor, requestKey, requestActorStatus, err := a.apVerifySignature(r) if err != nil { // Send 401 because signature could not be verified a.serveError(w, r, err.Error(), http.StatusUnauthorized) @@ -217,14 +217,14 @@ func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func apVerifySignature(r *http.Request) (*asPerson, string, int, error) { +func (a *goBlog) apVerifySignature(r *http.Request) (*asPerson, string, int, error) { verifier, err := httpsig.NewVerifier(r) if err != nil { // Error with signature header etc. return nil, "", 0, err } keyID := verifier.KeyId() - actor, statusCode, err := apGetRemoteActor(keyID) + actor, statusCode, err := a.apGetRemoteActor(keyID) if err != nil || actor == nil || statusCode != 0 { // Actor not found or something else bad return nil, keyID, statusCode, err @@ -249,14 +249,14 @@ func handleWellKnownHostMeta(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(``)) } -func apGetRemoteActor(iri string) (*asPerson, int, error) { +func (a *goBlog) apGetRemoteActor(iri string) (*asPerson, int, error) { req, err := http.NewRequest(http.MethodGet, iri, nil) if err != nil { return nil, 0, err } req.Header.Set("Accept", contenttype.AS) req.Header.Set(userAgent, appUserAgent) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, 0, err } @@ -367,7 +367,7 @@ func (a *goBlog) apAccept(blogName string, blog *configBlog, follow map[string]i // actor and object are equal return } - follower, status, err := apGetRemoteActor(newFollower) + follower, status, err := a.apGetRemoteActor(newFollower) if err != nil || status != 0 { // Couldn't retrieve remote actor info log.Println("Failed to retrieve remote actor info:", newFollower) diff --git a/activityPubSending.go b/activityPubSending.go index 3fe4663..30b2087 100644 --- a/activityPubSending.go +++ b/activityPubSending.go @@ -114,7 +114,7 @@ func (a *goBlog) apSendSigned(blogIri, to string, activity []byte) error { return err } // Do request - resp, err := appHttpClient.Do(r) + resp, err := a.httpClient.Do(r) if err != nil { return err } diff --git a/activityStreams.go b/activityStreams.go index 5a20e01..9680ced 100644 --- a/activityStreams.go +++ b/activityStreams.go @@ -17,17 +17,18 @@ const asContext = "https://www.w3.org/ns/activitystreams" const asRequestKey requestContextKey = "asRequest" -var asCheckMediaTypes = []ct.MediaType{ - ct.NewMediaType(contenttype.HTML), - ct.NewMediaType(contenttype.AS), - ct.NewMediaType(contenttype.LDJSON), -} - func (a *goBlog) checkActivityStreamsRequest(next http.Handler) http.Handler { + if len(a.asCheckMediaTypes) == 0 { + a.asCheckMediaTypes = []ct.MediaType{ + ct.NewMediaType(contenttype.HTML), + ct.NewMediaType(contenttype.AS), + ct.NewMediaType(contenttype.LDJSON), + } + } return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if ap := a.cfg.ActivityPub; ap != nil && ap.Enabled { // Check if accepted media type is not HTML - if mt, _, err := ct.GetAcceptableMediaType(r, asCheckMediaTypes); err == nil && mt.String() != asCheckMediaTypes[0].String() { + if mt, _, err := ct.GetAcceptableMediaType(r, a.asCheckMediaTypes); err == nil && mt.String() != a.asCheckMediaTypes[0].String() { next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), asRequestKey, true))) return } diff --git a/app.go b/app.go index 3b1495e..59744d9 100644 --- a/app.go +++ b/app.go @@ -9,6 +9,7 @@ import ( "git.jlel.se/jlelse/GoBlog/pkgs/minify" shutdowner "git.jlel.se/jlelse/go-shutdowner" ts "git.jlel.se/jlelse/template-strings" + ct "github.com/elnormous/contenttype" "github.com/go-chi/chi/v5" "github.com/go-fed/httpsig" rotatelogs "github.com/lestrrat-go/file-rotatelogs" @@ -23,6 +24,8 @@ type goBlog struct { apPostSignMutex sync.Mutex webfingerResources map[string]*configBlog webfingerAccts map[string]string + // ActivityStreams + asCheckMediaTypes []ct.MediaType // Assets assetFileNames map[string]string assetFiles map[string]*assetFile @@ -36,6 +39,8 @@ type goBlog struct { cfg *config // Database db *database + // Errors + errorCheckMediaTypes []ct.MediaType // Hooks pPostHooks []postHookFunc pUpdateHooks []postHookFunc @@ -43,6 +48,8 @@ type goBlog struct { hourlyHooks []hourlyHookFunc // HTTP cspDomains string + // HTTP Client + httpClient httpClient // HTTP Routers d *dynamicHandler privateMode bool diff --git a/blogroll.go b/blogroll.go index cbdee07..bcdc78e 100644 --- a/blogroll.go +++ b/blogroll.go @@ -78,7 +78,7 @@ func (a *goBlog) getBlogrollOutlines(blog string) ([]*opml.Outline, error) { if config.AuthHeader != "" && config.AuthValue != "" { req.Header.Set(config.AuthHeader, config.AuthValue) } - res, err := appHttpClient.Do(req) + res, err := a.httpClient.Do(req) if err != nil { return nil, err } diff --git a/captcha.go b/captcha.go index f47ad1c..fe6c138 100644 --- a/captcha.go +++ b/captcha.go @@ -22,7 +22,6 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { } } // 2. Show Captcha - w.WriteHeader(http.StatusUnauthorized) h, _ := json.Marshal(r.Header.Clone()) b, _ := io.ReadAll(io.LimitReader(r.Body, 2000000)) // Only allow 20 Megabyte _ = r.Body.Close() @@ -31,7 +30,7 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { _ = r.ParseForm() b = []byte(r.PostForm.Encode()) } - a.render(w, r, templateCaptcha, &renderData{ + a.renderWithStatusCode(w, r, http.StatusUnauthorized, templateCaptcha, &renderData{ Data: map[string]string{ "captchamethod": r.Method, "captchaheaders": base64.StdEncoding.EncodeToString(h), diff --git a/captcha_test.go b/captcha_test.go new file mode 100644 index 0000000..abc3372 --- /dev/null +++ b/captcha_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.jlel.se/jlelse/GoBlog/pkgs/contenttype" + "github.com/stretchr/testify/assert" +) + +func Test_captchaMiddleware(t *testing.T) { + app := &goBlog{ + cfg: &config{ + Server: &configServer{ + PublicAddress: "https://example.com", + }, + Blogs: map[string]*configBlog{ + "en": { + Lang: "en", + }, + }, + DefaultBlog: "en", + User: &configUser{}, + }, + } + + app.setInMemoryDatabase() + app.initSessions() + _ = app.initTemplateStrings() + _ = app.initRendering() + + h := app.captchaMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write([]byte("ABC Test")) + })) + + t.Run("Default", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/abc", nil) + + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + res := rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Contains(t, res.Header.Get("Content-Type"), contenttype.HTML) + assert.Contains(t, resString, "name=captchamethod value=POST") + }) + + t.Run("Captcha session", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/abc", nil) + rec1 := httptest.NewRecorder() + + session, _ := app.captchaSessions.Get(req, "c") + session.Values["captcha"] = true + session.Save(req, rec1) + + for _, cookie := range rec1.Result().Cookies() { + req.AddCookie(cookie) + } + + rec2 := httptest.NewRecorder() + + h.ServeHTTP(rec2, req) + + res := rec2.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Contains(t, resString, "ABC Test") + }) +} diff --git a/errors.go b/errors.go index 8c93f2a..c8b5f65 100644 --- a/errors.go +++ b/errors.go @@ -21,24 +21,24 @@ func (a *goBlog) serveNotAllowed(w http.ResponseWriter, r *http.Request) { a.serveError(w, r, "", http.StatusMethodNotAllowed) } -var errorCheckMediaTypes = []ct.MediaType{ - ct.NewMediaType(contenttype.HTML), -} - func (a *goBlog) serveError(w http.ResponseWriter, r *http.Request, message string, status int) { - if mt, _, err := ct.GetAcceptableMediaType(r, errorCheckMediaTypes); err != nil || mt.String() != errorCheckMediaTypes[0].String() { + // Init the first time + if len(a.errorCheckMediaTypes) == 0 { + a.errorCheckMediaTypes = append(a.errorCheckMediaTypes, ct.NewMediaType(contenttype.HTML)) + } + // Check message + if message == "" { + message = http.StatusText(status) + } + // Check if request accepts HTML + if mt, _, err := ct.GetAcceptableMediaType(r, a.errorCheckMediaTypes); err != nil || mt.String() != a.errorCheckMediaTypes[0].String() { // Request doesn't accept HTML http.Error(w, message, status) return } - title := fmt.Sprintf("%d %s", status, http.StatusText(status)) - if message == "" { - message = http.StatusText(status) - } - w.WriteHeader(status) - a.render(w, r, templateError, &renderData{ + a.renderWithStatusCode(w, r, status, templateError, &renderData{ Data: &errorData{ - Title: title, + Title: fmt.Sprintf("%d %s", status, http.StatusText(status)), Message: message, }, }) diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..33ffabe --- /dev/null +++ b/errors_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.jlel.se/jlelse/GoBlog/pkgs/contenttype" + "github.com/stretchr/testify/assert" +) + +func Test_errors(t *testing.T) { + app := &goBlog{ + cfg: &config{ + Server: &configServer{ + PublicAddress: "https://example.com", + }, + Blogs: map[string]*configBlog{ + "en": { + Lang: "en", + }, + }, + DefaultBlog: "en", + User: &configUser{}, + }, + } + + app.initMarkdown() + _ = app.initTemplateStrings() + _ = app.initRendering() + + t.Run("Test 404, no HTML", func(t *testing.T) { + h := http.HandlerFunc(app.serve404) + + req := httptest.NewRequest(http.MethodGet, "/abc", nil) + req.Header.Set("Accept", contenttype.JSON) + + rec := httptest.NewRecorder() + + h(rec, req) + + res := rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusNotFound, res.StatusCode) + assert.Contains(t, resString, "not found") + assert.Contains(t, res.Header.Get("Content-Type"), "text/plain") + }) + + t.Run("Test 404, HTML", func(t *testing.T) { + h := http.HandlerFunc(app.serve404) + + req := httptest.NewRequest(http.MethodGet, "/abc", nil) + req.Header.Set("Accept", contenttype.HTML) + + rec := httptest.NewRecorder() + + h(rec, req) + + res := rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusNotFound, res.StatusCode) + assert.Contains(t, resString, "not found") + assert.Contains(t, res.Header.Get("Content-Type"), contenttype.HTML) + }) + + t.Run("Test Method Not Allowed, no HTML", func(t *testing.T) { + h := http.HandlerFunc(app.serveNotAllowed) + + req := httptest.NewRequest(http.MethodGet, "/abc", nil) + req.Header.Set("Accept", contenttype.JSON) + + rec := httptest.NewRecorder() + + h(rec, req) + + res := rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode) + assert.Contains(t, resString, "Method Not Allowed") + assert.Contains(t, res.Header.Get("Content-Type"), "text/plain") + }) + + t.Run("Test Method Not Allowed", func(t *testing.T) { + h := http.HandlerFunc(app.serveNotAllowed) + + req := httptest.NewRequest(http.MethodGet, "/abc", nil) + req.Header.Set("Accept", contenttype.HTML) + + rec := httptest.NewRecorder() + + h(rec, req) + + res := rec.Result() + resBody, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + resString := string(resBody) + + assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode) + assert.Contains(t, resString, "Method Not Allowed") + assert.Contains(t, res.Header.Get("Content-Type"), contenttype.HTML) + }) +} diff --git a/geo.go b/geo.go index 66405ea..9664975 100644 --- a/geo.go +++ b/geo.go @@ -12,11 +12,11 @@ import ( "github.com/thoas/go-funk" ) -func (db *database) geoTitle(g *gogeouri.Geo, lang string) string { +func (a *goBlog) geoTitle(g *gogeouri.Geo, lang string) string { if name, ok := g.Parameters["name"]; ok && len(name) > 0 && name[0] != "" { return name[0] } - ba, err := db.photonReverse(g.Latitude, g.Longitude, lang) + ba, err := a.photonReverse(g.Latitude, g.Longitude, lang) if err != nil { return "" } @@ -32,9 +32,9 @@ func (db *database) geoTitle(g *gogeouri.Geo, lang string) string { return strings.Join(funk.FilterString([]string{name, city, state, country}, func(s string) bool { return s != "" }), ", ") } -func (db *database) photonReverse(lat, lon float64, lang string) ([]byte, error) { +func (a *goBlog) photonReverse(lat, lon float64, lang string) ([]byte, error) { cacheKey := fmt.Sprintf("photon-%v-%v-%v", lat, lon, lang) - cache, _ := db.retrievePersistentCache(cacheKey) + cache, _ := a.db.retrievePersistentCache(cacheKey) if cache != nil { return cache, nil } @@ -51,7 +51,7 @@ func (db *database) photonReverse(lat, lon float64, lang string) ([]byte, error) return nil, err } req.Header.Set(userAgent, appUserAgent) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (db *database) photonReverse(lat, lon float64, lang string) ([]byte, error) if err != nil { return nil, err } - _ = db.cachePersistently(cacheKey, ba) + _ = a.db.cachePersistently(cacheKey, ba) return ba, nil } diff --git a/healthcheck.go b/healthcheck.go index 16f562c..00e56b3 100644 --- a/healthcheck.go +++ b/healthcheck.go @@ -12,7 +12,7 @@ func (a *goBlog) healthcheck() bool { fmt.Println(err.Error()) return false } - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { fmt.Println(err.Error()) return false diff --git a/httpClient.go b/httpClient.go index d674dc5..b7b76d5 100644 --- a/httpClient.go +++ b/httpClient.go @@ -9,9 +9,15 @@ type httpClient interface { Do(req *http.Request) (*http.Response, error) } -var appHttpClient httpClient = &http.Client{ - Timeout: 5 * time.Minute, - Transport: &http.Transport{ - DisableKeepAlives: true, - }, +func getHTTPClient() httpClient { + return &http.Client{ + Timeout: 5 * time.Minute, + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } +} + +func (a *goBlog) initHTTPClient() { + a.httpClient = getHTTPClient() } diff --git a/httpClient_test.go b/httpClient_test.go index be1e497..fa0e7fb 100644 --- a/httpClient_test.go +++ b/httpClient_test.go @@ -4,32 +4,15 @@ import ( "io" "net/http" "strings" - "sync" ) type fakeHttpClient struct { - req *http.Request - res *http.Response - err error - enabled bool - // internal - alt httpClient - mx sync.Mutex -} - -var fakeAppHttpClient *fakeHttpClient - -func init() { - fakeAppHttpClient = &fakeHttpClient{ - alt: appHttpClient, - } - appHttpClient = fakeAppHttpClient + req *http.Request + res *http.Response + err error } func (c *fakeHttpClient) Do(req *http.Request) (*http.Response, error) { - if !c.enabled { - return c.alt.Do(req) - } c.req = req return c.res, c.err } @@ -49,14 +32,6 @@ func (c *fakeHttpClient) setFakeResponse(statusCode int, body string, err error) } } -func (c *fakeHttpClient) lock(enabled bool) { - c.mx.Lock() - c.clean() - c.enabled = enabled -} - -func (c *fakeHttpClient) unlock() { - c.enabled = false - c.clean() - c.mx.Unlock() +func getFakeHTTPClient() *fakeHttpClient { + return &fakeHttpClient{} } diff --git a/main.go b/main.go index 5e123ab..1e4e312 100644 --- a/main.go +++ b/main.go @@ -47,6 +47,7 @@ func main() { } app := &goBlog{} + app.initHTTPClient() // Initialize config if err = app.initConfig(); err != nil { diff --git a/mediaCompression.go b/mediaCompression.go index 8529d9d..81c2370 100644 --- a/mediaCompression.go +++ b/mediaCompression.go @@ -37,7 +37,7 @@ func (a *goBlog) tinify(url string, config *configMicropubMedia) (location strin } req.SetBasicAuth("api", config.TinifyKey) req.Header.Set(contentType, contenttype.JSON) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return "", err } @@ -64,7 +64,7 @@ func (a *goBlog) tinify(url string, config *configMicropubMedia) (location strin } downloadReq.SetBasicAuth("api", config.TinifyKey) downloadReq.Header.Set(contentType, contenttype.JSON) - downloadResp, err := appHttpClient.Do(downloadReq) + downloadResp, err := a.httpClient.Do(downloadReq) if err != nil { return "", err } @@ -121,7 +121,7 @@ func (a *goBlog) shortPixel(url string, config *configMicropubMedia) (location s if err != nil { return "", err } - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return "", err } @@ -165,7 +165,7 @@ func (a *goBlog) cloudflare(url string) (location string, err error) { if err != nil { return "", err } - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return "", err } diff --git a/micropubMedia.go b/micropubMedia.go index a056a75..089d349 100644 --- a/micropubMedia.go +++ b/micropubMedia.go @@ -106,7 +106,7 @@ func (a *goBlog) serveMicropubMedia(w http.ResponseWriter, r *http.Request) { func (a *goBlog) uploadFile(filename string, f io.Reader) (string, error) { ms := a.cfg.Micropub.MediaStorage if ms != nil && ms.BunnyStorageKey != "" && ms.BunnyStorageName != "" { - return ms.uploadToBunny(filename, f) + return a.uploadToBunny(filename, f) } loc, err := saveMediaFile(filename, f) if err != nil { @@ -118,13 +118,14 @@ func (a *goBlog) uploadFile(filename string, f io.Reader) (string, error) { return a.getFullAddress(loc), nil } -func (config *configMicropubMedia) uploadToBunny(filename string, f io.Reader) (location string, err error) { +func (a *goBlog) uploadToBunny(filename string, f io.Reader) (location string, err error) { + config := a.cfg.Micropub.MediaStorage if config == nil || config.BunnyStorageName == "" || config.BunnyStorageKey == "" || config.MediaURL == "" { return "", errors.New("Bunny storage not completely configured") } req, _ := http.NewRequest(http.MethodPut, fmt.Sprintf("https://storage.bunnycdn.com/%s/%s", url.PathEscape(config.BunnyStorageName), url.PathEscape(filename)), f) req.Header.Add("AccessKey", config.BunnyStorageKey) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return "", err } diff --git a/notifications.go b/notifications.go index 9cd8fe6..49aa393 100644 --- a/notifications.go +++ b/notifications.go @@ -30,7 +30,7 @@ func (a *goBlog) sendNotification(text string) { log.Println("Failed to save notification:", err.Error()) } if an := a.cfg.Notifications; an != nil { - err := an.Telegram.send(n.Text, "") + err := a.send(an.Telegram, n.Text, "") if err != nil { log.Println("Failed to send Telegram notification:", err.Error()) } diff --git a/render.go b/render.go index cd687a8..41f093a 100644 --- a/render.go +++ b/render.go @@ -65,7 +65,7 @@ func (a *goBlog) initRendering() error { "sort": sortedStrings, "absolute": a.getFullAddress, "mentions": a.db.getWebmentionsByAddress, - "geotitle": a.db.geoTitle, + "geotitle": a.geoTitle, "geolink": geoOSMLink, "opensearch": openSearchUrl, } @@ -106,6 +106,10 @@ type renderData struct { } func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, data *renderData) { + a.renderWithStatusCode(w, r, http.StatusOK, template, data) +} + +func (a *goBlog) renderWithStatusCode(w http.ResponseWriter, r *http.Request, statusCode int, template string, data *renderData) { // Server timing t := servertiming.FromContext(r.Context()).NewMetric("r").Start() // Check render data @@ -153,6 +157,7 @@ func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, http.Error(w, err.Error(), http.StatusInternalServerError) return } + w.WriteHeader(statusCode) _, err = a.min.Write(w, contenttype.HTML, tw.Bytes()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/telegram.go b/telegram.go index 3d667db..53171d9 100644 --- a/telegram.go +++ b/telegram.go @@ -17,7 +17,7 @@ func (a *goBlog) initTelegram() { a.pPostHooks = append(a.pPostHooks, func(p *post) { if tg := a.cfg.Blogs[p.Blog].Telegram; tg.enabled() && p.isPublishedSectionPost() { if html := tg.generateHTML(p.Title(), a.fullPostURL(p), a.shortPostURL(p)); html != "" { - if err := tg.send(html, "HTML"); err != nil { + if err := a.send(tg, html, "HTML"); err != nil { log.Printf("Failed to send post to Telegram: %v", err) } } @@ -54,7 +54,7 @@ func (tg *configTelegram) generateHTML(title, fullURL, shortURL string) string { return message.String() } -func (tg *configTelegram) send(message, mode string) error { +func (a *goBlog) send(tg *configTelegram, message, mode string) error { if !tg.enabled() { return nil } @@ -70,7 +70,7 @@ func (tg *configTelegram) send(message, mode string) error { } tgURL.RawQuery = params.Encode() req, _ := http.NewRequest(http.MethodPost, tgURL.String(), nil) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return err } diff --git a/telegram_test.go b/telegram_test.go index 8ef0750..d1b7dbc 100644 --- a/telegram_test.go +++ b/telegram_test.go @@ -4,6 +4,8 @@ import ( "net/http" "testing" "time" + + "github.com/stretchr/testify/assert" ) func Test_configTelegram_enabled(t *testing.T) { @@ -69,8 +71,7 @@ func Test_configTelegram_generateHTML(t *testing.T) { } func Test_configTelegram_send(t *testing.T) { - fakeAppHttpClient.lock(true) - defer fakeAppHttpClient.unlock() + fakeClient := getFakeHTTPClient() tg := &configTelegram{ Enabled: true, @@ -78,25 +79,19 @@ func Test_configTelegram_send(t *testing.T) { BotToken: "bottoken", } - fakeAppHttpClient.setFakeResponse(200, "", nil) - - err := tg.send("Message", "HTML") - if err != nil { - t.Fatalf("Error: %v", err) + app := &goBlog{ + httpClient: fakeClient, } - if fakeAppHttpClient.req == nil { - t.Error("Empty request") - } - if fakeAppHttpClient.err != nil { - t.Error("Error in request") - } - if fakeAppHttpClient.req.Method != http.MethodPost { - t.Error("Wrong method") - } - if u := fakeAppHttpClient.req.URL.String(); u != "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Message" { - t.Errorf("Wrong request URL, got: %v", u) - } + fakeClient.setFakeResponse(200, "", nil) + + err := app.send(tg, "Message", "HTML") + assert.Nil(t, err) + + assert.NotNil(t, fakeClient.req) + assert.Nil(t, fakeClient.err) + assert.Equal(t, http.MethodPost, fakeClient.req.Method) + assert.Equal(t, "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Message", fakeClient.req.URL.String()) } func Test_goBlog_initTelegram(t *testing.T) { @@ -113,10 +108,9 @@ func Test_goBlog_initTelegram(t *testing.T) { func Test_telegram(t *testing.T) { t.Run("Send post to Telegram", func(t *testing.T) { - fakeAppHttpClient.lock(true) - defer fakeAppHttpClient.unlock() + fakeClient := getFakeHTTPClient() - fakeAppHttpClient.setFakeResponse(200, "", nil) + fakeClient.setFakeResponse(200, "", nil) app := &goBlog{ pPostHooks: []postHookFunc{}, @@ -134,6 +128,7 @@ func Test_telegram(t *testing.T) { }, }, }, + httpClient: fakeClient, } app.setInMemoryDatabase() @@ -152,16 +147,17 @@ func Test_telegram(t *testing.T) { app.pPostHooks[0](p) - if u := fakeAppHttpClient.req.URL.String(); u != "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Title%0A%0A%3Ca+href%3D%22https%3A%2F%2Fexample.com%2Fs%2F1%22%3Ehttps%3A%2F%2Fexample.com%2Fs%2F1%3C%2Fa%3E" { - t.Errorf("Wrong request URL, got: %v", u) - } + assert.Equal( + t, + "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Title%0A%0A%3Ca+href%3D%22https%3A%2F%2Fexample.com%2Fs%2F1%22%3Ehttps%3A%2F%2Fexample.com%2Fs%2F1%3C%2Fa%3E", + fakeClient.req.URL.String(), + ) }) t.Run("Telegram disabled", func(t *testing.T) { - fakeAppHttpClient.lock(true) - defer fakeAppHttpClient.unlock() + fakeClient := getFakeHTTPClient() - fakeAppHttpClient.setFakeResponse(200, "", nil) + fakeClient.setFakeResponse(200, "", nil) app := &goBlog{ pPostHooks: []postHookFunc{}, @@ -173,6 +169,7 @@ func Test_telegram(t *testing.T) { "en": {}, }, }, + httpClient: fakeClient, } app.setInMemoryDatabase() @@ -191,8 +188,6 @@ func Test_telegram(t *testing.T) { app.pPostHooks[0](p) - if fakeAppHttpClient.req != nil { - t.Error("There should be no request") - } + assert.Nil(t, fakeClient.req) }) } diff --git a/webmentionSending.go b/webmentionSending.go index 9f978d4..92468e2 100644 --- a/webmentionSending.go +++ b/webmentionSending.go @@ -47,11 +47,11 @@ func (a *goBlog) sendWebmentions(p *post) error { // Just ignore the mention continue } - endpoint := discoverEndpoint(link) + endpoint := a.discoverEndpoint(link) if endpoint == "" { continue } - if err = sendWebmention(endpoint, a.fullPostURL(p), link); err != nil { + if err = a.sendWebmention(endpoint, a.fullPostURL(p), link); err != nil { log.Println("Sending webmention to " + link + " failed") continue } @@ -60,7 +60,7 @@ func (a *goBlog) sendWebmentions(p *post) error { return nil } -func sendWebmention(endpoint, source, target string) error { +func (a *goBlog) sendWebmention(endpoint, source, target string) error { req, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(url.Values{ "source": []string{source}, "target": []string{target}, @@ -70,7 +70,7 @@ func sendWebmention(endpoint, source, target string) error { } req.Header.Set(contentType, contenttype.WWWForm) req.Header.Set(userAgent, appUserAgent) - res, err := appHttpClient.Do(req) + res, err := a.httpClient.Do(req) if err != nil { return err } @@ -82,14 +82,14 @@ func sendWebmention(endpoint, source, target string) error { return nil } -func discoverEndpoint(urlStr string) string { +func (a *goBlog) discoverEndpoint(urlStr string) string { doRequest := func(method, urlStr string) string { req, err := http.NewRequest(method, urlStr, nil) if err != nil { return "" } req.Header.Set(userAgent, appUserAgent) - resp, err := appHttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return "" } diff --git a/webmentionVerification.go b/webmentionVerification.go index 669e5be..415538d 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -78,7 +78,7 @@ func (a *goBlog) verifyMention(m *mention) error { resp = rec.Result() } else { req.Header.Set(userAgent, appUserAgent) - resp, err = appHttpClient.Do(req) + resp, err = a.httpClient.Do(req) if err != nil { return err }