Rework sessions, config and some tests

This commit is contained in:
Jan-Lukas Else 2021-12-14 17:38:36 +01:00
parent 5c9cd77694
commit 893caf8ec4
12 changed files with 237 additions and 230 deletions

View File

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"path/filepath"
"strings" "strings"
"testing" "testing"
@ -17,32 +16,20 @@ import (
func Test_authMiddleware(t *testing.T) { func Test_authMiddleware(t *testing.T) {
app := &goBlog{ app := &goBlog{
cfg: &config{ cfg: createDefaultTestConfig(t),
Db: &configDb{ }
File: filepath.Join(t.TempDir(), "test.db"), app.cfg.User = &configUser{
}, Nick: "test",
Server: &configServer{ Password: "pass",
PublicAddress: "https://example.com", AppPasswords: []*configAppPassword{
}, {
Blogs: map[string]*configBlog{ Username: "app1",
"en": { Password: "pass1",
Lang: "en",
},
},
DefaultBlog: "en",
User: &configUser{
Nick: "test",
Password: "pass",
AppPasswords: []*configAppPassword{
{
Username: "app1",
Password: "pass1",
},
},
}, },
}, },
} }
_ = app.initConfig()
_ = app.initDatabase(false) _ = app.initDatabase(false)
app.initComponents(false) app.initComponents(false)

View File

@ -1,11 +1,9 @@
package main package main
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -14,23 +12,10 @@ import (
func Test_captchaMiddleware(t *testing.T) { func Test_captchaMiddleware(t *testing.T) {
app := &goBlog{ app := &goBlog{
cfg: &config{ cfg: createDefaultTestConfig(t),
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{
PublicAddress: "https://example.com",
},
Blogs: map[string]*configBlog{
"en": {
Lang: "en",
},
},
DefaultBlog: "en",
User: &configUser{},
},
} }
_ = app.initConfig()
_ = app.initDatabase(false) _ = app.initDatabase(false)
app.initComponents(false) app.initComponents(false)
@ -39,11 +24,9 @@ func Test_captchaMiddleware(t *testing.T) {
})) }))
t.Run("Default", func(t *testing.T) { t.Run("Default", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/abc", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
h.ServeHTTP(rec, req.WithContext(context.WithValue(req.Context(), blogKey, "en"))) h.ServeHTTP(rec, reqWithDefaultBlog(httptest.NewRequest(http.MethodPost, "/abc", nil)))
res := rec.Result() res := rec.Result()
resBody, _ := io.ReadAll(res.Body) resBody, _ := io.ReadAll(res.Body)

160
config.go
View File

@ -27,6 +27,7 @@ type config struct {
EasterEgg *configEasterEgg `mapstructure:"easterEgg"` EasterEgg *configEasterEgg `mapstructure:"easterEgg"`
Debug bool `mapstructure:"debug"` Debug bool `mapstructure:"debug"`
MapTiles *configMapTiles `mapstructure:"mapTiles"` MapTiles *configMapTiles `mapstructure:"mapTiles"`
initialized bool
} }
type configServer struct { type configServer struct {
@ -42,7 +43,6 @@ type configServer struct {
Tor bool `mapstructure:"tor"` Tor bool `mapstructure:"tor"`
SecurityHeaders bool `mapstructure:"securityHeaders"` SecurityHeaders bool `mapstructure:"securityHeaders"`
CSPDomains []string `mapstructure:"cspDomains"` CSPDomains []string `mapstructure:"cspDomains"`
JWTSecret string `mapstructure:"jwtSecret"`
publicHostname string publicHostname string
shortPublicHostname string shortPublicHostname string
mediaHostname string mediaHostname string
@ -287,86 +287,83 @@ type configMapTiles struct {
MaxZoom int `mapstructure:"maxZoom"` MaxZoom int `mapstructure:"maxZoom"`
} }
func (a *goBlog) initConfig(file string) error { func (a *goBlog) loadConfigFile(file string) error {
log.Println("Initialize configuration...") // Use viper to load the config file
v := viper.New()
if file != "" { if file != "" {
// Use config file from the flag // Use config file from the flag
viper.SetConfigFile(file) v.SetConfigFile(file)
} else { } else {
viper.SetConfigName("config") // Search in default locations
viper.AddConfigPath("./config/") v.SetConfigName("config")
v.AddConfigPath("./config/")
} }
err := viper.ReadInConfig() // Read config
if err != nil { if err := v.ReadInConfig(); err != nil {
return err return err
} }
// Defaults
viper.SetDefault("server.logging", false)
viper.SetDefault("server.logFile", "data/access.log")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.publicAddress", "http://localhost:8080")
viper.SetDefault("server.publicHttps", false)
viper.SetDefault("database.file", "data/db.sqlite")
viper.SetDefault("cache.enable", true)
viper.SetDefault("cache.expiration", 600)
viper.SetDefault("user.nick", "admin")
viper.SetDefault("user.password", "secret")
viper.SetDefault("hooks.shell", "/bin/bash")
viper.SetDefault("micropub.categoryParam", "tags")
viper.SetDefault("micropub.replyParam", "replylink")
viper.SetDefault("micropub.replyTitleParam", "replytitle")
viper.SetDefault("micropub.likeParam", "likelink")
viper.SetDefault("micropub.likeTitleParam", "liketitle")
viper.SetDefault("micropub.bookmarkParam", "link")
viper.SetDefault("micropub.audioParam", "audio")
viper.SetDefault("micropub.photoParam", "images")
viper.SetDefault("micropub.photoDescriptionParam", "imagealts")
viper.SetDefault("micropub.locationParam", "location")
viper.SetDefault("activityPub.tagsTaxonomies", []string{"tags"})
// Unmarshal config // Unmarshal config
a.cfg = &config{} a.cfg = createDefaultConfig()
err = viper.Unmarshal(a.cfg) return v.Unmarshal(a.cfg)
if err != nil { }
return err
func (a *goBlog) initConfig() error {
if a.cfg == nil {
a.cfg = createDefaultConfig()
}
if a.cfg.initialized {
return nil
} }
// Check config // Check config
// Parse addresses and hostnames
if a.cfg.Server.PublicAddress == "" {
return errors.New("no public address configured")
}
publicURL, err := url.Parse(a.cfg.Server.PublicAddress) publicURL, err := url.Parse(a.cfg.Server.PublicAddress)
if err != nil { if err != nil {
return err return errors.New("Invalid public address: " + err.Error())
} }
a.cfg.Server.publicHostname = publicURL.Hostname() a.cfg.Server.publicHostname = publicURL.Hostname()
if sa := a.cfg.Server.ShortPublicAddress; sa != "" { if sa := a.cfg.Server.ShortPublicAddress; sa != "" {
shortPublicURL, err := url.Parse(sa) shortPublicURL, err := url.Parse(sa)
if err != nil { if err != nil {
return err return errors.New("Invalid short public address: " + err.Error())
} }
a.cfg.Server.shortPublicHostname = shortPublicURL.Hostname() a.cfg.Server.shortPublicHostname = shortPublicURL.Hostname()
} }
if ma := a.cfg.Server.MediaAddress; ma != "" { if ma := a.cfg.Server.MediaAddress; ma != "" {
mediaUrl, err := url.Parse(ma) mediaUrl, err := url.Parse(ma)
if err != nil { if err != nil {
return err return errors.New("Invalid media address: " + err.Error())
} }
a.cfg.Server.mediaHostname = mediaUrl.Hostname() a.cfg.Server.mediaHostname = mediaUrl.Hostname()
} }
if a.cfg.Server.JWTSecret == "" { // Check if any blog is configured
return errors.New("no JWT secret configured") if a.cfg.Blogs == nil || len(a.cfg.Blogs) == 0 {
} a.cfg.Blogs = map[string]*configBlog{
if len(a.cfg.Blogs) == 0 { "default": createDefaultBlog(),
return errors.New("no blog configured")
}
if len(a.cfg.DefaultBlog) == 0 || a.cfg.Blogs[a.cfg.DefaultBlog] == nil {
return errors.New("no default blog or default blog not present")
}
if a.cfg.Micropub.MediaStorage != nil {
if a.cfg.Micropub.MediaStorage.MediaURL == "" ||
a.cfg.Micropub.MediaStorage.BunnyStorageKey == "" ||
a.cfg.Micropub.MediaStorage.BunnyStorageName == "" {
a.cfg.Micropub.MediaStorage.BunnyStorageKey = ""
a.cfg.Micropub.MediaStorage.BunnyStorageName = ""
} }
a.cfg.Micropub.MediaStorage.MediaURL = strings.TrimSuffix(a.cfg.Micropub.MediaStorage.MediaURL, "/")
} }
// Check if default blog is set
if a.cfg.DefaultBlog == "" {
if len(a.cfg.Blogs) == 1 {
// Set default blog to the only blog that is configured
for k := range a.cfg.Blogs {
a.cfg.DefaultBlog = k
}
} else {
return errors.New("no default blog configured")
}
}
// Check if default blog exists
if a.cfg.Blogs[a.cfg.DefaultBlog] == nil {
return errors.New("default blog does not exist")
}
// Check media storage config
if ms := a.cfg.Micropub.MediaStorage; ms != nil && ms.MediaURL != "" {
ms.MediaURL = strings.TrimSuffix(ms.MediaURL, "/")
}
// Check if webmention receiving is disabled
if wm := a.cfg.Webmention; wm != nil && wm.DisableReceiving { if wm := a.cfg.Webmention; wm != nil && wm.DisableReceiving {
// Disable comments for all blogs // Disable comments for all blogs
for _, b := range a.cfg.Blogs { for _, b := range a.cfg.Blogs {
@ -380,10 +377,65 @@ func (a *goBlog) initConfig(file string) error {
br.Enabled = false br.Enabled = false
} }
} }
// Log success
a.cfg.initialized = true
log.Println("Initialized configuration") log.Println("Initialized configuration")
return nil return nil
} }
func createDefaultConfig() *config {
return &config{
Server: &configServer{
Port: 8080,
PublicAddress: "http://localhost:8080",
},
Db: &configDb{
File: "data/db.sqlite",
},
Cache: &configCache{
Enable: true,
Expiration: 600,
},
User: &configUser{
Nick: "admin",
Password: "secret",
},
Hooks: &configHooks{
Shell: "/bin/bash",
},
Micropub: &configMicropub{
CategoryParam: "tags",
ReplyParam: "replylink",
ReplyTitleParam: "replytitle",
LikeParam: "likelink",
LikeTitleParam: "liketitle",
BookmarkParam: "link",
AudioParam: "audio",
PhotoParam: "images",
PhotoDescriptionParam: "imagealts",
LocationParam: "location",
},
ActivityPub: &configActivityPub{
TagsTaxonomies: []string{"tags"},
},
}
}
func createDefaultBlog() *configBlog {
return &configBlog{
Path: "/",
Lang: "en",
Title: "My Blog",
Description: "Welcome to my blog.",
Sections: map[string]*configSection{
"posts": {
Title: "Posts",
},
},
DefaultSection: "posts",
}
}
func (a *goBlog) httpsConfigured(checkAddress bool) bool { func (a *goBlog) httpsConfigured(checkAddress bool) bool {
return a.cfg.Server.PublicHTTPS || return a.cfg.Server.PublicHTTPS ||
a.cfg.Server.TailscaleHTTPS || a.cfg.Server.TailscaleHTTPS ||

18
config_test.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"context"
"net/http"
"path/filepath"
"testing"
)
func createDefaultTestConfig(t *testing.T) *config {
c := createDefaultConfig()
c.Db.File = filepath.Join(t.TempDir(), "blog.db")
return c
}
func reqWithDefaultBlog(req *http.Request) *http.Request {
return req.WithContext(context.WithValue(req.Context(), blogKey, "default"))
}

3
dbmigrations/00026.sql Normal file
View File

@ -0,0 +1,3 @@
drop table sessions;
create table sessions (id text primary key, data blob, created text default '', modified text default '', expires text default '');
create index sessions_exp on sessions (expires);

View File

@ -23,8 +23,6 @@ server:
securityHeaders: true # Set security HTTP headers (to always use HTTPS etc.) securityHeaders: true # Set security HTTP headers (to always use HTTPS etc.)
cspDomains: # Specify additional domains to allow embedded content with enabled securityHeaders cspDomains: # Specify additional domains to allow embedded content with enabled securityHeaders
- media.example.com - media.example.com
# Cookies
jwtSecret: changeThisWeakSecret # secret to use for cookies (login and captcha)
# Tor # Tor
tor: true # Publish onion service, requires Tor to be installed and available in path tor: true # Publish onion service, requires Tor to be installed and available in path
# Tailscale (see https://tailscale.com) # Tailscale (see https://tailscale.com)

8
go.mod
View File

@ -20,10 +20,9 @@ require (
github.com/emersion/go-smtp v0.15.0 github.com/emersion/go-smtp v0.15.0
github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/chi/v5 v5.0.7
github.com/go-fed/httpsig v1.1.0 github.com/go-fed/httpsig v1.1.0
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.0 github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/handlers v1.5.1 github.com/gorilla/handlers v1.5.1
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1 github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/hacdias/indieauth v1.7.1 github.com/hacdias/indieauth v1.7.1
@ -44,7 +43,7 @@ require (
github.com/spf13/cast v1.4.1 github.com/spf13/cast v1.4.1
github.com/spf13/viper v1.10.0 github.com/spf13/viper v1.10.0
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/tdewolff/minify/v2 v2.9.22 github.com/tdewolff/minify/v2 v2.9.23
github.com/thoas/go-funk v0.9.1 github.com/thoas/go-funk v0.9.1
github.com/tkrajina/gpxgo v1.1.2 github.com/tkrajina/gpxgo v1.1.2
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
@ -87,6 +86,7 @@ require (
github.com/google/btree v1.0.1 // indirect github.com/google/btree v1.0.1 // indirect
github.com/google/go-cmp v0.5.6 // indirect github.com/google/go-cmp v0.5.6 // indirect
github.com/gorilla/css v1.0.0 // indirect github.com/gorilla/css v1.0.0 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/insomniacslk/dhcp v0.0.0-20210621130208-1cac67f12b1e // indirect github.com/insomniacslk/dhcp v0.0.0-20210621130208-1cac67f12b1e // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect
@ -112,7 +112,7 @@ require (
github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect
github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect
github.com/tcnksm/go-httpstat v0.2.0 // indirect github.com/tcnksm/go-httpstat v0.2.0 // indirect
github.com/tdewolff/parse/v2 v2.5.21 // indirect github.com/tdewolff/parse/v2 v2.5.24 // indirect
github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect
github.com/vishvananda/netlink v1.1.1-0.20211101163509-b10eb8fe5cf6 // indirect github.com/vishvananda/netlink v1.1.1-0.20211101163509-b10eb8fe5cf6 // indirect
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect

12
go.sum
View File

@ -139,8 +139,8 @@ github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.0 h1:BtndtqqCQfPsL2uMkYmduOip1+dPcSmh40l82mBUPKk= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.0/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
@ -387,10 +387,10 @@ github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQ
github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0=
github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0=
github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8= github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8=
github.com/tdewolff/minify/v2 v2.9.22 h1:PlmaAakaJHdMMdTTwjjsuSwIxKqWPTlvjTj6a/g/ILU= github.com/tdewolff/minify/v2 v2.9.23 h1:UrLltJpnJPm7/fYFP3Ue/GD5tHufx2z7ERQihACLkmg=
github.com/tdewolff/minify/v2 v2.9.22/go.mod h1:dNlaFdXaIxgSXh3UFASqjTY0/xjpDkkCsYHA1NCGnmQ= github.com/tdewolff/minify/v2 v2.9.23/go.mod h1:4o1Mw4T3RLV0CHUny7OEnntezuwoj/FNst4QzrNxIts=
github.com/tdewolff/parse/v2 v2.5.21 h1:s/OLsVxxmQUlbFtPODDVHA836qchgmoxjEsk/cUZl48= github.com/tdewolff/parse/v2 v2.5.24 h1:sJPG5Viy2lq9NBbnK4KpWEA+17RNZz8EOXVqErHKHgs=
github.com/tdewolff/parse/v2 v2.5.21/go.mod h1:WzaJpRSbwq++EIQHYIRTpbYKNA3gn9it1Ik++q4zyho= github.com/tdewolff/parse/v2 v2.5.24/go.mod h1:WzaJpRSbwq++EIQHYIRTpbYKNA3gn9it1Ik++q4zyho=
github.com/tdewolff/test v1.0.6 h1:76mzYJQ83Op284kMT+63iCNCI7NEERsIN8dLM+RiKr4= github.com/tdewolff/test v1.0.6 h1:76mzYJQ83Op284kMT+63iCNCI7NEERsIN8dLM+RiKr4=
github.com/tdewolff/test v1.0.6/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE= github.com/tdewolff/test v1.0.6/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE=
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=

View File

@ -29,7 +29,7 @@ func newFakeHttpClient() *fakeHttpClient {
} }
// Copy result status code and body // Copy result status code and body
rw.WriteHeader(fc.res.StatusCode) rw.WriteHeader(fc.res.StatusCode)
io.Copy(rw, rec.Body) _, _ = io.Copy(rw, rec.Body)
} }
}), }),
}, },

View File

@ -57,7 +57,11 @@ func main() {
} }
// Initialize config // Initialize config
if err = app.initConfig(*configfile); err != nil { if err = app.loadConfigFile(*configfile); err != nil {
app.logErrAndQuit("Failed to load config file:", err.Error())
return
}
if err = app.initConfig(); err != nil {
app.logErrAndQuit("Failed to init config:", err.Error()) app.logErrAndQuit("Failed to init config:", err.Error())
return return
} }

View File

@ -1,14 +1,16 @@
package main package main
import ( import (
"bytes"
"database/sql" "database/sql"
"fmt" "encoding/gob"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/araddon/dateparse" "github.com/araddon/dateparse"
"github.com/gorilla/securecookie" "github.com/google/uuid"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
) )
@ -30,7 +32,6 @@ func (a *goBlog) initSessions() {
deleteExpiredSessions() deleteExpiredSessions()
a.hourlyHooks = append(a.hourlyHooks, deleteExpiredSessions) a.hourlyHooks = append(a.hourlyHooks, deleteExpiredSessions)
a.loginSessions = &dbSessionStore{ a.loginSessions = &dbSessionStore{
codecs: securecookie.CodecsFromPairs([]byte(a.cfg.Server.JWTSecret)),
options: &sessions.Options{ options: &sessions.Options{
Secure: a.httpsConfigured(true), Secure: a.httpsConfigured(true),
HttpOnly: true, HttpOnly: true,
@ -41,7 +42,6 @@ func (a *goBlog) initSessions() {
db: a.db, db: a.db,
} }
a.captchaSessions = &dbSessionStore{ a.captchaSessions = &dbSessionStore{
codecs: securecookie.CodecsFromPairs([]byte(a.cfg.Server.JWTSecret)),
options: &sessions.Options{ options: &sessions.Options{
Secure: a.httpsConfigured(true), Secure: a.httpsConfigured(true),
HttpOnly: true, HttpOnly: true,
@ -55,7 +55,6 @@ func (a *goBlog) initSessions() {
type dbSessionStore struct { type dbSessionStore struct {
options *sessions.Options options *sessions.Options
codecs []securecookie.Codec
db *database db *database
} }
@ -67,29 +66,33 @@ func (s *dbSessionStore) New(r *http.Request, name string) (session *sessions.Se
session = sessions.NewSession(s, name) session = sessions.NewSession(s, name)
opts := *s.options opts := *s.options
session.Options = &opts session.Options = &opts
session.IsNew = true if c, cErr := r.Cookie(name); cErr == nil && strings.HasPrefix(c.Value, session.Name()+"-") {
if cook, errCookie := r.Cookie(name); errCookie == nil { // Has cookie, load from database
if err = securecookie.DecodeMulti(name, cook.Value, &session.ID, s.codecs...); err == nil { session.ID = c.Value
session.IsNew = s.load(session) == nil if s.load(session) != nil {
// Failed to load session from database, delete the ID (= new session)
session.ID = ""
} }
} }
// If no ID, the session is new
session.IsNew = session.ID == ""
return session, err return session, err
} }
func (s *dbSessionStore) Save(r *http.Request, w http.ResponseWriter, ss *sessions.Session) (err error) { func (s *dbSessionStore) Save(r *http.Request, w http.ResponseWriter, ss *sessions.Session) (err error) {
if ss.ID == "" { if ss.ID == "" {
// Is new session, save it to database
if err = s.insert(ss); err != nil { if err = s.insert(ss); err != nil {
return err return err
} }
} else if err = s.save(ss); err != nil {
return err
}
if encoded, err := securecookie.EncodeMulti(ss.Name(), ss.ID, s.codecs...); err != nil {
return err
} else { } else {
http.SetCookie(w, sessions.NewCookie(ss.Name(), encoded, ss.Options)) // Update existing session
return nil if err = s.save(ss); err != nil {
return err
}
} }
http.SetCookie(w, sessions.NewCookie(ss.Name(), ss.ID, ss.Options))
return nil
} }
func (s *dbSessionStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { func (s *dbSessionStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
@ -106,15 +109,20 @@ func (s *dbSessionStore) Delete(r *http.Request, w http.ResponseWriter, session
} }
func (s *dbSessionStore) load(session *sessions.Session) (err error) { func (s *dbSessionStore) load(session *sessions.Session) (err error) {
row, err := s.db.queryRow("select data, created, modified, expires from sessions where id = @id and expires > @now", sql.Named("id", session.ID), sql.Named("now", utcNowString())) row, err := s.db.queryRow(
"select data, created, modified, expires from sessions where id = @id and expires > @now",
sql.Named("id", session.ID),
sql.Named("now", utcNowString()),
)
if err != nil { if err != nil {
return err return err
} }
var data, createdStr, modifiedStr, expiresStr string var createdStr, modifiedStr, expiresStr string
var data []byte
if err = row.Scan(&data, &createdStr, &modifiedStr, &expiresStr); err != nil { if err = row.Scan(&data, &createdStr, &modifiedStr, &expiresStr); err != nil {
return err return err
} }
if err = securecookie.DecodeMulti(session.Name(), data, &session.Values, s.codecs...); err != nil { if err = gob.NewDecoder(bytes.NewReader(data)).Decode(&session.Values); err != nil {
return err return err
} }
session.Values[sessionCreatedOn] = timeNoErr(dateparse.ParseLocal(createdStr)) session.Values[sessionCreatedOn] = timeNoErr(dateparse.ParseLocal(createdStr))
@ -124,44 +132,44 @@ func (s *dbSessionStore) load(session *sessions.Session) (err error) {
} }
func (s *dbSessionStore) insert(session *sessions.Session) (err error) { func (s *dbSessionStore) insert(session *sessions.Session) (err error) {
created := time.Now().UTC() deleteSessionValuesNotNeededForDb(session)
modified := time.Now().UTC() var encoded bytes.Buffer
expires := time.Now().UTC().Add(time.Second * time.Duration(session.Options.MaxAge)) if err := gob.NewEncoder(&encoded).Encode(session.Values); err != nil {
delete(session.Values, sessionCreatedOn)
delete(session.Values, sessionExpiresOn)
delete(session.Values, sessionModifiedOn)
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, s.codecs...)
if err != nil {
return err return err
} }
res, err := s.db.exec("insert or replace into sessions(data, created, modified, expires) values(@data, @created, @modified, @expires)", session.ID = session.Name() + "-" + uuid.NewString()
sql.Named("data", encoded), sql.Named("created", created.Format(time.RFC3339)), sql.Named("modified", modified.Format(time.RFC3339)), sql.Named("expires", expires.Format(time.RFC3339))) created, modified := utcNowString(), utcNowString()
if err != nil { expires := time.Now().UTC().Add(time.Second * time.Duration(session.Options.MaxAge)).Format(time.RFC3339)
return err _, err = s.db.exec(
} "insert or replace into sessions(id, data, created, modified, expires) values(@id, @data, @created, @modified, @expires)",
lastInserted, err := res.LastInsertId() sql.Named("id", session.ID),
if err != nil { sql.Named("data", encoded.Bytes()),
return err sql.Named("created", created),
} sql.Named("modified", modified),
session.ID = fmt.Sprintf("%d", lastInserted) sql.Named("expires", expires),
return nil )
return err
} }
func (s *dbSessionStore) save(session *sessions.Session) (err error) { func (s *dbSessionStore) save(session *sessions.Session) (err error) {
if session.IsNew { if session.IsNew {
return s.insert(session) return s.insert(session)
} }
delete(session.Values, sessionCreatedOn) deleteSessionValuesNotNeededForDb(session)
delete(session.Values, sessionExpiresOn) var encoded bytes.Buffer
delete(session.Values, sessionModifiedOn) if err = gob.NewEncoder(&encoded).Encode(session.Values); err != nil {
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, s.codecs...)
if err != nil {
return err return err
} }
_, err = s.db.exec("update sessions set data = @data, modified = @modified where id = @id", _, err = s.db.exec("update sessions set data = @data, modified = @modified where id = @id",
sql.Named("data", encoded), sql.Named("modified", utcNowString()), sql.Named("id", session.ID)) sql.Named("data", encoded.Bytes()), sql.Named("modified", utcNowString()), sql.Named("id", session.ID))
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func deleteSessionValuesNotNeededForDb(session *sessions.Session) {
delete(session.Values, sessionCreatedOn)
delete(session.Values, sessionExpiresOn)
delete(session.Values, sessionModifiedOn)
}

View File

@ -2,7 +2,6 @@ package main
import ( import (
"net/http" "net/http"
"path/filepath"
"testing" "testing"
"time" "time"
@ -11,42 +10,15 @@ import (
) )
func Test_configTelegram_enabled(t *testing.T) { func Test_configTelegram_enabled(t *testing.T) {
if (&configTelegram{}).enabled() == true { assert.False(t, (&configTelegram{}).enabled())
t.Error("Telegram shouldn't be enabled")
}
var tg *configTelegram var tg *configTelegram
if tg.enabled() == true { assert.False(t, tg.enabled())
t.Error("Telegram shouldn't be enabled")
}
if (&configTelegram{ assert.False(t, (&configTelegram{Enabled: true}).enabled())
Enabled: true, assert.False(t, (&configTelegram{Enabled: true, ChatID: "abc"}).enabled())
}).enabled() == true { assert.False(t, (&configTelegram{Enabled: true, BotToken: "abc"}).enabled())
t.Error("Telegram shouldn't be enabled")
}
if (&configTelegram{ assert.True(t, (&configTelegram{Enabled: true, ChatID: "abc", BotToken: "abc"}).enabled())
Enabled: true,
ChatID: "abc",
}).enabled() == true {
t.Error("Telegram shouldn't be enabled")
}
if (&configTelegram{
Enabled: true,
BotToken: "abc",
}).enabled() == true {
t.Error("Telegram shouldn't be enabled")
}
if (&configTelegram{
Enabled: true,
BotToken: "abc",
ChatID: "abc",
}).enabled() != true {
t.Error("Telegram should be enabled")
}
} }
func Test_configTelegram_generateHTML(t *testing.T) { func Test_configTelegram_generateHTML(t *testing.T) {
@ -78,11 +50,11 @@ func Test_configTelegram_send(t *testing.T) {
fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" { if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`)) _, _ = rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`))
return return
} }
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`)) _, _ = rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`))
})) }))
tg := &configTelegram{ tg := &configTelegram{
@ -126,38 +98,31 @@ func Test_goBlog_initTelegram(t *testing.T) {
func Test_telegram(t *testing.T) { func Test_telegram(t *testing.T) {
t.Run("Send post to Telegram", func(t *testing.T) { t.Run("Send post to Telegram", func(t *testing.T) {
fakeClient := newFakeHttpClient() fakeClient := newFakeHttpClient()
fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { fakeClient.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" { if r.URL.String() == "https://api.telegram.org/botbottoken/getMe" {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`)) _, _ = rw.Write([]byte(`{"ok":true,"result":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"}}`))
return return
} }
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":123456789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`)) _, _ = rw.Write([]byte(`{"ok":true,"result":{"message_id":123,"from":{"id":123456789,"is_bot":true,"first_name":"Test","username":"testbot"},"chat":{"id":123456789,"first_name":"Test","username":"testbot"},"date":1564181818,"text":"Message"}}`))
})) }))
cfg := createDefaultTestConfig(t)
cfg.Blogs = map[string]*configBlog{
"en": createDefaultBlog(),
}
cfg.Blogs["en"].Telegram = &configTelegram{
Enabled: true,
ChatID: "chatid",
BotToken: "bottoken",
}
app := &goBlog{ app := &goBlog{
pPostHooks: []postHookFunc{}, cfg: cfg,
cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{
PublicAddress: "https://example.com",
},
Blogs: map[string]*configBlog{
"en": {
Telegram: &configTelegram{
Enabled: true,
ChatID: "chatid",
BotToken: "bottoken",
},
},
},
},
httpClient: fakeClient.Client, httpClient: fakeClient.Client,
} }
_ = app.initConfig()
_ = app.initDatabase(false) _ = app.initDatabase(false)
app.initMarkdown() app.initMarkdown()
@ -179,43 +144,32 @@ func Test_telegram(t *testing.T) {
req := fakeClient.req req := fakeClient.req
assert.Equal(t, "chatid", req.FormValue("chat_id")) assert.Equal(t, "chatid", req.FormValue("chat_id"))
assert.Equal(t, "HTML", req.FormValue("parse_mode")) assert.Equal(t, "HTML", req.FormValue("parse_mode"))
assert.Equal(t, "Title\n\n<a href=\"https://example.com/s/1\">https://example.com/s/1</a>", req.FormValue("text")) assert.Equal(t, "Title\n\n<a href=\"http://localhost:8080/s/1\">http://localhost:8080/s/1</a>", req.FormValue("text"))
}) })
t.Run("Telegram disabled", func(t *testing.T) { t.Run("Telegram disabled", func(t *testing.T) {
fakeClient := newFakeHttpClient() fakeClient := newFakeHttpClient()
app := &goBlog{ app := &goBlog{
pPostHooks: []postHookFunc{}, cfg: createDefaultTestConfig(t),
cfg: &config{
Db: &configDb{
File: filepath.Join(t.TempDir(), "test.db"),
},
Server: &configServer{
PublicAddress: "https://example.com",
},
Blogs: map[string]*configBlog{
"en": {},
},
},
httpClient: fakeClient.Client, httpClient: fakeClient.Client,
} }
_ = app.initConfig()
_ = app.initDatabase(false) _ = app.initDatabase(false)
app.initTelegram() app.initTelegram()
p := &post{ app.postPostHooks(&post{
Path: "/test", Path: "/test",
Parameters: map[string][]string{ Parameters: map[string][]string{
"title": {"Title"}, "title": {"Title"},
}, },
Published: time.Now().String(), Published: time.Now().String(),
Section: "test", Section: "test",
Blog: "en", Blog: "default",
Status: statusPublished, Status: statusPublished,
} })
app.pPostHooks[0](p)
assert.Nil(t, fakeClient.req) assert.Nil(t, fakeClient.req)
}) })