@ -1,14 +1,16 @@
package main
import (
"bytes"
"database/sql"
"fmt "
"encoding/gob "
"log"
"net/http"
"strings"
"time"
"github.com/araddon/dateparse"
"github.com/gorilla/securecookie "
"github.com/google/uuid "
"github.com/gorilla/sessions"
)
@ -30,7 +32,6 @@ func (a *goBlog) initSessions() {
deleteExpiredSessions ( )
a . hourlyHooks = append ( a . hourlyHooks , deleteExpiredSessions )
a . loginSessions = & dbSessionStore {
codecs : securecookie . CodecsFromPairs ( [ ] byte ( a . cfg . Server . JWTSecret ) ) ,
options : & sessions . Options {
Secure : a . httpsConfigured ( true ) ,
HttpOnly : true ,
@ -41,7 +42,6 @@ func (a *goBlog) initSessions() {
db : a . db ,
}
a . captchaSessions = & dbSessionStore {
codecs : securecookie . CodecsFromPairs ( [ ] byte ( a . cfg . Server . JWTSecret ) ) ,
options : & sessions . Options {
Secure : a . httpsConfigured ( true ) ,
HttpOnly : true ,
@ -55,7 +55,6 @@ func (a *goBlog) initSessions() {
type dbSessionStore struct {
options * sessions . Options
codecs [ ] securecookie . Codec
db * database
}
@ -67,29 +66,33 @@ func (s *dbSessionStore) New(r *http.Request, name string) (session *sessions.Se
session = sessions . NewSession ( s , name )
opts := * s . options
session . Options = & opts
session . IsNew = true
if cook , errCookie := r . Cookie ( name ) ; errCookie == nil {
if err = securecookie . DecodeMulti ( name , cook . Value , & session . ID , s . codecs ... ) ; err == nil {
session . IsNew = s . load ( session ) == nil
if c , cErr := r . Cookie ( name ) ; cErr == nil && strings . HasPrefix ( c . Value , session . Name ( ) + "-" ) {
// Has cookie, load from database
session . ID = c . Value
if s . load ( session ) != nil {
// Failed to load session from database, delete the ID (= new session)
session . ID = ""
}
}
// If no ID, the session is new
session . IsNew = session . ID == ""
return session , err
}
func ( s * dbSessionStore ) Save ( r * http . Request , w http . ResponseWriter , ss * sessions . Session ) ( err error ) {
if ss . ID == "" {
// Is new session, save it to database
if err = s . insert ( ss ) ; err != nil {
return err
}
} else if err = s . save ( ss ) ; err != nil {
return err
}
if encoded , err := securecookie . EncodeMulti ( ss . Name ( ) , ss . ID , s . codecs ... ) ; err != nil {
return err
} else {
http . SetCookie ( w , sessions . NewCookie ( ss . Name ( ) , encoded , ss . Options ) )
return nil
// Update existing session
if err = s . save ( ss ) ; err != nil {
return err
}
}
http . SetCookie ( w , sessions . NewCookie ( ss . Name ( ) , ss . ID , ss . Options ) )
return nil
}
func ( s * dbSessionStore ) Delete ( r * http . Request , w http . ResponseWriter , session * sessions . Session ) error {
@ -106,15 +109,20 @@ func (s *dbSessionStore) Delete(r *http.Request, w http.ResponseWriter, session
}
func ( s * dbSessionStore ) load ( session * sessions . Session ) ( err error ) {
row , err := s . db . queryRow ( "select data, created, modified, expires from sessions where id = @id and expires > @now" , sql . Named ( "id" , session . ID ) , sql . Named ( "now" , utcNowString ( ) ) )
row , err := s . db . queryRow (
"select data, created, modified, expires from sessions where id = @id and expires > @now" ,
sql . Named ( "id" , session . ID ) ,
sql . Named ( "now" , utcNowString ( ) ) ,
)
if err != nil {
return err
}
var data , createdStr , modifiedStr , expiresStr string
var createdStr , modifiedStr , expiresStr string
var data [ ] byte
if err = row . Scan ( & data , & createdStr , & modifiedStr , & expiresStr ) ; err != nil {
return err
}
if err = securecookie . DecodeMulti ( session . Name ( ) , data , & session . Values , s . codecs ... ) ; err != nil {
if err = gob . NewDecoder ( bytes . NewReader ( data ) ) . Decode ( & session . Values ) ; err != nil {
return err
}
session . Values [ sessionCreatedOn ] = timeNoErr ( dateparse . ParseLocal ( createdStr ) )
@ -124,44 +132,44 @@ func (s *dbSessionStore) load(session *sessions.Session) (err error) {
}
func ( s * dbSessionStore ) insert ( session * sessions . Session ) ( err error ) {
created := time . Now ( ) . UTC ( )
modified := time . Now ( ) . UTC ( )
expires := time . Now ( ) . UTC ( ) . Add ( time . Second * time . Duration ( session . Options . MaxAge ) )
delete ( session . Values , sessionCreatedOn )
delete ( session . Values , sessionExpiresOn )
delete ( session . Values , sessionModifiedOn )
encoded , err := securecookie . EncodeMulti ( session . Name ( ) , session . Values , s . codecs ... )
if err != nil {
return err
}
res , err := s . db . exec ( "insert or replace into sessions(data, created, modified, expires) values(@data, @created, @modified, @expires)" ,
sql . Named ( "data" , encoded ) , sql . Named ( "created" , created . Format ( time . RFC3339 ) ) , sql . Named ( "modified" , modified . Format ( time . RFC3339 ) ) , sql . Named ( "expires" , expires . Format ( time . RFC3339 ) ) )
if err != nil {
deleteSessionValuesNotNeededForDb ( session )
var encoded bytes . Buffer
if err := gob . NewEncoder ( & encoded ) . Encode ( session . Values ) ; err != nil {
return err
}
lastInserted , err := res . LastInsertId ( )
if err != nil {
return err
}
session . ID = fmt . Sprintf ( "%d" , lastInserted )
return nil
session . ID = session . Name ( ) + "-" + uuid . NewString ( )
created , modified := utcNowString ( ) , utcNowString ( )
expires := time . Now ( ) . UTC ( ) . Add ( time . Second * time . Duration ( session . Options . MaxAge ) ) . Format ( time . RFC3339 )
_ , err = s . db . exec (
"insert or replace into sessions(id, data, created, modified, expires) values(@id, @data, @created, @modified, @expires)" ,
sql . Named ( "id" , session . ID ) ,
sql . Named ( "data" , encoded . Bytes ( ) ) ,
sql . Named ( "created" , created ) ,
sql . Named ( "modified" , modified ) ,
sql . Named ( "expires" , expires ) ,
)
return err
}
func ( s * dbSessionStore ) save ( session * sessions . Session ) ( err error ) {
if session . IsNew {
return s . insert ( session )
}
delete ( session . Values , sessionCreatedOn )
delete ( session . Values , sessionExpiresOn )
delete ( session . Values , sessionModifiedOn )
encoded , err := securecookie . EncodeMulti ( session . Name ( ) , session . Values , s . codecs ... )
if err != nil {
deleteSessionValuesNotNeededForDb ( session )
var encoded bytes . Buffer
if err = gob . NewEncoder ( & encoded ) . Encode ( session . Values ) ; err != nil {
return err
}
_ , err = s . db . exec ( "update sessions set data = @data, modified = @modified where id = @id" ,
sql . Named ( "data" , encoded ) , sql . Named ( "modified" , utcNowString ( ) ) , sql . Named ( "id" , session . ID ) )
sql . Named ( "data" , encoded . Bytes ( ) ) , sql . Named ( "modified" , utcNowString ( ) ) , sql . Named ( "id" , session . ID ) )
if err != nil {
return err
}
return nil
}
func deleteSessionValuesNotNeededForDb ( session * sessions . Session ) {
delete ( session . Values , sessionCreatedOn )
delete ( session . Values , sessionExpiresOn )
delete ( session . Values , sessionModifiedOn )
}