diff --git a/captcha.go b/captcha.go
index 12f381a..5f7c8aa 100644
--- a/captcha.go
+++ b/captcha.go
@@ -22,16 +22,31 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}
- // Check Cookie
+ // Check session
ses, err := a.captchaSessions.Get(r, "c")
- if err == nil && ses != nil {
- if captcha, ok := ses.Values["captcha"]; ok && captcha.(bool) {
- next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), captchaSolvedKey, true)))
- return
+ if err != nil {
+ a.serveError(w, r, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ if captcha, ok := ses.Values["captcha"]; ok && captcha == true {
+ // Captcha already solved
+ next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), captchaSolvedKey, true)))
+ return
+ }
+ // Get captcha ID
+ captchaId := ""
+ if sesCaptchaId, ok := ses.Values["captchaid"]; ok {
+ // Already has a captcha ID
+ ci := sesCaptchaId.(string)
+ if captcha.Reload(ci) {
+ captchaId = ci
}
}
- // Show Captcha
- w.Header().Set("Cache-Control", "no-store,max-age=0")
+ if captchaId == "" {
+ captchaId = captcha.New()
+ ses.Values["captchaid"] = captchaId
+ }
+ // Encode original request
h, _ := json.Marshal(r.Header)
b, _ := io.ReadAll(io.LimitReader(r.Body, 20*1000*1000)) // Only allow 20 MB
_ = r.Body.Close()
@@ -40,13 +55,16 @@ func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler {
_ = r.ParseForm()
b = []byte(r.PostForm.Encode())
}
+ // Render captcha
+ ses.Save(r, w)
+ w.Header().Set("Cache-Control", "no-store,max-age=0")
a.renderWithStatusCode(w, r, http.StatusUnauthorized, templateCaptcha, &renderData{
BlogString: r.Context().Value(blogKey).(string),
Data: map[string]string{
"captchamethod": r.Method,
"captchaheaders": base64.StdEncoding.EncodeToString(h),
"captchabody": base64.StdEncoding.EncodeToString(b),
- "captchaid": captcha.New(),
+ "captchaid": captchaId,
},
})
})
@@ -70,19 +88,24 @@ func (a *goBlog) checkCaptcha(w http.ResponseWriter, r *http.Request) bool {
if r.FormValue("captchaaction") != "captcha" {
return false
}
- // Prepare original request
+ // Decode and prepare original request
captchabody, _ := base64.StdEncoding.DecodeString(r.FormValue("captchabody"))
- req, _ := http.NewRequest(r.FormValue("captchamethod"), r.RequestURI, bytes.NewReader(captchabody))
+ origReq, _ := http.NewRequest(r.FormValue("captchamethod"), r.RequestURI, bytes.NewReader(captchabody))
// Copy original headers
captchaheaders, _ := base64.StdEncoding.DecodeString(r.FormValue("captchaheaders"))
var headers http.Header
_ = json.Unmarshal(captchaheaders, &headers)
for k, v := range headers {
- req.Header[k] = v
+ origReq.Header[k] = v
}
- // Check captcha and create cookie
- if captcha.VerifyString(r.FormValue("captchaid"), r.FormValue("digits")) {
- ses, err := a.captchaSessions.Get(r, "c")
+ // Get session
+ ses, err := a.captchaSessions.Get(r, "c")
+ if err != nil {
+ a.serveError(w, r, err.Error(), http.StatusInternalServerError)
+ return true
+ }
+ // Check if session contains a captchaId and if captcha is solved
+ if sesCaptchaId, ok := ses.Values["captchaid"]; ok && captcha.VerifyString(sesCaptchaId.(string), r.FormValue("digits")) {
if err != nil {
a.serveError(w, r, err.Error(), http.StatusInternalServerError)
return true
@@ -93,9 +116,13 @@ func (a *goBlog) checkCaptcha(w http.ResponseWriter, r *http.Request) bool {
a.serveError(w, r, err.Error(), http.StatusInternalServerError)
return true
}
- req = req.WithContext(context.WithValue(req.Context(), captchaSolvedKey, true))
+ origReq = origReq.WithContext(context.WithValue(origReq.Context(), captchaSolvedKey, true))
+ }
+ // Copy captcha cookie to original request
+ if captchaCookie, err := r.Cookie("c"); err == nil {
+ origReq.AddCookie(captchaCookie)
}
// Serve original request
- a.d.ServeHTTP(w, req)
+ a.d.ServeHTTP(w, origReq)
return true
}
diff --git a/templates/captcha.gohtml b/templates/captcha.gohtml
index feb0311..50b1914 100644
--- a/templates/captcha.gohtml
+++ b/templates/captcha.gohtml
@@ -10,7 +10,6 @@
-