diff --git a/captcha.go b/captcha.go index 12f381a..5f7c8aa 100644 --- a/captcha.go +++ b/captcha.go @@ -22,16 +22,31 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } - // Check Cookie + // Check session ses, err := a.captchaSessions.Get(r, "c") - if err == nil && ses != nil { - if captcha, ok := ses.Values["captcha"]; ok && captcha.(bool) { - next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), captchaSolvedKey, true))) - return + if err != nil { + a.serveError(w, r, err.Error(), http.StatusInternalServerError) + return + } + if captcha, ok := ses.Values["captcha"]; ok && captcha == true { + // Captcha already solved + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), captchaSolvedKey, true))) + return + } + // Get captcha ID + captchaId := "" + if sesCaptchaId, ok := ses.Values["captchaid"]; ok { + // Already has a captcha ID + ci := sesCaptchaId.(string) + if captcha.Reload(ci) { + captchaId = ci } } - // Show Captcha - w.Header().Set("Cache-Control", "no-store,max-age=0") + if captchaId == "" { + captchaId = captcha.New() + ses.Values["captchaid"] = captchaId + } + // Encode original request h, _ := json.Marshal(r.Header) b, _ := io.ReadAll(io.LimitReader(r.Body, 20*1000*1000)) // Only allow 20 MB _ = r.Body.Close() @@ -40,13 +55,16 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { _ = r.ParseForm() b = []byte(r.PostForm.Encode()) } + // Render captcha + ses.Save(r, w) + w.Header().Set("Cache-Control", "no-store,max-age=0") a.renderWithStatusCode(w, r, http.StatusUnauthorized, templateCaptcha, &renderData{ BlogString: r.Context().Value(blogKey).(string), Data: map[string]string{ "captchamethod": r.Method, "captchaheaders": base64.StdEncoding.EncodeToString(h), "captchabody": base64.StdEncoding.EncodeToString(b), - "captchaid": captcha.New(), + "captchaid": captchaId, }, }) }) @@ -70,19 +88,24 @@ func (a *goBlog) checkCaptcha(w http.ResponseWriter, r *http.Request) bool { if r.FormValue("captchaaction") != "captcha" { return false } - // Prepare original request + // Decode and prepare original request captchabody, _ := base64.StdEncoding.DecodeString(r.FormValue("captchabody")) - req, _ := http.NewRequest(r.FormValue("captchamethod"), r.RequestURI, bytes.NewReader(captchabody)) + origReq, _ := http.NewRequest(r.FormValue("captchamethod"), r.RequestURI, bytes.NewReader(captchabody)) // Copy original headers captchaheaders, _ := base64.StdEncoding.DecodeString(r.FormValue("captchaheaders")) var headers http.Header _ = json.Unmarshal(captchaheaders, &headers) for k, v := range headers { - req.Header[k] = v + origReq.Header[k] = v } - // Check captcha and create cookie - if captcha.VerifyString(r.FormValue("captchaid"), r.FormValue("digits")) { - ses, err := a.captchaSessions.Get(r, "c") + // Get session + ses, err := a.captchaSessions.Get(r, "c") + if err != nil { + a.serveError(w, r, err.Error(), http.StatusInternalServerError) + return true + } + // Check if session contains a captchaId and if captcha is solved + if sesCaptchaId, ok := ses.Values["captchaid"]; ok && captcha.VerifyString(sesCaptchaId.(string), r.FormValue("digits")) { if err != nil { a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true @@ -93,9 +116,13 @@ func (a *goBlog) checkCaptcha(w http.ResponseWriter, r *http.Request) bool { a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true } - req = req.WithContext(context.WithValue(req.Context(), captchaSolvedKey, true)) + origReq = origReq.WithContext(context.WithValue(origReq.Context(), captchaSolvedKey, true)) + } + // Copy captcha cookie to original request + if captchaCookie, err := r.Cookie("c"); err == nil { + origReq.AddCookie(captchaCookie) } // Serve original request - a.d.ServeHTTP(w, req) + a.d.ServeHTTP(w, origReq) return true } diff --git a/templates/captcha.gohtml b/templates/captcha.gohtml index feb0311..50b1914 100644 --- a/templates/captcha.gohtml +++ b/templates/captcha.gohtml @@ -10,7 +10,6 @@ -