GoBlog/indieauth.go

92 lines
2.3 KiB
Go

package main
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
type indieAuthTokenResponse struct {
Me string `json:"me"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
IssuedBy string `json:"issued_by"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
StatusCode int
}
func checkIndieAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bearerToken := r.Header.Get("Authorization")
if len(bearerToken) == 0 {
if accessToken := r.URL.Query().Get("access_token"); len(accessToken) > 0 {
bearerToken = "Bearer " + accessToken
}
}
tokenResponse, err := verifyIndieAuthAccessToken(bearerToken)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if tokenResponse.StatusCode != http.StatusOK {
http.Error(w, "Failed to retrieve authentication information", http.StatusUnauthorized)
return
}
authorized := false
for _, allowed := range appConfig.Micropub.AuthAllowed {
if err := compareHostnames(tokenResponse.Me, allowed); err == nil {
authorized = true
break
}
}
if !authorized {
http.Error(w, "Forbidden", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
return
})
}
func verifyIndieAuthAccessToken(bearerToken string) (*indieAuthTokenResponse, error) {
if len(bearerToken) == 0 {
return nil, errors.New("no token")
}
req, err := http.NewRequest("GET", appConfig.Micropub.TokenEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Add(contentType, contentTypeWWWForm)
req.Header.Add("Authorization", bearerToken)
req.Header.Add("Accept", contentTypeJSON)
c := http.Client{
Timeout: time.Duration(10 * time.Second),
}
resp, err := c.Do(req)
if err != nil {
return nil, err
}
tokenRes := indieAuthTokenResponse{StatusCode: resp.StatusCode}
err = json.NewDecoder(resp.Body).Decode(&tokenRes)
resp.Body.Close()
if err != nil {
return nil, err
}
return &tokenRes, nil
}
func compareHostnames(a string, allowed string) error {
h1, err := url.Parse(a)
if err != nil {
return err
}
if strings.ToLower(h1.Hostname()) != strings.ToLower(allowed) {
return fmt.Errorf("hostnames do not match, %s is not %s", h1, allowed)
}
return nil
}