diff --git a/utils.go b/utils.go index 7d9a5c9..e5cb232 100644 --- a/utils.go +++ b/utils.go @@ -5,7 +5,6 @@ import ( "fmt" "html/template" "io" - "net/http" "net/url" "path" "sort" @@ -46,25 +45,6 @@ func generateRandomString(chars int) string { return funk.RandomString(chars, []rune(randomLetters)) } -func isAllowedHost(r *http.Request, hosts ...string) bool { - if r.URL == nil { - return false - } - rh := r.URL.Host - switch r.URL.Scheme { - case "http": - rh = strings.TrimSuffix(rh, ":80") - case "https": - rh = strings.TrimSuffix(rh, ":443") - } - for _, host := range hosts { - if rh == host { - return true - } - } - return false -} - func isAbsoluteURL(s string) bool { if u, err := url.Parse(s); err != nil || !u.IsAbs() { return false diff --git a/utils_test.go b/utils_test.go index 8c559ca..b9d8e8e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,8 +1,6 @@ package main import ( - "net/http" - "net/http/httptest" "reflect" "testing" @@ -28,28 +26,6 @@ func Test_generateRandomString(t *testing.T) { } } -func Test_isAllowedHost(t *testing.T) { - req1 := httptest.NewRequest(http.MethodGet, "https://example.com", nil) - req2 := httptest.NewRequest(http.MethodGet, "https://example.com:443", nil) - req3 := httptest.NewRequest(http.MethodGet, "http://example.com:80", nil) - - if isAllowedHost(req1, "example.com") != true { - t.Error("Wrong result") - } - - if isAllowedHost(req1, "example.net") != false { - t.Error("Wrong result") - } - - if isAllowedHost(req2, "example.com") != true { - t.Error("Wrong result") - } - - if isAllowedHost(req3, "example.com") != true { - t.Error("Wrong result") - } -} - func Test_isAbsoluteURL(t *testing.T) { if isAbsoluteURL("http://example.com") != true { t.Error("Wrong result") diff --git a/webmention.go b/webmention.go index ee84839..e9d920e 100644 --- a/webmention.go +++ b/webmention.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "net/http/httptest" "strings" "time" @@ -52,7 +51,7 @@ func (a *goBlog) handleWebmention(w http.ResponseWriter, r *http.Request) { a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - if !isAllowedHost(httptest.NewRequest(http.MethodGet, m.Target, nil), a.cfg.Server.publicHostname) { + if !strings.HasPrefix(m.Target, a.cfg.Server.PublicAddress) { a.serveError(w, r, "target not allowed", http.StatusBadRequest) return } diff --git a/webmentionVerification.go b/webmentionVerification.go index eb4f581..6497921 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -63,6 +63,18 @@ func (a *goBlog) queueMention(m *mention) error { } func (a *goBlog) verifyMention(m *mention) error { + // Parse url -> string for source and target + u, err := url.Parse(m.Source) + if err != nil { + return err + } + m.Source = u.String() + u, err = url.Parse(m.Target) + if err != nil { + return err + } + m.Target = u.String() + // Do request req, err := http.NewRequest(http.MethodGet, m.Source, nil) if err != nil { return err @@ -156,7 +168,7 @@ func (m *mention) fill(mf *microformats.Microformat) bool { // Check URL if url, ok := mf.Properties["url"]; ok && len(url) > 0 { if url0, ok := url[0].(string); ok { - if strings.ToLower(url0) != strings.ToLower(m.Source) { + if !strings.EqualFold(url0, m.Source) { // Not correct URL return false }