Set appAserAgent on all requests using RoundTripper

This commit is contained in:
Jan-Lukas Else 2022-12-14 10:21:32 +01:00
parent c8121e3a3a
commit 34ab1b1fb2
12 changed files with 43 additions and 14 deletions

View File

@ -78,7 +78,6 @@ func (a *goBlog) apSendSigned(blogIri, to string, activity []byte) error {
return err return err
} }
r.Header.Set("Accept-Charset", "utf-8") r.Header.Set("Accept-Charset", "utf-8")
r.Header.Set(userAgent, appUserAgent)
r.Header.Set("Accept", contenttype.ASUTF8) r.Header.Set("Accept", contenttype.ASUTF8)
r.Header.Set(contentType, contenttype.ASUTF8) r.Header.Set(contentType, contenttype.ASUTF8)
// Sign request // Sign request

View File

@ -43,7 +43,6 @@ func (a *goBlog) apRemoteFollow(w http.ResponseWriter, r *http.Request) {
webfinger := &webfingerType{} webfinger := &webfingerType{}
err := requests.URL(fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", instance, user, instance)). err := requests.URL(fmt.Sprintf("https://%s/.well-known/webfinger?resource=acct:%s@%s", instance, user, instance)).
Client(a.httpClient). Client(a.httpClient).
UserAgent(appUserAgent).
Handle(func(resp *http.Response) error { Handle(func(resp *http.Response) error {
defer resp.Body.Close() defer resp.Body.Close()
return json.NewDecoder(io.LimitReader(resp.Body, 1000*1000)).Decode(webfinger) return json.NewDecoder(io.LimitReader(resp.Body, 1000*1000)).Decode(webfinger)

View File

@ -71,7 +71,7 @@ func (a *goBlog) getBlogrollOutlines(blog string) ([]*opml.Outline, error) {
if cache := a.db.loadOutlineCache(blog); cache != nil { if cache := a.db.loadOutlineCache(blog); cache != nil {
return cache, nil return cache, nil
} }
rb := requests.URL(config.Opml).Client(a.httpClient).UserAgent(appUserAgent) rb := requests.URL(config.Opml).Client(a.httpClient)
if config.AuthHeader != "" && config.AuthValue != "" { if config.AuthHeader != "" && config.AuthValue != "" {
rb.Header(config.AuthHeader, config.AuthValue) rb.Header(config.AuthHeader, config.AuthValue)
} }

2
geo.go
View File

@ -47,7 +47,7 @@ func (a *goBlog) photonReverse(lat, lon float64, lang string) (*geojson.FeatureC
buf := bufferpool.Get() buf := bufferpool.Get()
defer bufferpool.Put(buf) defer bufferpool.Put(buf)
// Create request // Create request
rb := requests.URL("https://photon.komoot.io/reverse").Client(a.httpClient).UserAgent(appUserAgent).ToBytesBuffer(buf) rb := requests.URL("https://photon.komoot.io/reverse").Client(a.httpClient).ToBytesBuffer(buf)
// Set parameters // Set parameters
rb.Param("lat", fmt.Sprintf("%v", lat)).Param("lon", fmt.Sprintf("%v", lon)) rb.Param("lat", fmt.Sprintf("%v", lat)).Param("lon", fmt.Sprintf("%v", lon))
rb.Param("lang", lo.If(lang == "de" || lang == "fr" || lang == "it", lang).Else("en")) // Photon only supports en, de, fr, it rb.Param("lang", lo.If(lang == "de" || lang == "fr" || lang == "it", lang).Else("en")) // Photon only supports en, de, fr, it

View File

@ -21,7 +21,6 @@ func (a *goBlog) proxyTiles() http.HandlerFunc {
targetUrl = strings.ReplaceAll(targetUrl, "{x}", chi.URLParam(r, "x")) targetUrl = strings.ReplaceAll(targetUrl, "{x}", chi.URLParam(r, "x"))
targetUrl = strings.ReplaceAll(targetUrl, "{y}", chi.URLParam(r, "y")) targetUrl = strings.ReplaceAll(targetUrl, "{y}", chi.URLParam(r, "y"))
proxyRequest, _ := http.NewRequestWithContext(r.Context(), http.MethodGet, targetUrl, nil) proxyRequest, _ := http.NewRequestWithContext(r.Context(), http.MethodGet, targetUrl, nil)
proxyRequest.Header.Set(userAgent, appUserAgent)
// Copy request headers // Copy request headers
for _, k := range []string{ for _, k := range []string{
"Accept-Encoding", "Accept-Encoding",

View File

@ -10,8 +10,25 @@ import (
func newHttpClient() *http.Client { func newHttpClient() *http.Client {
return &http.Client{ return &http.Client{
Timeout: time.Minute, Timeout: time.Minute,
Transport: gzhttp.Transport(&http.Transport{ Transport: newAddUserAgentTransport(
DisableKeepAlives: true, gzhttp.Transport(
}), &http.Transport{
DisableKeepAlives: true,
},
),
),
} }
} }
type addUserAgentTransport struct {
t http.RoundTripper
}
func (t *addUserAgentTransport) RoundTrip(r *http.Request) (*http.Response, error) {
r.Header.Set(userAgent, appUserAgent)
return t.t.RoundTrip(r)
}
func newAddUserAgentTransport(t http.RoundTripper) *addUserAgentTransport {
return &addUserAgentTransport{t}
}

View File

@ -8,6 +8,8 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/carlmjohnson/requests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -75,3 +77,20 @@ func Test_fakeHttpClient(t *testing.T) {
require.Equal(t, http.StatusNotFound, resp.StatusCode) require.Equal(t, http.StatusNotFound, resp.StatusCode)
_ = resp.Body.Close() _ = resp.Body.Close()
} }
func Test_addUserAgent(t *testing.T) {
ua := "ABC"
client := &http.Client{
Transport: newAddUserAgentTransport(&handlerRoundTripper{
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ua = r.Header.Get(userAgent)
}),
}),
}
err := requests.URL("http://example.com").UserAgent("WRONG").Client(client).Fetch(context.Background())
require.NoError(t, err)
assert.Equal(t, appUserAgent, ua)
}

View File

@ -55,7 +55,6 @@ func (a *goBlog) indexNow(url string) {
} }
err := requests.URL("https://api.indexnow.org/indexnow"). err := requests.URL("https://api.indexnow.org/indexnow").
Client(a.httpClient). Client(a.httpClient).
UserAgent(appUserAgent).
Param("url", url). Param("url", url).
Param("key", string(key)). Param("key", string(key)).
Fetch(context.Background()) Fetch(context.Background())

View File

@ -25,7 +25,6 @@ func (a *goBlog) sendNtfy(cfg *configNtfy, msg string) error {
builder := requests. builder := requests.
URL(server + "/" + topic). URL(server + "/" + topic).
Client(a.httpClient). Client(a.httpClient).
UserAgent(appUserAgent).
Method(http.MethodPost). Method(http.MethodPost).
BodyReader(strings.NewReader(msg)) BodyReader(strings.NewReader(msg))
if cfg.User != "" { if cfg.User != "" {

1
tts.go
View File

@ -210,7 +210,6 @@ func (a *goBlog) createTTSAudio(lang, ssml string, w io.Writer) error {
URL("https://texttospeech.googleapis.com/v1beta1/text:synthesize"). URL("https://texttospeech.googleapis.com/v1beta1/text:synthesize").
Param("key", gctts.GoogleAPIKey). Param("key", gctts.GoogleAPIKey).
Client(a.httpClient). Client(a.httpClient).
UserAgent(appUserAgent).
Method(http.MethodPost). Method(http.MethodPost).
BodyJSON(body). BodyJSON(body).
ToJSON(&response). ToJSON(&response).

View File

@ -77,7 +77,7 @@ func (a *goBlog) sendWebmentions(p *post) error {
func (a *goBlog) sendWebmention(endpoint, source, target string) error { func (a *goBlog) sendWebmention(endpoint, source, target string) error {
// TODO: Pass all tests from https://webmention.rocks/ // TODO: Pass all tests from https://webmention.rocks/
return requests.URL(endpoint).Client(a.httpClient).Method(http.MethodPost).UserAgent(appUserAgent). return requests.URL(endpoint).Client(a.httpClient).Method(http.MethodPost).
BodyForm(url.Values{ BodyForm(url.Values{
"source": []string{source}, "source": []string{source},
"target": []string{target}, "target": []string{target},
@ -94,7 +94,7 @@ func (a *goBlog) sendWebmention(endpoint, source, target string) error {
func (a *goBlog) discoverEndpoint(urlStr string) string { func (a *goBlog) discoverEndpoint(urlStr string) string {
doRequest := func(method, urlStr string) string { doRequest := func(method, urlStr string) string {
endpoint := "" endpoint := ""
if err := requests.URL(urlStr).Client(a.httpClient).Method(method).UserAgent(appUserAgent). if err := requests.URL(urlStr).Client(a.httpClient).Method(method).
AddValidator(func(r *http.Response) error { AddValidator(func(r *http.Response) error {
if r.StatusCode < 200 || 300 <= r.StatusCode { if r.StatusCode < 200 || 300 <= r.StatusCode {
return fmt.Errorf("HTTP %d", r.StatusCode) return fmt.Errorf("HTTP %d", r.StatusCode)

View File

@ -89,7 +89,6 @@ func (a *goBlog) verifyMention(m *mention) error {
} }
defer sourceResp.Body.Close() defer sourceResp.Body.Close()
} else { } else {
sourceReq.Header.Set(userAgent, appUserAgent)
sourceResp, err = a.httpClient.Do(sourceReq) sourceResp, err = a.httpClient.Do(sourceReq)
if err != nil { if err != nil {
return err return err