diff --git a/app.go b/app.go index 9ce8d9f..0cf5897 100644 --- a/app.go +++ b/app.go @@ -48,7 +48,7 @@ type goBlog struct { pDeleteHooks []postHookFunc hourlyHooks []hourlyHookFunc // HTTP Client - httpClient httpClient + httpClient *http.Client // HTTP Routers d http.Handler // IndieAuth diff --git a/blogroll_test.go b/blogroll_test.go index aa2172a..354fc58 100644 --- a/blogroll_test.go +++ b/blogroll_test.go @@ -12,10 +12,10 @@ import ( func Test_blogroll(t *testing.T) { - fc := &fakeHttpClient{} + fc := newFakeHttpClient() app := &goBlog{ - httpClient: fc, + httpClient: fc.Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"), diff --git a/geoTiles_test.go b/geoTiles_test.go index 4885ee0..99cae16 100644 --- a/geoTiles_test.go +++ b/geoTiles_test.go @@ -14,12 +14,11 @@ func Test_proxyTiles(t *testing.T) { cfg: &config{}, } - hc := &fakeHttpClient{ - handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("Hello, World!")) - }), - } - app.httpClient = hc + hc := newFakeHttpClient() + hc.setHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("Hello, World!")) + })) + app.httpClient = hc.Client // Default tile source diff --git a/go.mod b/go.mod index f879956..630939b 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/emersion/go-smtp v0.15.0 github.com/go-chi/chi/v5 v5.0.7 github.com/go-fed/httpsig v1.1.0 + github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.4.0 github.com/google/uuid v1.3.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/securecookie v1.1.1 diff --git a/go.sum b/go.sum index 4136fb8..ce50958 100644 --- a/go.sum +++ b/go.sum @@ -169,6 +169,8 @@ github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1 github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.4.0 h1:Mr3JcvBjQEhCN9wld6OHKHuHxWaoXTaQfYKmj7QwP18= +github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.4.0/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= diff --git a/httpClient.go b/httpClient.go index e183855..1dea369 100644 --- a/httpClient.go +++ b/httpClient.go @@ -5,24 +5,11 @@ import ( "time" ) -type httpClient interface { - Do(req *http.Request) (*http.Response, error) -} - -type appHttpClient struct { - hc *http.Client -} - -var _ httpClient = &appHttpClient{} - -func (c *appHttpClient) Do(req *http.Request) (*http.Response, error) { - if c.hc == nil { - c.hc = &http.Client{ - Timeout: 5 * time.Minute, - Transport: &http.Transport{ - DisableKeepAlives: true, - }, - } +func newHttpClient() *http.Client { + return &http.Client{ + Timeout: 5 * time.Minute, + Transport: &http.Transport{ + DisableKeepAlives: true, + }, } - return c.hc.Do(req) } diff --git a/httpClient_test.go b/httpClient_test.go index b236f91..9373363 100644 --- a/httpClient_test.go +++ b/httpClient_test.go @@ -1,28 +1,40 @@ package main import ( + "io" "net/http" "net/http/httptest" ) type fakeHttpClient struct { - httpClient + *http.Client req *http.Request res *http.Response handler http.Handler } -var _ httpClient = &fakeHttpClient{} - -func (c *fakeHttpClient) Do(req *http.Request) (*http.Response, error) { - if c.handler == nil { - return nil, nil +func newFakeHttpClient() *fakeHttpClient { + fc := &fakeHttpClient{} + fc.Client = &http.Client{ + Transport: &handlerRoundTripper{ + handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + fc.req = r + if fc.handler != nil { + rec := httptest.NewRecorder() + fc.handler.ServeHTTP(rec, r) + fc.res = rec.Result() + // Copy the headers from the response recorder + for k, v := range rec.Header() { + rw.Header()[k] = v + } + // Copy result status code and body + rw.WriteHeader(fc.res.StatusCode) + io.Copy(rw, rec.Body) + } + }), + }, } - rec := httptest.NewRecorder() - c.handler.ServeHTTP(rec, req) - c.req = req - c.res = rec.Result() - return c.res, nil + return fc } func (c *fakeHttpClient) clean() { diff --git a/indieAuthServer_test.go b/indieAuthServer_test.go index 423b5fc..187f8b7 100644 --- a/indieAuthServer_test.go +++ b/indieAuthServer_test.go @@ -21,7 +21,7 @@ func Test_indieAuthServer(t *testing.T) { var err error app := &goBlog{ - httpClient: &fakeHttpClient{}, + httpClient: newFakeHttpClient().Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"), diff --git a/indieAuth_test.go b/indieAuth_test.go index 4ac89d9..68c7879 100644 --- a/indieAuth_test.go +++ b/indieAuth_test.go @@ -14,7 +14,7 @@ import ( func Test_checkIndieAuth(t *testing.T) { app := &goBlog{ - httpClient: &fakeHttpClient{}, + httpClient: newFakeHttpClient().Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"), diff --git a/main.go b/main.go index 14b6abe..e22308f 100644 --- a/main.go +++ b/main.go @@ -53,7 +53,7 @@ func main() { initGC() app := &goBlog{ - httpClient: &appHttpClient{}, + httpClient: newHttpClient(), } // Initialize config diff --git a/mediaCompression.go b/mediaCompression.go index 742ea44..ddba66f 100644 --- a/mediaCompression.go +++ b/mediaCompression.go @@ -16,7 +16,7 @@ const defaultCompressionWidth = 2000 const defaultCompressionHeight = 3000 type mediaCompression interface { - compress(url string, save mediaStorageSaveFunc, hc httpClient) (location string, err error) + compress(url string, save mediaStorageSaveFunc, hc *http.Client) (location string, err error) } func (a *goBlog) compressMediaFile(url string) (location string, err error) { @@ -55,7 +55,7 @@ type shortpixel struct { var _ mediaCompression = &shortpixel{} -func (sp *shortpixel) compress(url string, upload mediaStorageSaveFunc, hc httpClient) (location string, err error) { +func (sp *shortpixel) compress(url string, upload mediaStorageSaveFunc, hc *http.Client) (location string, err error) { // Check url fileExtension, allowed := urlHasExt(url, "jpg", "jpeg", "png") if !allowed { @@ -111,7 +111,7 @@ type tinify struct { var _ mediaCompression = &tinify{} -func (tf *tinify) compress(url string, upload mediaStorageSaveFunc, hc httpClient) (location string, err error) { +func (tf *tinify) compress(url string, upload mediaStorageSaveFunc, hc *http.Client) (location string, err error) { // Check url fileExtension, allowed := urlHasExt(url, "jpg", "jpeg", "png") if !allowed { @@ -188,7 +188,7 @@ type cloudflare struct { var _ mediaCompression = &cloudflare{} -func (cf *cloudflare) compress(url string, upload mediaStorageSaveFunc, hc httpClient) (location string, err error) { +func (cf *cloudflare) compress(url string, upload mediaStorageSaveFunc, hc *http.Client) (location string, err error) { // Check url _, allowed := urlHasExt(url, "jpg", "jpeg", "png") if !allowed { diff --git a/mediaCompression_test.go b/mediaCompression_test.go index 0182cdc..df7112f 100644 --- a/mediaCompression_test.go +++ b/mediaCompression_test.go @@ -27,7 +27,7 @@ func Test_compress(t *testing.T) { } t.Run("Cloudflare", func(t *testing.T) { - fakeClient := &fakeHttpClient{} + fakeClient := newFakeHttpClient() fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(t, "https://www.cloudflare.com/cdn-cgi/image/f=jpeg,q=75,metadata=none,fit=scale-down,w=2000,h=3000/https://example.com/original.jpg", r.URL.String()) @@ -36,14 +36,14 @@ func Test_compress(t *testing.T) { })) cf := &cloudflare{} - res, err := cf.compress("https://example.com/original.jpg", uf, fakeClient) + res, err := cf.compress("https://example.com/original.jpg", uf, fakeClient.Client) assert.Nil(t, err) assert.Equal(t, "https://example.com/"+fakeSha256+".jpeg", res) }) t.Run("Shortpixel", func(t *testing.T) { - fakeClient := &fakeHttpClient{} + fakeClient := newFakeHttpClient() fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(t, "https://api.shortpixel.com/v2/reducer-sync.php", r.URL.String()) @@ -63,7 +63,7 @@ func Test_compress(t *testing.T) { })) cf := &shortpixel{"testkey"} - res, err := cf.compress("https://example.com/original.jpg", uf, fakeClient) + res, err := cf.compress("https://example.com/original.jpg", uf, fakeClient.Client) assert.Nil(t, err) assert.Equal(t, "https://example.com/"+fakeSha256+".jpg", res) diff --git a/notifications.go b/notifications.go index 2a1ff3c..4f95afe 100644 --- a/notifications.go +++ b/notifications.go @@ -31,7 +31,7 @@ func (a *goBlog) sendNotification(text string) { log.Println("Failed to save notification:", err.Error()) } if an := a.cfg.Notifications; an != nil { - err := a.send(an.Telegram, n.Text, "") + _, err := a.send(an.Telegram, n.Text, "") if err != nil { log.Println("Failed to send Telegram notification:", err.Error()) } diff --git a/postsDb.go b/postsDb.go index 0dad166..1d768f3 100644 --- a/postsDb.go +++ b/postsDb.go @@ -503,22 +503,6 @@ func (d *database) countPosts(config *postsRequestConfig) (count int, err error) return } -func (d *database) getPostPaths(status postStatus) ([]string, error) { - var postPaths []string - rows, err := d.query("select path from posts where status = @status", sql.Named("status", status)) - if err != nil { - return nil, err - } - var path string - for rows.Next() { - _ = rows.Scan(&path) - if path != "" { - postPaths = append(postPaths, path) - } - } - return postPaths, nil -} - func (a *goBlog) getRandomPostPath(blog string) (path string, err error) { sections, ok := funk.Keys(a.cfg.Blogs[blog].Sections).([]string) if !ok { diff --git a/postsDb_test.go b/postsDb_test.go index 3ead607..d1df932 100644 --- a/postsDb_test.go +++ b/postsDb_test.go @@ -64,17 +64,6 @@ func Test_postsDb(t *testing.T) { is.Equal("Title", p.Title()) is.Equal([]string{"C", "A", "B"}, p.Parameters["tags"]) - // Check number of post paths - pp, err := app.db.getPostPaths(statusDraft) - must.NoError(err) - if is.Len(pp, 1) { - is.Equal("/test/abc", pp[0]) - } - - pp, err = app.db.getPostPaths(statusPublished) - must.NoError(err) - is.Len(pp, 0) - // Check drafts drafts, _ := app.getPosts(&postsRequestConfig{ blog: "en", diff --git a/telegram.go b/telegram.go index 87e07f0..e08af3c 100644 --- a/telegram.go +++ b/telegram.go @@ -2,21 +2,17 @@ package main import ( "bytes" - "errors" - "fmt" "log" - "net/http" "net/url" - "strings" -) -const telegramBaseURL = "https://api.telegram.org/bot" + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" +) func (a *goBlog) initTelegram() { a.pPostHooks = append(a.pPostHooks, func(p *post) { if tg := a.cfg.Blogs[p.Blog].Telegram; tg.enabled() && p.isPublishedSectionPost() { if html := tg.generateHTML(p.RenderedTitle, a.fullPostURL(p), a.shortPostURL(p)); html != "" { - if err := a.send(tg, html, "HTML"); err != nil { + if _, err := a.send(tg, html, "HTML"); err != nil { log.Printf("Failed to send post to Telegram: %v", err) } } @@ -35,47 +31,41 @@ func (tg *configTelegram) generateHTML(title, fullURL, shortURL string) string { if !tg.enabled() { return "" } - replacer := strings.NewReplacer("<", "<", ">", ">", "&", "&") var message bytes.Buffer if title != "" { - message.WriteString(replacer.Replace(title)) + message.WriteString(tgbotapi.EscapeText(tgbotapi.ModeHTML, title)) message.WriteString("\n\n") } if tg.InstantViewHash != "" { message.WriteString("") - message.WriteString(replacer.Replace(shortURL)) + message.WriteString(tgbotapi.EscapeText(tgbotapi.ModeHTML, shortURL)) message.WriteString("") } else { message.WriteString("") - message.WriteString(replacer.Replace(shortURL)) + message.WriteString(tgbotapi.EscapeText(tgbotapi.ModeHTML, shortURL)) message.WriteString("") } return message.String() } -func (a *goBlog) send(tg *configTelegram, message, mode string) error { +func (a *goBlog) send(tg *configTelegram, message, mode string) (int, error) { if !tg.enabled() { - return nil + return 0, nil } - params := url.Values{} - params.Add("chat_id", tg.ChatID) - params.Add("text", message) - if mode != "" { - params.Add("parse_mode", mode) - } - tgURL, err := url.Parse(telegramBaseURL + tg.BotToken + "/sendMessage") + bot, err := tgbotapi.NewBotAPIWithClient(tg.BotToken, tgbotapi.APIEndpoint, a.httpClient) if err != nil { - return errors.New("failed to create Telegram request") + return 0, err } - tgURL.RawQuery = params.Encode() - req, _ := http.NewRequest(http.MethodPost, tgURL.String(), nil) - resp, err := a.httpClient.Do(req) + msg := tgbotapi.MessageConfig{ + BaseChat: tgbotapi.BaseChat{ + ChannelUsername: tg.ChatID, + }, + Text: message, + ParseMode: mode, + } + res, err := bot.Send(msg) if err != nil { - return err + return 0, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to send Telegram message, status code %d", resp.StatusCode) - } - return nil + return res.MessageID, nil } diff --git a/telegram_test.go b/telegram_test.go index 7797357..a723a19 100644 --- a/telegram_test.go +++ b/telegram_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_configTelegram_enabled(t *testing.T) { @@ -72,7 +73,17 @@ func Test_configTelegram_generateHTML(t *testing.T) { } func Test_configTelegram_send(t *testing.T) { - fakeClient := &fakeHttpClient{} + fakeClient := newFakeHttpClient() + + fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" { + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`)) + return + } + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":123456789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`)) + })) tg := &configTelegram{ Enabled: true, @@ -81,17 +92,22 @@ func Test_configTelegram_send(t *testing.T) { } app := &goBlog{ - httpClient: fakeClient, + httpClient: fakeClient.Client, } - fakeClient.setFakeResponse(200, "") + msgId, err := app.send(tg, "Message", "HTML") + require.Nil(t, err) - err := app.send(tg, "Message", "HTML") - assert.Nil(t, err) + assert.Equal(t, 123, msgId) assert.NotNil(t, fakeClient.req) assert.Equal(t, http.MethodPost, fakeClient.req.Method) - assert.Equal(t, "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Message", fakeClient.req.URL.String()) + assert.Equal(t, "https://api.telegram.org/botbottoken/sendMessage", fakeClient.req.URL.String()) + + req := fakeClient.req + assert.Equal(t, "chatid", req.FormValue("chat_id")) + assert.Equal(t, "HTML", req.FormValue("parse_mode")) + assert.Equal(t, "Message", req.FormValue("text")) } func Test_goBlog_initTelegram(t *testing.T) { @@ -108,9 +124,17 @@ func Test_goBlog_initTelegram(t *testing.T) { func Test_telegram(t *testing.T) { t.Run("Send post to Telegram", func(t *testing.T) { - fakeClient := &fakeHttpClient{} + fakeClient := newFakeHttpClient() - fakeClient.setFakeResponse(200, "") + fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" { + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`)) + return + } + rw.WriteHeader(http.StatusOK) + rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":123456789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`)) + })) app := &goBlog{ pPostHooks: []postHookFunc{}, @@ -131,7 +155,7 @@ func Test_telegram(t *testing.T) { }, }, }, - httpClient: fakeClient, + httpClient: fakeClient.Client, } _ = app.initDatabase(false) @@ -149,17 +173,16 @@ func Test_telegram(t *testing.T) { app.pPostHooks[0](p) - assert.Equal( - t, - "https://api.telegram.org/botbottoken/sendMessage?chat_id=chatid&parse_mode=HTML&text=Title%0A%0A%3Ca+href%3D%22https%3A%2F%2Fexample.com%2Fs%2F1%22%3Ehttps%3A%2F%2Fexample.com%2Fs%2F1%3C%2Fa%3E", - fakeClient.req.URL.String(), - ) + assert.Equal(t, "https://api.telegram.org/botbottoken/sendMessage", fakeClient.req.URL.String()) + + req := fakeClient.req + assert.Equal(t, "chatid", req.FormValue("chat_id")) + assert.Equal(t, "HTML", req.FormValue("parse_mode")) + assert.Equal(t, "Title\n\nhttps://example.com/s/1", req.FormValue("text")) }) t.Run("Telegram disabled", func(t *testing.T) { - fakeClient := &fakeHttpClient{} - - fakeClient.setFakeResponse(200, "") + fakeClient := newFakeHttpClient() app := &goBlog{ pPostHooks: []postHookFunc{}, @@ -174,7 +197,7 @@ func Test_telegram(t *testing.T) { "en": {}, }, }, - httpClient: fakeClient, + httpClient: fakeClient.Client, } _ = app.initDatabase(false) diff --git a/webmentionVerification_test.go b/webmentionVerification_test.go index 2ef8e78..1709f2b 100644 --- a/webmentionVerification_test.go +++ b/webmentionVerification_test.go @@ -16,11 +16,11 @@ func Test_verifyMention(t *testing.T) { require.NoError(t, err) testHtml := string(testHtmlBytes) - mockClient := &fakeHttpClient{} + mockClient := newFakeHttpClient() mockClient.setFakeResponse(http.StatusOK, testHtml) app := &goBlog{ - httpClient: mockClient, + httpClient: mockClient.Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"), @@ -65,11 +65,11 @@ func Test_verifyMentionBidgy(t *testing.T) { require.NoError(t, err) testHtml := string(testHtmlBytes) - mockClient := &fakeHttpClient{} + mockClient := newFakeHttpClient() mockClient.setFakeResponse(http.StatusOK, testHtml) app := &goBlog{ - httpClient: mockClient, + httpClient: mockClient.Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"), @@ -108,11 +108,11 @@ func Test_verifyMentionColin(t *testing.T) { require.NoError(t, err) testHtml := string(testHtmlBytes) - mockClient := &fakeHttpClient{} + mockClient := newFakeHttpClient() mockClient.setFakeResponse(http.StatusOK, testHtml) app := &goBlog{ - httpClient: mockClient, + httpClient: mockClient.Client, cfg: &config{ Db: &configDb{ File: filepath.Join(t.TempDir(), "test.db"),