diff --git a/authentication.go b/authentication.go index b42efd3..bc7af9e 100644 --- a/authentication.go +++ b/authentication.go @@ -16,22 +16,32 @@ func checkCredentials(username, password string) bool { return username == appConfig.User.Nick && password == appConfig.User.Password } +func checkUsername(username string) bool { + return username == appConfig.User.Nick +} + func jwtKey() []byte { return []byte(appConfig.Server.JWTSecret) } func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 1. Check JWT + // 1. Check BasicAuth + if username, password, ok := r.BasicAuth(); ok && checkCredentials(username, password) { + next.ServeHTTP(w, r) + return + } + // 2. Check JWT if tokenCookie, err := r.Cookie("token"); err == nil { - if tkn, err := jwt.Parse(tokenCookie.Value, func(t *jwt.Token) (interface{}, error) { + claims := &authClaims{} + if tkn, err := jwt.ParseWithClaims(tokenCookie.Value, claims, func(t *jwt.Token) (interface{}, error) { return jwtKey(), nil - }); err == nil && tkn.Valid { + }); err == nil && tkn.Valid && claims.TokenType == "login" && checkUsername(claims.Username) { next.ServeHTTP(w, r) return } } - // 2. Show login form + // 3. Show login form w.WriteHeader(http.StatusUnauthorized) h, _ := json.Marshal(r.Header.Clone()) b, _ := ioutil.ReadAll(io.LimitReader(r.Body, 2000000)) // Only allow 20 Megabyte @@ -75,7 +85,7 @@ func checkLogin(w http.ResponseWriter, r *http.Request) bool { } // Check credential if checkCredentials(r.FormValue("username"), r.FormValue("password")) { - tokenCookie, err := createTokenCookie() + tokenCookie, err := createTokenCookie(r.FormValue("username")) if err != nil { serveError(w, r, err.Error(), http.StatusInternalServerError) return true @@ -92,9 +102,19 @@ func checkLogin(w http.ResponseWriter, r *http.Request) bool { return false } -func createTokenCookie() (*http.Cookie, error) { +type authClaims struct { + *jwt.StandardClaims + TokenType string + Username string +} + +func createTokenCookie(username string) (*http.Cookie, error) { expiration := time.Now().Add(7 * 24 * time.Hour) - tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwt.StandardClaims{ExpiresAt: expiration.Unix()}).SignedString(jwtKey()) + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &authClaims{ + &jwt.StandardClaims{ExpiresAt: expiration.Unix()}, + "login", + username, + }).SignedString(jwtKey()) if err != nil { return nil, err } diff --git a/captcha.go b/captcha.go index 02a752c..e3dac89 100644 --- a/captcha.go +++ b/captcha.go @@ -19,10 +19,11 @@ func initCaptcha() { func captchaMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 1. Check JWT + claims := &captchaClaims{} if captchaCookie, err := r.Cookie("captcha"); err == nil { - if tkn, err := jwt.Parse(captchaCookie.Value, func(t *jwt.Token) (interface{}, error) { + if tkn, err := jwt.ParseWithClaims(captchaCookie.Value, claims, func(t *jwt.Token) (interface{}, error) { return jwtKey(), nil - }); err == nil && tkn.Valid { + }); err == nil && tkn.Valid && claims.TokenType == "captcha" { next.ServeHTTP(w, r) return } @@ -90,9 +91,17 @@ func checkCaptcha(w http.ResponseWriter, r *http.Request) bool { return false } +type captchaClaims struct { + *jwt.StandardClaims + TokenType string +} + func createCaptchaCookie() (*http.Cookie, error) { expiration := time.Now().Add(24 * time.Hour) - tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &jwt.StandardClaims{ExpiresAt: expiration.Unix()}).SignedString(jwtKey()) + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &captchaClaims{ + &jwt.StandardClaims{ExpiresAt: expiration.Unix()}, + "captcha", + }).SignedString(jwtKey()) if err != nil { return nil, err }