diff --git a/activityPubSending.go b/activityPubSending.go index fc4ec50..ea118cb 100644 --- a/activityPubSending.go +++ b/activityPubSending.go @@ -78,7 +78,6 @@ func (a *goBlog) apSendSigned(blogIri, to string, activity []byte) error { return err } r.Header.Set("Accept-Charset", "utf-8") - r.Header.Set(userAgent, appUserAgent) r.Header.Set("Accept", contenttype.ASUTF8) r.Header.Set(contentType, contenttype.ASUTF8) // Sign request diff --git a/activityPubTools.go b/activityPubTools.go index dd9b9fb..74c4303 100644 --- a/activityPubTools.go +++ b/activityPubTools.go @@ -43,7 +43,6 @@ func (a *goBlog) apRemoteFollow(w http.ResponseWriter, r *http.Request) { webfinger := &webfingerType{} err := requests.URL(fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", instance, user, instance)). Client(a.httpClient). - UserAgent(appUserAgent). Handle(func(resp *http.Response) error { defer resp.Body.Close() return json.NewDecoder(io.LimitReader(resp.Body, 1000*1000)).Decode(webfinger) diff --git a/blogroll.go b/blogroll.go index 0f53f18..098a2be 100644 --- a/blogroll.go +++ b/blogroll.go @@ -71,7 +71,7 @@ func (a *goBlog) getBlogrollOutlines(blog string) ([]*opml.Outline, error) { if cache := a.db.loadOutlineCache(blog); cache != nil { return cache, nil } - rb := requests.URL(config.Opml).Client(a.httpClient).UserAgent(appUserAgent) + rb := requests.URL(config.Opml).Client(a.httpClient) if config.AuthHeader != "" && config.AuthValue != "" { rb.Header(config.AuthHeader, config.AuthValue) } diff --git a/geo.go b/geo.go index c98e805..8e65792 100644 --- a/geo.go +++ b/geo.go @@ -47,7 +47,7 @@ func (a *goBlog) photonReverse(lat, lon float64, lang string) (*geojson.FeatureC buf := bufferpool.Get() defer bufferpool.Put(buf) // Create request - rb := requests.URL("https://photon.komoot.io/reverse").Client(a.httpClient).UserAgent(appUserAgent).ToBytesBuffer(buf) + rb := requests.URL("https://photon.komoot.io/reverse").Client(a.httpClient).ToBytesBuffer(buf) // Set parameters rb.Param("lat", fmt.Sprintf("%v", lat)).Param("lon", fmt.Sprintf("%v", lon)) rb.Param("lang", lo.If(lang == "de" || lang == "fr" || lang == "it", lang).Else("en")) // Photon only supports en, de, fr, it diff --git a/geoTiles.go b/geoTiles.go index f16d2ab..1713cc2 100644 --- a/geoTiles.go +++ b/geoTiles.go @@ -21,7 +21,6 @@ func (a *goBlog) proxyTiles() http.HandlerFunc { targetUrl = strings.ReplaceAll(targetUrl, "{x}", chi.URLParam(r, "x")) targetUrl = strings.ReplaceAll(targetUrl, "{y}", chi.URLParam(r, "y")) proxyRequest, _ := http.NewRequestWithContext(r.Context(), http.MethodGet, targetUrl, nil) - proxyRequest.Header.Set(userAgent, appUserAgent) // Copy request headers for _, k := range []string{ "Accept-Encoding", diff --git a/httpClient.go b/httpClient.go index 2664b54..a91e8ad 100644 --- a/httpClient.go +++ b/httpClient.go @@ -10,8 +10,25 @@ import ( func newHttpClient() *http.Client { return &http.Client{ Timeout: time.Minute, - Transport: gzhttp.Transport(&http.Transport{ - DisableKeepAlives: true, - }), + Transport: newAddUserAgentTransport( + gzhttp.Transport( + &http.Transport{ + DisableKeepAlives: true, + }, + ), + ), } } + +type addUserAgentTransport struct { + t http.RoundTripper +} + +func (t *addUserAgentTransport) RoundTrip(r *http.Request) (*http.Response, error) { + r.Header.Set(userAgent, appUserAgent) + return t.t.RoundTrip(r) +} + +func newAddUserAgentTransport(t http.RoundTripper) *addUserAgentTransport { + return &addUserAgentTransport{t} +} diff --git a/httpClient_test.go b/httpClient_test.go index 3d1a6f3..159503e 100644 --- a/httpClient_test.go +++ b/httpClient_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/carlmjohnson/requests" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -75,3 +77,20 @@ func Test_fakeHttpClient(t *testing.T) { require.Equal(t, http.StatusNotFound, resp.StatusCode) _ = resp.Body.Close() } + +func Test_addUserAgent(t *testing.T) { + ua := "ABC" + + client := &http.Client{ + Transport: newAddUserAgentTransport(&handlerRoundTripper{ + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get(userAgent) + }), + }), + } + + err := requests.URL("http://example.com").UserAgent("WRONG").Client(client).Fetch(context.Background()) + require.NoError(t, err) + + assert.Equal(t, appUserAgent, ua) +} diff --git a/indexnow.go b/indexnow.go index bd4d195..f102821 100644 --- a/indexnow.go +++ b/indexnow.go @@ -55,7 +55,6 @@ func (a *goBlog) indexNow(url string) { } err := requests.URL("https://api.indexnow.org/indexnow"). Client(a.httpClient). - UserAgent(appUserAgent). Param("url", url). Param("key", string(key)). Fetch(context.Background()) diff --git a/ntfy.go b/ntfy.go index 9eb93e6..5b367d3 100644 --- a/ntfy.go +++ b/ntfy.go @@ -25,7 +25,6 @@ func (a *goBlog) sendNtfy(cfg *configNtfy, msg string) error { builder := requests. URL(server + "/" + topic). Client(a.httpClient). - UserAgent(appUserAgent). Method(http.MethodPost). BodyReader(strings.NewReader(msg)) if cfg.User != "" { diff --git a/tts.go b/tts.go index 19bca49..cd221e7 100644 --- a/tts.go +++ b/tts.go @@ -210,7 +210,6 @@ func (a *goBlog) createTTSAudio(lang, ssml string, w io.Writer) error { URL("https://texttospeech.googleapis.com/v1beta1/text:synthesize"). Param("key", gctts.GoogleAPIKey). Client(a.httpClient). - UserAgent(appUserAgent). Method(http.MethodPost). BodyJSON(body). ToJSON(&response). diff --git a/webmentionSending.go b/webmentionSending.go index 1033042..9599ce0 100644 --- a/webmentionSending.go +++ b/webmentionSending.go @@ -77,7 +77,7 @@ func (a *goBlog) sendWebmentions(p *post) error { func (a *goBlog) sendWebmention(endpoint, source, target string) error { // TODO: Pass all tests from https://webmention.rocks/ - return requests.URL(endpoint).Client(a.httpClient).Method(http.MethodPost).UserAgent(appUserAgent). + return requests.URL(endpoint).Client(a.httpClient).Method(http.MethodPost). BodyForm(url.Values{ "source": []string{source}, "target": []string{target}, @@ -94,7 +94,7 @@ func (a *goBlog) sendWebmention(endpoint, source, target string) error { func (a *goBlog) discoverEndpoint(urlStr string) string { doRequest := func(method, urlStr string) string { endpoint := "" - if err := requests.URL(urlStr).Client(a.httpClient).Method(method).UserAgent(appUserAgent). + if err := requests.URL(urlStr).Client(a.httpClient).Method(method). AddValidator(func(r *http.Response) error { if r.StatusCode < 200 || 300 <= r.StatusCode { return fmt.Errorf("HTTP %d", r.StatusCode) diff --git a/webmentionVerification.go b/webmentionVerification.go index ab12420..0149936 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -89,7 +89,6 @@ func (a *goBlog) verifyMention(m *mention) error { } defer sourceResp.Body.Close() } else { - sourceReq.Header.Set(userAgent, appUserAgent) sourceResp, err = a.httpClient.Do(sourceReq) if err != nil { return err