diff --git a/.vscode/tasks.json b/.vscode/tasks.json index a2a0a4b..ea6c6ac 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -1,15 +1,33 @@ { - "version": "2.0.0", - "tasks": [ - { - "label": "Build", - "type": "shell", - "command": "go build --tags \"libsqlite3 linux sqlite_fts5\"", - "problemMatcher": [], - "group": { - "kind": "build", - "isDefault": true + "version": "2.0.0", + "tasks": [ + { + "label": "Build", + "type": "shell", + "command": "go build", + "options": { + "env": { + "GOFLAGS": "-tags=linux,libsqlite3,sqlite_fts5" } + }, + "group": { + "kind": "build", + "isDefault": true } - ] - } \ No newline at end of file + }, + { + "label": "Test", + "type": "shell", + "command": "go test", + "options": { + "env": { + "GOFLAGS": "-tags=linux,libsqlite3,sqlite_fts5" + } + }, + "group": { + "kind": "test", + "isDefault": true + } + } + ] +} \ No newline at end of file diff --git a/activityPub.go b/activityPub.go index 1612be5..8d97b57 100644 --- a/activityPub.go +++ b/activityPub.go @@ -1,7 +1,6 @@ package main import ( - "crypto/rsa" "crypto/x509" "database/sql" "encoding/json" @@ -14,50 +13,41 @@ import ( "net/url" "os" "strings" - "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-fed/httpsig" ) -var ( - apPrivateKey *rsa.PrivateKey - apPostSigner httpsig.Signer - apPostSignMutex *sync.Mutex = &sync.Mutex{} - webfingerResources map[string]*configBlog - webfingerAccts map[string]string -) - -func initActivityPub() error { - if !appConfig.ActivityPub.Enabled { +func (a *goBlog) initActivityPub() error { + if !a.cfg.ActivityPub.Enabled { return nil } // Add hooks - postPostHooks = append(postPostHooks, func(p *post) { + a.pPostHooks = append(a.pPostHooks, func(p *post) { if p.isPublishedSectionPost() { - p.apPost() + a.apPost(p) } }) - postUpdateHooks = append(postUpdateHooks, func(p *post) { + a.pUpdateHooks = append(a.pUpdateHooks, func(p *post) { if p.isPublishedSectionPost() { - p.apUpdate() + a.apUpdate(p) } }) - postDeleteHooks = append(postDeleteHooks, func(p *post) { - p.apDelete() + a.pDeleteHooks = append(a.pDeleteHooks, func(p *post) { + a.apDelete(p) }) // Prepare webfinger - webfingerResources = map[string]*configBlog{} - webfingerAccts = map[string]string{} - for name, blog := range appConfig.Blogs { - acct := "acct:" + name + "@" + appConfig.Server.publicHostname - webfingerResources[acct] = blog - webfingerResources[blog.apIri()] = blog - webfingerAccts[blog.apIri()] = acct + a.webfingerResources = map[string]*configBlog{} + a.webfingerAccts = map[string]string{} + for name, blog := range a.cfg.Blogs { + acct := "acct:" + name + "@" + a.cfg.Server.publicHostname + a.webfingerResources[acct] = blog + a.webfingerResources[a.apIri(blog)] = blog + a.webfingerAccts[a.apIri(blog)] = acct } // Read key and prepare signing - pkfile, err := os.ReadFile(appConfig.ActivityPub.KeyPath) + pkfile, err := os.ReadFile(a.cfg.ActivityPub.KeyPath) if err != nil { return err } @@ -65,11 +55,11 @@ func initActivityPub() error { if privateKeyDecoded == nil { return errors.New("failed to decode private key") } - apPrivateKey, err = x509.ParsePKCS1PrivateKey(privateKeyDecoded.Bytes) + a.apPrivateKey, err = x509.ParsePKCS1PrivateKey(privateKeyDecoded.Bytes) if err != nil { return err } - apPostSigner, _, err = httpsig.NewSigner( + a.apPostSigner, _, err = httpsig.NewSigner( []httpsig.Algorithm{httpsig.RSA_SHA256}, httpsig.DigestSha256, []string{httpsig.RequestTarget, "date", "host", "digest"}, @@ -80,32 +70,32 @@ func initActivityPub() error { return err } // Init send queue - initAPSendQueue() + a.initAPSendQueue() return nil } -func apHandleWebfinger(w http.ResponseWriter, r *http.Request) { - blog, ok := webfingerResources[r.URL.Query().Get("resource")] +func (a *goBlog) apHandleWebfinger(w http.ResponseWriter, r *http.Request) { + blog, ok := a.webfingerResources[r.URL.Query().Get("resource")] if !ok { - serveError(w, r, "Resource not found", http.StatusNotFound) + a.serveError(w, r, "Resource not found", http.StatusNotFound) return } b, _ := json.Marshal(map[string]interface{}{ - "subject": webfingerAccts[blog.apIri()], + "subject": a.webfingerAccts[a.apIri(blog)], "aliases": []string{ - webfingerAccts[blog.apIri()], - blog.apIri(), + a.webfingerAccts[a.apIri(blog)], + a.apIri(blog), }, "links": []map[string]string{ { "rel": "self", "type": contentTypeAS, - "href": blog.apIri(), + "href": a.apIri(blog), }, { "rel": "http://webfinger.net/rel/profile-page", "type": "text/html", - "href": blog.apIri(), + "href": a.apIri(blog), }, }, }) @@ -113,19 +103,19 @@ func apHandleWebfinger(w http.ResponseWriter, r *http.Request) { _, _ = writeMinified(w, contentTypeJSON, b) } -func apHandleInbox(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) apHandleInbox(w http.ResponseWriter, r *http.Request) { blogName := chi.URLParam(r, "blog") - blog := appConfig.Blogs[blogName] + blog := a.cfg.Blogs[blogName] if blog == nil { - serveError(w, r, "Inbox not found", http.StatusNotFound) + a.serveError(w, r, "Inbox not found", http.StatusNotFound) return } - blogIri := blog.apIri() + blogIri := a.apIri(blog) // Verify request requestActor, requestKey, requestActorStatus, err := apVerifySignature(r) if err != nil { // Send 401 because signature could not be verified - serveError(w, r, err.Error(), http.StatusUnauthorized) + a.serveError(w, r, err.Error(), http.StatusUnauthorized) return } if requestActorStatus != 0 { @@ -134,12 +124,12 @@ func apHandleInbox(w http.ResponseWriter, r *http.Request) { if err == nil { u.Fragment = "" u.RawFragment = "" - _ = apRemoveFollower(blogName, u.String()) + _ = a.db.apRemoveFollower(blogName, u.String()) w.WriteHeader(http.StatusOK) return } } - serveError(w, r, "Error when trying to get request actor", http.StatusBadRequest) + a.serveError(w, r, "Error when trying to get request actor", http.StatusBadRequest) return } // Parse activity @@ -147,29 +137,29 @@ func apHandleInbox(w http.ResponseWriter, r *http.Request) { err = json.NewDecoder(r.Body).Decode(&activity) _ = r.Body.Close() if err != nil { - serveError(w, r, "Failed to decode body", http.StatusBadRequest) + a.serveError(w, r, "Failed to decode body", http.StatusBadRequest) return } // Get and check activity actor activityActor, ok := activity["actor"].(string) if !ok { - serveError(w, r, "actor in activity is no string", http.StatusBadRequest) + a.serveError(w, r, "actor in activity is no string", http.StatusBadRequest) return } if activityActor != requestActor.ID { - serveError(w, r, "Request actor isn't activity actor", http.StatusForbidden) + a.serveError(w, r, "Request actor isn't activity actor", http.StatusForbidden) return } // Do switch activity["type"] { case "Follow": - apAccept(blogName, blog, activity) + a.apAccept(blogName, blog, activity) case "Undo": { if object, ok := activity["object"].(map[string]interface{}); ok { if objectType, ok := object["type"].(string); ok && objectType == "Follow" { if iri, ok := object["actor"].(string); ok && iri == activityActor { - _ = apRemoveFollower(blogName, activityActor) + _ = a.db.apRemoveFollower(blogName, activityActor) } } } @@ -181,13 +171,13 @@ func apHandleInbox(w http.ResponseWriter, r *http.Request) { id, hasID := object["id"].(string) if hasReplyToString && hasID && len(inReplyTo) > 0 && len(id) > 0 && strings.Contains(inReplyTo, blogIri) { // It's an ActivityPub reply; save reply as webmention - _ = createWebmention(id, inReplyTo) + _ = a.createWebmention(id, inReplyTo) } else if content, hasContent := object["content"].(string); hasContent && hasID && len(id) > 0 { // May be a mention; find links to blog and save them as webmentions if links, err := allLinksFromHTML(strings.NewReader(content), id); err == nil { for _, link := range links { if strings.Contains(link, blogIri) { - _ = createWebmention(id, link) + _ = a.createWebmention(id, link) } } } @@ -198,21 +188,21 @@ func apHandleInbox(w http.ResponseWriter, r *http.Request) { case "Block": { if object, ok := activity["object"].(string); ok && len(object) > 0 && object == activityActor { - _ = apRemoveFollower(blogName, activityActor) + _ = a.db.apRemoveFollower(blogName, activityActor) } } case "Like": { likeObject, likeObjectOk := activity["object"].(string) if likeObjectOk && len(likeObject) > 0 && strings.Contains(likeObject, blogIri) { - sendNotification(fmt.Sprintf("%s liked %s", activityActor, likeObject)) + a.sendNotification(fmt.Sprintf("%s liked %s", activityActor, likeObject)) } } case "Announce": { announceObject, announceObjectOk := activity["object"].(string) if announceObjectOk && len(announceObject) > 0 && strings.Contains(announceObject, blogIri) { - sendNotification(fmt.Sprintf("%s announced %s", activityActor, announceObject)) + a.sendNotification(fmt.Sprintf("%s announced %s", activityActor, announceObject)) } } } @@ -277,8 +267,8 @@ func apGetRemoteActor(iri string) (*asPerson, int, error) { return actor, 0, nil } -func apGetAllInboxes(blog string) ([]string, error) { - rows, err := appDb.query("select distinct inbox from activitypub_followers where blog = @blog", sql.Named("blog", blog)) +func (db *database) apGetAllInboxes(blog string) ([]string, error) { + rows, err := db.query("select distinct inbox from activitypub_followers where blog = @blog", sql.Named("blog", blog)) if err != nil { return nil, err } @@ -294,27 +284,27 @@ func apGetAllInboxes(blog string) ([]string, error) { return inboxes, nil } -func apAddFollower(blog, follower, inbox string) error { - _, err := appDb.exec("insert or replace into activitypub_followers (blog, follower, inbox) values (@blog, @follower, @inbox)", sql.Named("blog", blog), sql.Named("follower", follower), sql.Named("inbox", inbox)) +func (db *database) apAddFollower(blog, follower, inbox string) error { + _, err := db.exec("insert or replace into activitypub_followers (blog, follower, inbox) values (@blog, @follower, @inbox)", sql.Named("blog", blog), sql.Named("follower", follower), sql.Named("inbox", inbox)) return err } -func apRemoveFollower(blog, follower string) error { - _, err := appDb.exec("delete from activitypub_followers where blog = @blog and follower = @follower", sql.Named("blog", blog), sql.Named("follower", follower)) +func (db *database) apRemoveFollower(blog, follower string) error { + _, err := db.exec("delete from activitypub_followers where blog = @blog and follower = @follower", sql.Named("blog", blog), sql.Named("follower", follower)) return err } -func apRemoveInbox(inbox string) error { - _, err := appDb.exec("delete from activitypub_followers where inbox = @inbox", sql.Named("inbox", inbox)) +func (db *database) apRemoveInbox(inbox string) error { + _, err := db.exec("delete from activitypub_followers where inbox = @inbox", sql.Named("inbox", inbox)) return err } -func (p *post) apPost() { - n := p.toASNote() - apSendToAllFollowers(p.Blog, map[string]interface{}{ +func (a *goBlog) apPost(p *post) { + n := a.toASNote(p) + a.apSendToAllFollowers(p.Blog, map[string]interface{}{ "@context": asContext, - "actor": appConfig.Blogs[p.Blog].apIri(), - "id": p.fullURL(), + "actor": a.apIri(a.cfg.Blogs[p.Blog]), + "id": a.fullPostURL(p), "published": n.Published, "type": "Create", "object": n, @@ -322,46 +312,46 @@ func (p *post) apPost() { if n.InReplyTo != "" { // Is reply, so announce it time.Sleep(30 * time.Second) - p.apAnnounce() + a.apAnnounce(p) } } -func (p *post) apUpdate() { - apSendToAllFollowers(p.Blog, map[string]interface{}{ +func (a *goBlog) apUpdate(p *post) { + a.apSendToAllFollowers(p.Blog, map[string]interface{}{ "@context": asContext, - "actor": appConfig.Blogs[p.Blog].apIri(), - "id": p.fullURL(), + "actor": a.apIri(a.cfg.Blogs[p.Blog]), + "id": a.fullPostURL(p), "published": time.Now().Format("2006-01-02T15:04:05-07:00"), "type": "Update", - "object": p.toASNote(), + "object": a.toASNote(p), }) } -func (p *post) apAnnounce() { - apSendToAllFollowers(p.Blog, map[string]interface{}{ +func (a *goBlog) apAnnounce(p *post) { + a.apSendToAllFollowers(p.Blog, map[string]interface{}{ "@context": asContext, - "actor": appConfig.Blogs[p.Blog].apIri(), - "id": p.fullURL() + "#announce", - "published": p.toASNote().Published, + "actor": a.apIri(a.cfg.Blogs[p.Blog]), + "id": a.fullPostURL(p) + "#announce", + "published": a.toASNote(p).Published, "type": "Announce", - "object": p.fullURL(), + "object": a.fullPostURL(p), }) } -func (p *post) apDelete() { - apSendToAllFollowers(p.Blog, map[string]interface{}{ +func (a *goBlog) apDelete(p *post) { + a.apSendToAllFollowers(p.Blog, map[string]interface{}{ "@context": asContext, - "actor": appConfig.Blogs[p.Blog].apIri(), - "id": p.fullURL() + "#delete", + "actor": a.apIri(a.cfg.Blogs[p.Blog]), + "id": a.fullPostURL(p) + "#delete", "type": "Delete", "object": map[string]string{ - "id": p.fullURL(), + "id": a.fullPostURL(p), "type": "Tombstone", }, }) } -func apAccept(blogName string, blog *configBlog, follow map[string]interface{}) { +func (a *goBlog) apAccept(blogName string, blog *configBlog, follow map[string]interface{}) { // it's a follow, write it down newFollower := follow["actor"].(string) log.Println("New follow request:", newFollower) @@ -381,7 +371,7 @@ func apAccept(blogName string, blog *configBlog, follow map[string]interface{}) if endpoints := follower.Endpoints; endpoints != nil && endpoints.SharedInbox != "" { inbox = endpoints.SharedInbox } - if err = apAddFollower(blogName, follower.ID, inbox); err != nil { + if err = a.db.apAddFollower(blogName, follower.ID, inbox); err != nil { return } // remove @context from the inner activity @@ -389,37 +379,37 @@ func apAccept(blogName string, blog *configBlog, follow map[string]interface{}) accept := map[string]interface{}{ "@context": asContext, "to": follow["actor"], - "actor": blog.apIri(), + "actor": a.apIri(blog), "object": follow, "type": "Accept", } - _, accept["id"] = apNewID(blog) - _ = apQueueSendSigned(blog.apIri(), follower.Inbox, accept) + _, accept["id"] = a.apNewID(blog) + _ = a.db.apQueueSendSigned(a.apIri(blog), follower.Inbox, accept) } -func apSendToAllFollowers(blog string, activity interface{}) { - inboxes, err := apGetAllInboxes(blog) +func (a *goBlog) apSendToAllFollowers(blog string, activity interface{}) { + inboxes, err := a.db.apGetAllInboxes(blog) if err != nil { log.Println("Failed to retrieve inboxes:", err.Error()) return } - apSendTo(appConfig.Blogs[blog].apIri(), activity, inboxes) + a.db.apSendTo(a.apIri(a.cfg.Blogs[blog]), activity, inboxes) } -func apSendTo(blogIri string, activity interface{}, inboxes []string) { +func (db *database) apSendTo(blogIri string, activity interface{}, inboxes []string) { for _, i := range inboxes { go func(inbox string) { - _ = apQueueSendSigned(blogIri, inbox, activity) + _ = db.apQueueSendSigned(blogIri, inbox, activity) }(i) } } -func apNewID(blog *configBlog) (hash string, url string) { - return hash, blog.apIri() + generateRandomString(16) +func (a *goBlog) apNewID(blog *configBlog) (hash string, url string) { + return hash, a.apIri(blog) + generateRandomString(16) } -func (b *configBlog) apIri() string { - return appConfig.Server.PublicAddress + b.Path +func (a *goBlog) apIri(b *configBlog) string { + return a.cfg.Server.PublicAddress + b.Path } func apRequestIsSuccess(code int) bool { diff --git a/activityPubSending.go b/activityPubSending.go index 513d73a..6a146e4 100644 --- a/activityPubSending.go +++ b/activityPubSending.go @@ -19,10 +19,10 @@ type apRequest struct { Try int } -func initAPSendQueue() { +func (a *goBlog) initAPSendQueue() { go func() { for { - qi, err := peekQueue("ap") + qi, err := a.db.peekQueue("ap") if err != nil { log.Println(err.Error()) continue @@ -31,22 +31,22 @@ func initAPSendQueue() { err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&r) if err != nil { log.Println(err.Error()) - _ = qi.dequeue() + _ = a.db.dequeue(qi) continue } - if err := apSendSigned(r.BlogIri, r.To, r.Activity); err != nil { + if err := a.apSendSigned(r.BlogIri, r.To, r.Activity); err != nil { if r.Try++; r.Try < 20 { // Try it again qi.content, _ = r.encode() - _ = qi.reschedule(time.Duration(r.Try) * 10 * time.Minute) + _ = a.db.reschedule(qi, time.Duration(r.Try)*10*time.Minute) continue } else { log.Printf("Request to %s failed for the 20th time", r.To) log.Println() - _ = apRemoveInbox(r.To) + _ = a.db.apRemoveInbox(r.To) } } - err = qi.dequeue() + err = a.db.dequeue(qi) if err != nil { log.Println(err.Error()) } @@ -58,7 +58,7 @@ func initAPSendQueue() { }() } -func apQueueSendSigned(blogIri, to string, activity interface{}) error { +func (db *database) apQueueSendSigned(blogIri, to string, activity interface{}) error { body, err := json.Marshal(activity) if err != nil { return err @@ -71,7 +71,7 @@ func apQueueSendSigned(blogIri, to string, activity interface{}) error { if err != nil { return err } - return enqueue("ap", b, time.Now()) + return db.enqueue("ap", b, time.Now()) } func (r *apRequest) encode() ([]byte, error) { @@ -83,7 +83,7 @@ func (r *apRequest) encode() ([]byte, error) { return buf.Bytes(), nil } -func apSendSigned(blogIri, to string, activity []byte) error { +func (a *goBlog) apSendSigned(blogIri, to string, activity []byte) error { // Create request context with timeout ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -105,9 +105,9 @@ func apSendSigned(blogIri, to string, activity []byte) error { r.Header.Set(contentType, contentTypeASUTF8) r.Header.Set("Host", iri.Host) // Sign request - apPostSignMutex.Lock() - err = apPostSigner.SignRequest(apPrivateKey, blogIri+"#main-key", r, activity) - apPostSignMutex.Unlock() + a.apPostSignMutex.Lock() + err = a.apPostSigner.SignRequest(a.apPrivateKey, blogIri+"#main-key", r, activity) + a.apPostSignMutex.Unlock() if err != nil { return err } diff --git a/activityStreams.go b/activityStreams.go index 3134126..65e559b 100644 --- a/activityStreams.go +++ b/activityStreams.go @@ -22,9 +22,9 @@ var asCheckMediaTypes = []contenttype.MediaType{ const asRequestKey requestContextKey = "asRequest" -func checkActivityStreamsRequest(next http.Handler) http.Handler { +func (a *goBlog) checkActivityStreamsRequest(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if ap := appConfig.ActivityPub; ap != nil && ap.Enabled { + if ap := a.cfg.ActivityPub; ap != nil && ap.Enabled { // Check if accepted media type is not HTML if mt, _, err := contenttype.GetAcceptableMediaType(r, asCheckMediaTypes); err == nil && mt.String() != asCheckMediaTypes[0].String() { next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), asRequestKey, true))) @@ -87,21 +87,21 @@ type asEndpoints struct { SharedInbox string `json:"sharedInbox,omitempty"` } -func (p *post) serveActivityStreams(w http.ResponseWriter) { - b, _ := json.Marshal(p.toASNote()) +func (a *goBlog) serveActivityStreamsPost(p *post, w http.ResponseWriter) { + b, _ := json.Marshal(a.toASNote(p)) w.Header().Set(contentType, contentTypeASUTF8) _, _ = writeMinified(w, contentTypeAS, b) } -func (p *post) toASNote() *asNote { +func (a *goBlog) toASNote(p *post) *asNote { // Create a Note object as := &asNote{ Context: asContext, To: []string{"https://www.w3.org/ns/activitystreams#Public"}, MediaType: contentTypeHTML, - ID: p.fullURL(), - URL: p.fullURL(), - AttributedTo: appConfig.Blogs[p.Blog].apIri(), + ID: a.fullPostURL(p), + URL: a.fullPostURL(p), + AttributedTo: a.apIri(a.cfg.Blogs[p.Blog]), } // Name and Type if title := p.title(); title != "" { @@ -111,9 +111,9 @@ func (p *post) toASNote() *asNote { as.Type = "Note" } // Content - as.Content = string(p.absoluteHTML()) + as.Content = string(a.absoluteHTML(p)) // Attachments - if images := p.Parameters[appConfig.Micropub.PhotoParam]; len(images) > 0 { + if images := p.Parameters[a.cfg.Micropub.PhotoParam]; len(images) > 0 { for _, image := range images { as.Attachment = append(as.Attachment, &asAttachment{ Type: "Image", @@ -122,12 +122,12 @@ func (p *post) toASNote() *asNote { } } // Tags - for _, tagTax := range appConfig.ActivityPub.TagsTaxonomies { + for _, tagTax := range a.cfg.ActivityPub.TagsTaxonomies { for _, tag := range p.Parameters[tagTax] { as.Tag = append(as.Tag, &asTag{ Type: "Hashtag", Name: tag, - Href: appConfig.Server.PublicAddress + appConfig.Blogs[p.Blog].getRelativePath(fmt.Sprintf("/%s/%s", tagTax, urlize(tag))), + Href: a.cfg.Server.PublicAddress + a.cfg.Blogs[p.Blog].getRelativePath(fmt.Sprintf("/%s/%s", tagTax, urlize(tag))), }) } } @@ -144,30 +144,31 @@ func (p *post) toASNote() *asNote { } } // Reply - if replyLink := p.firstParameter(appConfig.Micropub.ReplyParam); replyLink != "" { + if replyLink := p.firstParameter(a.cfg.Micropub.ReplyParam); replyLink != "" { as.InReplyTo = replyLink } return as } -func (b *configBlog) serveActivityStreams(blog string, w http.ResponseWriter, r *http.Request) { - publicKeyDer, err := x509.MarshalPKIXPublicKey(&apPrivateKey.PublicKey) +func (a *goBlog) serveActivityStreams(blog string, w http.ResponseWriter, r *http.Request) { + b := a.cfg.Blogs[blog] + publicKeyDer, err := x509.MarshalPKIXPublicKey(&(a.apPrivateKey.PublicKey)) if err != nil { - serveError(w, r, "Failed to marshal public key", http.StatusInternalServerError) + a.serveError(w, r, "Failed to marshal public key", http.StatusInternalServerError) return } asBlog := &asPerson{ Context: asContext, Type: "Person", - ID: b.apIri(), - URL: b.apIri(), + ID: a.apIri(b), + URL: a.apIri(b), Name: b.Title, Summary: b.Description, PreferredUsername: blog, - Inbox: appConfig.Server.PublicAddress + "/activitypub/inbox/" + blog, + Inbox: a.cfg.Server.PublicAddress + "/activitypub/inbox/" + blog, PublicKey: &asPublicKey{ - Owner: b.apIri(), - ID: b.apIri() + "#main-key", + Owner: a.apIri(b), + ID: a.apIri(b) + "#main-key", PublicKeyPem: string(pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Headers: nil, @@ -176,10 +177,10 @@ func (b *configBlog) serveActivityStreams(blog string, w http.ResponseWriter, r }, } // Add profile picture - if appConfig.User.Picture != "" { + if a.cfg.User.Picture != "" { asBlog.Icon = &asAttachment{ Type: "Image", - URL: appConfig.User.Picture, + URL: a.cfg.User.Picture, } } jb, _ := json.Marshal(asBlog) diff --git a/app.go b/app.go new file mode 100644 index 0000000..fe964c1 --- /dev/null +++ b/app.go @@ -0,0 +1,73 @@ +package main + +import ( + "crypto/rsa" + "html/template" + "net/http" + "sync" + + ts "git.jlel.se/jlelse/template-strings" + "github.com/go-chi/chi/v5" + "github.com/go-fed/httpsig" + rotatelogs "github.com/lestrrat-go/file-rotatelogs" + "github.com/yuin/goldmark" + "golang.org/x/sync/singleflight" +) + +type goBlog struct { + // ActivityPub + apPrivateKey *rsa.PrivateKey + apPostSigner httpsig.Signer + apPostSignMutex sync.Mutex + webfingerResources map[string]*configBlog + webfingerAccts map[string]string + // Assets + assetFileNames map[string]string + assetFiles map[string]*assetFile + // Blogroll + blogrollCacheGroup singleflight.Group + // Cache + cache *cache + // Config + cfg *config + // Database + db *database + // Hooks + pPostHooks []postHookFunc + pUpdateHooks []postHookFunc + pDeleteHooks []postHookFunc + // HTTP + d *dynamicHandler + privateMode bool + privateModeHandler []func(http.Handler) http.Handler + captchaHandler http.Handler + micropubRouter *chi.Mux + indieAuthRouter *chi.Mux + webmentionsRouter *chi.Mux + notificationsRouter *chi.Mux + activitypubRouter *chi.Mux + editorRouter *chi.Mux + commentsRouter *chi.Mux + searchRouter *chi.Mux + setBlogMiddlewares map[string]func(http.Handler) http.Handler + sectionMiddlewares map[string]func(http.Handler) http.Handler + taxonomyMiddlewares map[string]func(http.Handler) http.Handler + photosMiddlewares map[string]func(http.Handler) http.Handler + searchMiddlewares map[string]func(http.Handler) http.Handler + customPagesMiddlewares map[string]func(http.Handler) http.Handler + commentsMiddlewares map[string]func(http.Handler) http.Handler + // Logs + logf *rotatelogs.RotateLogs + // Markdown + md, absoluteMd goldmark.Markdown + // Regex Redirects + regexRedirects []*regexRedirect + // Rendering + templates map[string]*template.Template + // Sessions + loginSessions, captchaSessions *dbSessionStore + // Template strings + ts *ts.TemplateStrings + // Tor + torAddress string +} diff --git a/authentication.go b/authentication.go index cf5a3f5..7e132d3 100644 --- a/authentication.go +++ b/authentication.go @@ -11,14 +11,14 @@ import ( "github.com/pquerna/otp/totp" ) -func checkCredentials(username, password, totpPasscode string) bool { - return username == appConfig.User.Nick && - password == appConfig.User.Password && - (appConfig.User.TOTP == "" || totp.Validate(totpPasscode, appConfig.User.TOTP)) +func (a *goBlog) checkCredentials(username, password, totpPasscode string) bool { + return username == a.cfg.User.Nick && + password == a.cfg.User.Password && + (a.cfg.User.TOTP == "" || totp.Validate(totpPasscode, a.cfg.User.TOTP)) } -func checkAppPasswords(username, password string) bool { - for _, apw := range appConfig.User.AppPasswords { +func (a *goBlog) checkAppPasswords(username, password string) bool { + for _, apw := range a.cfg.User.AppPasswords { if apw.Username == username && apw.Password == password { return true } @@ -26,11 +26,11 @@ func checkAppPasswords(username, password string) bool { return false } -func jwtKey() []byte { - return []byte(appConfig.Server.JWTSecret) +func (a *goBlog) jwtKey() []byte { + return []byte(a.cfg.Server.JWTSecret) } -func authMiddleware(next http.Handler) http.Handler { +func (a *goBlog) authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 1. Check if already logged in if loggedIn, ok := r.Context().Value(loggedInKey).(bool); ok && loggedIn { @@ -38,12 +38,12 @@ func authMiddleware(next http.Handler) http.Handler { return } // 2. Check BasicAuth (just for app passwords) - if username, password, ok := r.BasicAuth(); ok && checkAppPasswords(username, password) { + if username, password, ok := r.BasicAuth(); ok && a.checkAppPasswords(username, password) { next.ServeHTTP(w, r) return } // 3. Check login cookie - if checkLoginCookie(r) { + if a.checkLoginCookie(r) { next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), loggedInKey, true))) return } @@ -57,12 +57,12 @@ func authMiddleware(next http.Handler) http.Handler { _ = r.ParseForm() b = []byte(r.PostForm.Encode()) } - render(w, r, templateLogin, &renderData{ + a.render(w, r, templateLogin, &renderData{ Data: map[string]interface{}{ "loginmethod": r.Method, "loginheaders": base64.StdEncoding.EncodeToString(h), "loginbody": base64.StdEncoding.EncodeToString(b), - "totp": appConfig.User.TOTP != "", + "totp": a.cfg.User.TOTP != "", }, }) }) @@ -70,9 +70,9 @@ func authMiddleware(next http.Handler) http.Handler { const loggedInKey requestContextKey = "loggedIn" -func checkLoggedIn(next http.Handler) http.Handler { +func (a *goBlog) checkLoggedIn(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if checkLoginCookie(r) { + if a.checkLoginCookie(r) { next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), loggedInKey, true))) return } @@ -80,8 +80,8 @@ func checkLoggedIn(next http.Handler) http.Handler { }) } -func checkLoginCookie(r *http.Request) bool { - ses, err := loginSessionsStore.Get(r, "l") +func (a *goBlog) checkLoginCookie(r *http.Request) bool { + ses, err := a.loginSessions.Get(r, "l") if err == nil && ses != nil { if login, ok := ses.Values["login"]; ok && login.(bool) { return true @@ -90,15 +90,15 @@ func checkLoginCookie(r *http.Request) bool { return false } -func checkIsLogin(next http.Handler) http.Handler { +func (a *goBlog) checkIsLogin(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if !checkLogin(rw, r) { + if !a.checkLogin(rw, r) { next.ServeHTTP(rw, r) } }) } -func checkLogin(w http.ResponseWriter, r *http.Request) bool { +func (a *goBlog) checkLogin(w http.ResponseWriter, r *http.Request) bool { if r.Method != http.MethodPost { return false } @@ -109,8 +109,8 @@ func checkLogin(w http.ResponseWriter, r *http.Request) bool { return false } // Check credential - if !checkCredentials(r.FormValue("username"), r.FormValue("password"), r.FormValue("token")) { - serveError(w, r, "Incorrect credentials", http.StatusUnauthorized) + if !a.checkCredentials(r.FormValue("username"), r.FormValue("password"), r.FormValue("token")) { + a.serveError(w, r, "Incorrect credentials", http.StatusUnauthorized) return true } // Prepare original request @@ -124,20 +124,20 @@ func checkLogin(w http.ResponseWriter, r *http.Request) bool { req.Header[k] = v } // Cookie - ses, err := loginSessionsStore.Get(r, "l") + ses, err := a.loginSessions.Get(r, "l") if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true } ses.Values["login"] = true - cookie, err := loginSessionsStore.SaveGetCookie(r, w, ses) + cookie, err := a.loginSessions.SaveGetCookie(r, w, ses) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true } req.AddCookie(cookie) // Serve original request - d.ServeHTTP(w, req) + a.d.ServeHTTP(w, req) return true } @@ -146,9 +146,9 @@ func serveLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) } -func serveLogout(w http.ResponseWriter, r *http.Request) { - if ses, err := loginSessionsStore.Get(r, "l"); err == nil && ses != nil { - _ = loginSessionsStore.Delete(r, w, ses) +func (a *goBlog) serveLogout(w http.ResponseWriter, r *http.Request) { + if ses, err := a.loginSessions.Get(r, "l"); err == nil && ses != nil { + _ = a.loginSessions.Delete(r, w, ses) } http.Redirect(w, r, "/", http.StatusFound) } diff --git a/blogroll.go b/blogroll.go index c1bcfcb..eefccb8 100644 --- a/blogroll.go +++ b/blogroll.go @@ -13,28 +13,25 @@ import ( "github.com/kaorimatz/go-opml" servertiming "github.com/mitchellh/go-server-timing" "github.com/thoas/go-funk" - "golang.org/x/sync/singleflight" ) -var blogrollCacheGroup singleflight.Group - -func serveBlogroll(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveBlogroll(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) t := servertiming.FromContext(r.Context()).NewMetric("bg").Start() - outlines, err, _ := blogrollCacheGroup.Do(blog, func() (interface{}, error) { - return getBlogrollOutlines(blog) + outlines, err, _ := a.blogrollCacheGroup.Do(blog, func() (interface{}, error) { + return a.getBlogrollOutlines(blog) }) t.Stop() if err != nil { - log.Println("Failed to get outlines:", err.Error()) - serveError(w, r, "", http.StatusInternalServerError) + log.Printf("Failed to get outlines: %v", err) + a.serveError(w, r, "", http.StatusInternalServerError) return } - if appConfig.Cache != nil && appConfig.Cache.Enable { - setInternalCacheExpirationHeader(w, r, int(appConfig.Cache.Expiration)) + if a.cfg.Cache != nil && a.cfg.Cache.Enable { + setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) } - c := appConfig.Blogs[blog].Blogroll - render(w, r, templateBlogroll, &renderData{ + c := a.cfg.Blogs[blog].Blogroll + a.render(w, r, templateBlogroll, &renderData{ BlogString: blog, Data: map[string]interface{}{ "Title": c.Title, @@ -45,34 +42,32 @@ func serveBlogroll(w http.ResponseWriter, r *http.Request) { }) } -func serveBlogrollExport(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveBlogrollExport(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) - outlines, err, _ := blogrollCacheGroup.Do(blog, func() (interface{}, error) { - return getBlogrollOutlines(blog) + outlines, err, _ := a.blogrollCacheGroup.Do(blog, func() (interface{}, error) { + return a.getBlogrollOutlines(blog) }) if err != nil { - log.Println("Failed to get outlines:", err.Error()) - serveError(w, r, "", http.StatusInternalServerError) + log.Printf("Failed to get outlines: %v", err) + a.serveError(w, r, "", http.StatusInternalServerError) return } - if appConfig.Cache != nil && appConfig.Cache.Enable { - setInternalCacheExpirationHeader(w, r, int(appConfig.Cache.Expiration)) + if a.cfg.Cache != nil && a.cfg.Cache.Enable { + setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) } w.Header().Set(contentType, contentTypeXMLUTF8) - mw := minifier.Writer(contentTypeXML, w) - defer func() { - _ = mw.Close() - }() - _ = opml.Render(mw, &opml.OPML{ + var opmlBytes bytes.Buffer + _ = opml.Render(&opmlBytes, &opml.OPML{ Version: "2.0", DateCreated: time.Now().UTC(), Outlines: outlines.([]*opml.Outline), }) + _, _ = writeMinified(w, contentTypeXML, opmlBytes.Bytes()) } -func getBlogrollOutlines(blog string) ([]*opml.Outline, error) { - config := appConfig.Blogs[blog].Blogroll - if cache := loadOutlineCache(blog); cache != nil { +func (a *goBlog) getBlogrollOutlines(blog string) ([]*opml.Outline, error) { + config := a.cfg.Blogs[blog].Blogroll + if cache := a.db.loadOutlineCache(blog); cache != nil { return cache, nil } req, err := http.NewRequest(http.MethodGet, config.Opml, nil) @@ -112,22 +107,22 @@ func getBlogrollOutlines(blog string) ([]*opml.Outline, error) { } else { outlines = sortOutlines(outlines) } - cacheOutlines(blog, outlines) + a.db.cacheOutlines(blog, outlines) return outlines, nil } -func cacheOutlines(blog string, outlines []*opml.Outline) { +func (db *database) cacheOutlines(blog string, outlines []*opml.Outline) { var opmlBuffer bytes.Buffer _ = opml.Render(&opmlBuffer, &opml.OPML{ Version: "2.0", DateCreated: time.Now().UTC(), Outlines: outlines, }) - _ = cachePersistently("blogroll_"+blog, opmlBuffer.Bytes()) + _ = db.cachePersistently("blogroll_"+blog, opmlBuffer.Bytes()) } -func loadOutlineCache(blog string) []*opml.Outline { - data, err := retrievePersistentCache("blogroll_" + blog) +func (db *database) loadOutlineCache(blog string) []*opml.Outline { + data, err := db.retrievePersistentCache("blogroll_" + blog) if err != nil || data == nil { return nil } diff --git a/blogstats.go b/blogstats.go index c9ad71a..516441b 100644 --- a/blogstats.go +++ b/blogstats.go @@ -9,19 +9,19 @@ import ( "golang.org/x/sync/singleflight" ) -func initBlogStats() { +func (a *goBlog) initBlogStats() { f := func(p *post) { - resetBlogStats(p.Blog) + a.db.resetBlogStats(p.Blog) } - postPostHooks = append(postPostHooks, f) - postUpdateHooks = append(postUpdateHooks, f) - postDeleteHooks = append(postDeleteHooks, f) + a.pPostHooks = append(a.pPostHooks, f) + a.pUpdateHooks = append(a.pUpdateHooks, f) + a.pDeleteHooks = append(a.pDeleteHooks, f) } -func serveBlogStats(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveBlogStats(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) - canonical := blogPath(blog) + appConfig.Blogs[blog].BlogStats.Path - render(w, r, templateBlogStats, &renderData{ + canonical := a.blogPath(blog) + a.cfg.Blogs[blog].BlogStats.Path + a.render(w, r, templateBlogStats, &renderData{ BlogString: blog, Canonical: canonical, Data: map[string]interface{}{ @@ -32,24 +32,24 @@ func serveBlogStats(w http.ResponseWriter, r *http.Request) { var blogStatsCacheGroup singleflight.Group -func serveBlogStatsTable(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveBlogStatsTable(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) data, err, _ := blogStatsCacheGroup.Do(blog, func() (interface{}, error) { - return getBlogStats(blog) + return a.db.getBlogStats(blog) }) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } // Render - render(w, r, templateBlogStatsTable, &renderData{ + a.render(w, r, templateBlogStatsTable, &renderData{ BlogString: blog, Data: data, }) } -func getBlogStats(blog string) (data map[string]interface{}, err error) { - if stats := loadBlogStatsCache(blog); stats != nil { +func (db *database) getBlogStats(blog string) (data map[string]interface{}, err error) { + if stats := db.loadBlogStatsCache(blog); stats != nil { return stats, nil } // Build query @@ -67,7 +67,7 @@ func getBlogStats(blog string) (data map[string]interface{}, err error) { Name, Posts, Chars, Words, WordsPerPost string } // Count total posts - row, err := appDb.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+"))", params...) + row, err := db.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+"))", params...) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func getBlogStats(blog string) (data map[string]interface{}, err error) { return nil, err } // Count posts per year - rows, err := appDb.query("select *, "+wordsPerPost+" from (select year, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' group by year order by year desc)", params...) + rows, err := db.query("select *, "+wordsPerPost+" from (select year, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' group by year order by year desc)", params...) if err != nil { return nil, err } @@ -90,7 +90,7 @@ func getBlogStats(blog string) (data map[string]interface{}, err error) { } } // Count posts without date - row, err = appDb.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published = '')", params...) + row, err = db.queryRow("select *, "+wordsPerPost+" from (select "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published = '')", params...) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func getBlogStats(blog string) (data map[string]interface{}, err error) { months := map[string][]statsTableType{} month := statsTableType{} for _, year := range years { - rows, err = appDb.query("select *, "+wordsPerPost+" from (select month, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' and year = @year group by month order by month desc)", append(params, sql.Named("year", year.Name))...) + rows, err = db.query("select *, "+wordsPerPost+" from (select month, "+postCount+", "+charCount+", "+wordCount+" from ("+query+") where published != '' and year = @year group by month order by month desc)", append(params, sql.Named("year", year.Name))...) if err != nil { return nil, err } @@ -120,17 +120,17 @@ func getBlogStats(blog string) (data map[string]interface{}, err error) { "withoutdate": noDate, "months": months, } - cacheBlogStats(blog, data) + db.cacheBlogStats(blog, data) return data, nil } -func cacheBlogStats(blog string, stats map[string]interface{}) { +func (db *database) cacheBlogStats(blog string, stats map[string]interface{}) { jb, _ := json.Marshal(stats) - _ = cachePersistently("blogstats_"+blog, jb) + _ = db.cachePersistently("blogstats_"+blog, jb) } -func loadBlogStatsCache(blog string) (stats map[string]interface{}) { - data, err := retrievePersistentCache("blogstats_" + blog) +func (db *database) loadBlogStatsCache(blog string) (stats map[string]interface{}) { + data, err := db.retrievePersistentCache("blogstats_" + blog) if err != nil || data == nil { return nil } @@ -141,6 +141,6 @@ func loadBlogStatsCache(blog string) (stats map[string]interface{}) { return stats } -func resetBlogStats(blog string) { - _ = clearPersistentCache("blogstats_" + blog) +func (db *database) resetBlogStats(blog string) { + _ = db.clearPersistentCache("blogstats_" + blog) } diff --git a/cache.go b/cache.go index 5fba71f..276633c 100644 --- a/cache.go +++ b/cache.go @@ -20,17 +20,22 @@ import ( "golang.org/x/sync/singleflight" ) -const ( - cacheInternalExpirationHeader = "Goblog-Expire" -) +const cacheInternalExpirationHeader = "Goblog-Expire" -var ( - cacheGroup singleflight.Group - cacheR *ristretto.Cache -) +type cache struct { + g singleflight.Group + c *ristretto.Cache + cfg *configCache +} -func initCache() (err error) { - cacheR, err = ristretto.NewCache(&ristretto.Config{ +func (a *goBlog) initCache() (err error) { + a.cache = &cache{ + cfg: a.cfg.Cache, + } + if a.cache.cfg != nil && !a.cache.cfg.Enable { + return nil + } + a.cache.c, err = ristretto.NewCache(&ristretto.Config{ NumCounters: 5000, MaxCost: 20000000, // 20 MB BufferItems: 16, @@ -52,13 +57,14 @@ func initCache() (err error) { return } -func cacheMiddleware(next http.Handler) http.Handler { +func (c *cache) cacheMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Do checks - if !appConfig.Cache.Enable { + if c.c == nil { + // No cache configured next.ServeHTTP(w, r) return } + // Do checks if !(r.Method == http.MethodGet || r.Method == http.MethodHead) { next.ServeHTTP(w, r) return @@ -74,32 +80,32 @@ func cacheMiddleware(next http.Handler) http.Handler { // Search and serve cache key := cacheKey(r) // Get cache or render it - cacheInterface, _, _ := cacheGroup.Do(key, func() (interface{}, error) { - return getCache(key, next, r), nil + cacheInterface, _, _ := c.g.Do(key, func() (interface{}, error) { + return c.getCache(key, next, r), nil }) - cache := cacheInterface.(*cacheItem) + ci := cacheInterface.(*cacheItem) // copy cached headers - for k, v := range cache.header { + for k, v := range ci.header { w.Header()[k] = v } - setCacheHeaders(w, cache) + c.setCacheHeaders(w, ci) // check conditional request - if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == cache.eTag { + if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag { // send 304 w.WriteHeader(http.StatusNotModified) return } if ifModifiedSinceHeader := r.Header.Get("If-Modified-Since"); ifModifiedSinceHeader != "" { - if t, err := dateparse.ParseAny(ifModifiedSinceHeader); err == nil && t.After(cache.creationTime) { + if t, err := dateparse.ParseAny(ifModifiedSinceHeader); err == nil && t.After(ci.creationTime) { // send 304 w.WriteHeader(http.StatusNotModified) return } } // set status code - w.WriteHeader(cache.code) + w.WriteHeader(ci.code) // write cached body - _, _ = w.Write(cache.body) + _, _ = w.Write(ci.body) }) } @@ -125,14 +131,14 @@ func cacheURLString(u *url.URL) string { return buf.String() } -func setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { +func (c *cache) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) { w.Header().Set("ETag", cache.eTag) w.Header().Set("Last-Modified", cache.creationTime.UTC().Format(http.TimeFormat)) if w.Header().Get("Cache-Control") == "" { if cache.expiration != 0 { w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,stale-while-revalidate=%d", cache.expiration, cache.expiration)) } else { - w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", appConfig.Cache.Expiration, appConfig.Cache.Expiration/3, appConfig.Cache.Expiration)) + w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", c.cfg.Expiration, c.cfg.Expiration/3, c.cfg.Expiration)) } } } @@ -146,8 +152,8 @@ type cacheItem struct { body []byte } -func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) { - if rItem, ok := cacheR.Get(key); ok { +func (c *cache) getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) { + if rItem, ok := c.c.Get(key); ok { item = rItem.(*cacheItem) } if item == nil { @@ -198,10 +204,10 @@ func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) // Save cache if cch := item.header.Get("Cache-Control"); !strings.Contains(cch, "no-store") && !strings.Contains(cch, "private") && !strings.Contains(cch, "no-cache") { if exp == 0 { - cacheR.Set(key, item, 0) + c.c.Set(key, item, 0) } else { ttl := time.Duration(exp) * time.Second - cacheR.SetWithTTL(key, item, 0, ttl) + c.c.SetWithTTL(key, item, 0, ttl) } } } else { @@ -210,8 +216,8 @@ func getCache(key string, next http.Handler, r *http.Request) (item *cacheItem) return item } -func purgeCache() { - cacheR.Clear() +func (c *cache) purge() { + c.c.Clear() } func setInternalCacheExpirationHeader(w http.ResponseWriter, r *http.Request, expiration int) { diff --git a/captcha.go b/captcha.go index 8a57326..0c82d21 100644 --- a/captcha.go +++ b/captcha.go @@ -10,10 +10,10 @@ import ( "github.com/dchest/captcha" ) -func captchaMiddleware(next http.Handler) http.Handler { +func (a *goBlog) captchaMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 1. Check Cookie - ses, err := captchaSessionsStore.Get(r, "c") + ses, err := a.captchaSessions.Get(r, "c") if err == nil && ses != nil { if captcha, ok := ses.Values["captcha"]; ok && captcha.(bool) { next.ServeHTTP(w, r) @@ -30,7 +30,7 @@ func captchaMiddleware(next http.Handler) http.Handler { _ = r.ParseForm() b = []byte(r.PostForm.Encode()) } - render(w, r, templateCaptcha, &renderData{ + a.render(w, r, templateCaptcha, &renderData{ Data: map[string]string{ "captchamethod": r.Method, "captchaheaders": base64.StdEncoding.EncodeToString(h), @@ -41,15 +41,15 @@ func captchaMiddleware(next http.Handler) http.Handler { }) } -func checkIsCaptcha(next http.Handler) http.Handler { +func (a *goBlog) checkIsCaptcha(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if !checkCaptcha(rw, r) { + if !a.checkCaptcha(rw, r) { next.ServeHTTP(rw, r) } }) } -func checkCaptcha(w http.ResponseWriter, r *http.Request) bool { +func (a *goBlog) checkCaptcha(w http.ResponseWriter, r *http.Request) bool { if r.Method != http.MethodPost { return false } @@ -71,20 +71,20 @@ func checkCaptcha(w http.ResponseWriter, r *http.Request) bool { } // Check captcha and create cookie if captcha.VerifyString(r.FormValue("captchaid"), r.FormValue("digits")) { - ses, err := captchaSessionsStore.Get(r, "c") + ses, err := a.captchaSessions.Get(r, "c") if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true } ses.Values["captcha"] = true - cookie, err := captchaSessionsStore.SaveGetCookie(r, w, ses) + cookie, err := a.captchaSessions.SaveGetCookie(r, w, ses) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return true } req.AddCookie(cookie) } // Serve original request - d.ServeHTTP(w, req) + a.d.ServeHTTP(w, req) return true } diff --git a/check.go b/check.go index 36a5c16..919b306 100644 --- a/check.go +++ b/check.go @@ -11,8 +11,8 @@ import ( "time" ) -func checkAllExternalLinks() { - allPosts, err := getPosts(&postsRequestConfig{status: statusPublished}) +func (a *goBlog) checkAllExternalLinks() { + allPosts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished}) if err != nil { log.Println(err.Error()) return @@ -30,45 +30,49 @@ func checkAllExternalLinks() { } responses := map[string]int{} rm := sync.RWMutex{} - for i := 0; i < 20; i++ { - go func() { - defer wg.Done() - wg.Add(1) - for postLinkPair := range linkChan { - rm.RLock() - _, ok := responses[postLinkPair.Second] - rm.RUnlock() - if !ok { - req, err := http.NewRequest(http.MethodGet, postLinkPair.Second, nil) - if err != nil { - fmt.Println(err.Error()) - continue - } - // User-Agent from Tor - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 6.1; rv:60.0) Gecko/20100101 Firefox/60.0") - req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") - req.Header.Set("Accept-Language", "en-US,en;q=0.5") - resp, err := client.Do(req) - if err != nil { - fmt.Println(postLinkPair.Second+" ("+postLinkPair.First+"):", err.Error()) - continue - } - status := resp.StatusCode - _, _ = io.Copy(io.Discard, resp.Body) - resp.Body.Close() - rm.Lock() - responses[postLinkPair.Second] = status - rm.Unlock() - } - rm.RLock() - if response, ok := responses[postLinkPair.Second]; ok && !checkSuccessStatus(response) { - fmt.Println(postLinkPair.Second+" ("+postLinkPair.First+"):", response) - } - rm.RUnlock() + processFunc := func() { + defer wg.Done() + wg.Add(1) + for postLinkPair := range linkChan { + if strings.HasPrefix(postLinkPair.Second, a.cfg.Server.PublicAddress) { + continue } - }() + rm.RLock() + _, ok := responses[postLinkPair.Second] + rm.RUnlock() + if !ok { + req, err := http.NewRequest(http.MethodGet, postLinkPair.Second, nil) + if err != nil { + fmt.Println(err.Error()) + continue + } + // User-Agent from Tor + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 6.1; rv:60.0) Gecko/20100101 Firefox/60.0") + req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") + req.Header.Set("Accept-Language", "en-US,en;q=0.5") + resp, err := client.Do(req) + if err != nil { + fmt.Println(postLinkPair.Second+" ("+postLinkPair.First+"):", err.Error()) + continue + } + status := resp.StatusCode + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + rm.Lock() + responses[postLinkPair.Second] = status + rm.Unlock() + } + rm.RLock() + if response, ok := responses[postLinkPair.Second]; ok && !checkSuccessStatus(response) { + fmt.Println(postLinkPair.Second+" ("+postLinkPair.First+"):", response) + } + rm.RUnlock() + } } - err = getExternalLinks(allPosts, linkChan) + for i := 0; i < 20; i++ { + go processFunc() + } + err = a.getExternalLinks(allPosts, linkChan) if err != nil { log.Println(err.Error()) return @@ -80,17 +84,15 @@ func checkSuccessStatus(status int) bool { return status >= 200 && status < 400 } -func getExternalLinks(posts []*post, linkChan chan<- stringPair) error { +func (a *goBlog) getExternalLinks(posts []*post, linkChan chan<- stringPair) error { wg := new(sync.WaitGroup) for _, p := range posts { wg.Add(1) go func(p *post) { defer wg.Done() - links, _ := allLinksFromHTML(strings.NewReader(string(p.absoluteHTML())), p.fullURL()) + links, _ := allLinksFromHTML(strings.NewReader(string(a.absoluteHTML(p))), a.fullPostURL(p)) for _, link := range links { - if !strings.HasPrefix(link, appConfig.Server.PublicAddress) { - linkChan <- stringPair{p.fullURL(), link} - } + linkChan <- stringPair{a.fullPostURL(p), link} } }(p) } diff --git a/comments.go b/comments.go index 1266bb7..7a7da3c 100644 --- a/comments.go +++ b/comments.go @@ -20,36 +20,36 @@ type comment struct { Comment string } -func serveComment(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveComment(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(chi.URLParam(r, "id")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - row, err := appDb.queryRow("select id, target, name, website, comment from comments where id = @id", sql.Named("id", id)) + row, err := a.db.queryRow("select id, target, name, website, comment from comments where id = @id", sql.Named("id", id)) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } comment := &comment{} if err = row.Scan(&comment.ID, &comment.Target, &comment.Name, &comment.Website, &comment.Comment); err == sql.ErrNoRows { - serve404(w, r) + a.serve404(w, r) return } else if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } blog := r.Context().Value(blogContextKey).(string) - render(w, r, templateComment, &renderData{ + a.render(w, r, templateComment, &renderData{ BlogString: blog, - Canonical: appConfig.Server.PublicAddress + appConfig.Blogs[blog].getRelativePath(fmt.Sprintf("/comment/%d", id)), + Canonical: a.cfg.Server.PublicAddress + a.cfg.Blogs[blog].getRelativePath(fmt.Sprintf("/comment/%d", id)), Data: comment, }) } -func createComment(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) createComment(w http.ResponseWriter, r *http.Request) { // Check target - target := checkCommentTarget(w, r) + target := a.checkCommentTarget(w, r) if target == "" { return } @@ -57,7 +57,7 @@ func createComment(w http.ResponseWriter, r *http.Request) { strict := bluemonday.StrictPolicy() comment := strings.TrimSpace(strict.Sanitize(r.FormValue("comment"))) if comment == "" { - serveError(w, r, "Comment is empty", http.StatusBadRequest) + a.serveError(w, r, "Comment is empty", http.StatusBadRequest) return } name := strings.TrimSpace(strict.Sanitize(r.FormValue("name"))) @@ -66,35 +66,35 @@ func createComment(w http.ResponseWriter, r *http.Request) { } website := strings.TrimSpace(strict.Sanitize(r.FormValue("website"))) // Insert - result, err := appDb.exec("insert into comments (target, comment, name, website) values (@target, @comment, @name, @website)", sql.Named("target", target), sql.Named("comment", comment), sql.Named("name", name), sql.Named("website", website)) + result, err := a.db.exec("insert into comments (target, comment, name, website) values (@target, @comment, @name, @website)", sql.Named("target", target), sql.Named("comment", comment), sql.Named("name", name), sql.Named("website", website)) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } if commentID, err := result.LastInsertId(); err != nil { // Serve error - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) } else { - commentAddress := fmt.Sprintf("%s/%d", blogPath(r.Context().Value(blogContextKey).(string))+"/comment", commentID) + commentAddress := fmt.Sprintf("%s/%d", a.blogPath(r.Context().Value(blogContextKey).(string))+"/comment", commentID) // Send webmention - _ = createWebmention(appConfig.Server.PublicAddress+commentAddress, appConfig.Server.PublicAddress+target) + _ = a.createWebmention(a.cfg.Server.PublicAddress+commentAddress, a.cfg.Server.PublicAddress+target) // Redirect to comment http.Redirect(w, r, commentAddress, http.StatusFound) } } -func checkCommentTarget(w http.ResponseWriter, r *http.Request) string { +func (a *goBlog) checkCommentTarget(w http.ResponseWriter, r *http.Request) string { target := r.FormValue("target") if target == "" { - serveError(w, r, "No target specified", http.StatusBadRequest) + a.serveError(w, r, "No target specified", http.StatusBadRequest) return "" - } else if !strings.HasPrefix(target, appConfig.Server.PublicAddress) { - serveError(w, r, "Bad target", http.StatusBadRequest) + } else if !strings.HasPrefix(target, a.cfg.Server.PublicAddress) { + a.serveError(w, r, "Bad target", http.StatusBadRequest) return "" } targetURL, err := url.Parse(target) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return "" } return targetURL.Path @@ -114,10 +114,10 @@ func buildCommentsQuery(config *commentsRequestConfig) (query string, args []int return } -func getComments(config *commentsRequestConfig) ([]*comment, error) { +func (db *database) getComments(config *commentsRequestConfig) ([]*comment, error) { comments := []*comment{} query, args := buildCommentsQuery(config) - rows, err := appDb.query(query, args...) + rows, err := db.query(query, args...) if err != nil { return nil, err } @@ -132,10 +132,10 @@ func getComments(config *commentsRequestConfig) ([]*comment, error) { return comments, nil } -func countComments(config *commentsRequestConfig) (count int, err error) { +func (db *database) countComments(config *commentsRequestConfig) (count int, err error) { query, params := buildCommentsQuery(config) query = "select count(*) from (" + query + ")" - row, err := appDb.queryRow(query, params...) + row, err := db.queryRow(query, params...) if err != nil { return } @@ -143,7 +143,7 @@ func countComments(config *commentsRequestConfig) (count int, err error) { return } -func deleteComment(id int) error { - _, err := appDb.exec("delete from comments where id = @id", sql.Named("id", id)) +func (db *database) deleteComment(id int) error { + _, err := db.exec("delete from comments where id = @id", sql.Named("id", id)) return err } diff --git a/commentsAdmin.go b/commentsAdmin.go index 7b18a7c..f5c0f62 100644 --- a/commentsAdmin.go +++ b/commentsAdmin.go @@ -13,11 +13,12 @@ import ( type commentsPaginationAdapter struct { config *commentsRequestConfig nums int64 + db *database } func (p *commentsPaginationAdapter) Nums() (int64, error) { if p.nums == 0 { - nums, _ := countComments(p.config) + nums, _ := p.db.countComments(p.config) p.nums = int64(nums) } return p.nums, nil @@ -28,23 +29,23 @@ func (p *commentsPaginationAdapter) Slice(offset, length int, data interface{}) modifiedConfig.offset = offset modifiedConfig.limit = length - comments, err := getComments(&modifiedConfig) + comments, err := p.db.getComments(&modifiedConfig) reflect.ValueOf(data).Elem().Set(reflect.ValueOf(&comments).Elem()) return err } -func commentsAdmin(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) commentsAdmin(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) commentsPath := r.Context().Value(pathContextKey).(string) // Adapter pageNoString := chi.URLParam(r, "page") pageNo, _ := strconv.Atoi(pageNoString) - p := paginator.New(&commentsPaginationAdapter{config: &commentsRequestConfig{}}, 5) + p := paginator.New(&commentsPaginationAdapter{config: &commentsRequestConfig{}, db: a.db}, 5) p.SetPage(pageNo) var comments []*comment err := p.Results(&comments) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } // Navigation @@ -70,7 +71,7 @@ func commentsAdmin(w http.ResponseWriter, r *http.Request) { } nextPath = fmt.Sprintf("%s/page/%d", commentsPath, nextPage) // Render - render(w, r, templateCommentsAdmin, &renderData{ + a.render(w, r, templateCommentsAdmin, &renderData{ BlogString: blog, Data: map[string]interface{}{ "Comments": comments, @@ -82,17 +83,17 @@ func commentsAdmin(w http.ResponseWriter, r *http.Request) { }) } -func commentsAdminDelete(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) commentsAdminDelete(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.FormValue("commentid")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - err = deleteComment(id) + err = a.db.deleteComment(id) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - purgeCache() + a.cache.purge() http.Redirect(w, r, ".", http.StatusFound) } diff --git a/config.go b/config.go index f95437b..8176503 100644 --- a/config.go +++ b/config.go @@ -224,9 +224,7 @@ type configWebmention struct { DisableReceiving bool `mapstructure:"disableReceiving"` } -var appConfig = &config{} - -func initConfig() error { +func (a *goBlog) initConfig() error { viper.SetConfigName("config") viper.AddConfigPath("./config/") err := viper.ReadInConfig() @@ -258,52 +256,53 @@ func initConfig() error { viper.SetDefault("webmention.disableSending", false) viper.SetDefault("webmention.disableReceiving", false) // Unmarshal config - err = viper.Unmarshal(appConfig) + a.cfg = &config{} + err = viper.Unmarshal(a.cfg) if err != nil { return err } // Check config - publicURL, err := url.Parse(appConfig.Server.PublicAddress) + publicURL, err := url.Parse(a.cfg.Server.PublicAddress) if err != nil { return err } - appConfig.Server.publicHostname = publicURL.Hostname() - if appConfig.Server.ShortPublicAddress != "" { - shortPublicURL, err := url.Parse(appConfig.Server.ShortPublicAddress) + a.cfg.Server.publicHostname = publicURL.Hostname() + if a.cfg.Server.ShortPublicAddress != "" { + shortPublicURL, err := url.Parse(a.cfg.Server.ShortPublicAddress) if err != nil { return err } - appConfig.Server.shortPublicHostname = shortPublicURL.Hostname() + a.cfg.Server.shortPublicHostname = shortPublicURL.Hostname() } - if appConfig.Server.JWTSecret == "" { + if a.cfg.Server.JWTSecret == "" { return errors.New("no JWT secret configured") } - if len(appConfig.Blogs) == 0 { + if len(a.cfg.Blogs) == 0 { return errors.New("no blog configured") } - if len(appConfig.DefaultBlog) == 0 || appConfig.Blogs[appConfig.DefaultBlog] == nil { + if len(a.cfg.DefaultBlog) == 0 || a.cfg.Blogs[a.cfg.DefaultBlog] == nil { return errors.New("no default blog or default blog not present") } - if appConfig.Micropub.MediaStorage != nil { - if appConfig.Micropub.MediaStorage.MediaURL == "" || - appConfig.Micropub.MediaStorage.BunnyStorageKey == "" || - appConfig.Micropub.MediaStorage.BunnyStorageName == "" { - appConfig.Micropub.MediaStorage.BunnyStorageKey = "" - appConfig.Micropub.MediaStorage.BunnyStorageName = "" + if a.cfg.Micropub.MediaStorage != nil { + if a.cfg.Micropub.MediaStorage.MediaURL == "" || + a.cfg.Micropub.MediaStorage.BunnyStorageKey == "" || + a.cfg.Micropub.MediaStorage.BunnyStorageName == "" { + a.cfg.Micropub.MediaStorage.BunnyStorageKey = "" + a.cfg.Micropub.MediaStorage.BunnyStorageName = "" } - appConfig.Micropub.MediaStorage.MediaURL = strings.TrimSuffix(appConfig.Micropub.MediaStorage.MediaURL, "/") + a.cfg.Micropub.MediaStorage.MediaURL = strings.TrimSuffix(a.cfg.Micropub.MediaStorage.MediaURL, "/") } - if pm := appConfig.PrivateMode; pm != nil && pm.Enabled { - appConfig.ActivityPub = &configActivityPub{Enabled: false} + if pm := a.cfg.PrivateMode; pm != nil && pm.Enabled { + a.cfg.ActivityPub = &configActivityPub{Enabled: false} } - if wm := appConfig.Webmention; wm != nil && wm.DisableReceiving { + if wm := a.cfg.Webmention; wm != nil && wm.DisableReceiving { // Disable comments for all blogs - for _, b := range appConfig.Blogs { + for _, b := range a.cfg.Blogs { b.Comments = &comments{Enabled: false} } } // Check config for each blog - for _, blog := range appConfig.Blogs { + for _, blog := range a.cfg.Blogs { if br := blog.Blogroll; br != nil && br.Enabled && br.Opml == "" { br.Enabled = false } @@ -311,6 +310,6 @@ func initConfig() error { return nil } -func httpsConfigured() bool { - return appConfig.Server.PublicHTTPS || appConfig.Server.SecurityHeaders || strings.HasPrefix(appConfig.Server.PublicAddress, "https") +func (a *goBlog) httpsConfigured() bool { + return a.cfg.Server.PublicHTTPS || a.cfg.Server.SecurityHeaders || strings.HasPrefix(a.cfg.Server.PublicAddress, "https") } diff --git a/customPages.go b/customPages.go index 64beb42..f526668 100644 --- a/customPages.go +++ b/customPages.go @@ -4,18 +4,18 @@ import "net/http" const customPageContextKey = "custompage" -func serveCustomPage(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveCustomPage(w http.ResponseWriter, r *http.Request) { page := r.Context().Value(customPageContextKey).(*customPage) - if appConfig.Cache != nil && appConfig.Cache.Enable && page.Cache { + if a.cfg.Cache != nil && a.cfg.Cache.Enable && page.Cache { if page.CacheExpiration != 0 { setInternalCacheExpirationHeader(w, r, page.CacheExpiration) } else { - setInternalCacheExpirationHeader(w, r, int(appConfig.Cache.Expiration)) + setInternalCacheExpirationHeader(w, r, int(a.cfg.Cache.Expiration)) } } - render(w, r, page.Template, &renderData{ + a.render(w, r, page.Template, &renderData{ BlogString: r.Context().Value(blogContextKey).(string), - Canonical: appConfig.Server.PublicAddress + page.Path, + Canonical: a.cfg.Server.PublicAddress + page.Path, Data: page.Data, }) } diff --git a/database.go b/database.go index f37c0fc..91b451d 100644 --- a/database.go +++ b/database.go @@ -8,18 +8,42 @@ import ( sqlite "github.com/mattn/go-sqlite3" "github.com/schollz/sqlite3dump" + "golang.org/x/sync/singleflight" ) -var appDb *goblogDb - -type goblogDb struct { - db *sql.DB - statementCache map[string]*sql.Stmt +type database struct { + db *sql.DB + stmts map[string]*sql.Stmt + g singleflight.Group + persistentCacheGroup singleflight.Group } -func initDatabase() (err error) { +func (a *goBlog) initDatabase() (err error) { // Setup db - sql.Register("goblog_db", &sqlite.SQLiteDriver{ + db, err := a.openDatabase(a.cfg.Db.File) + if err != nil { + return err + } + // Create appDB + a.db = db + db.vacuum() + addShutdownFunc(func() { + _ = db.close() + log.Println("Closed database") + }) + if a.cfg.Db.DumpFile != "" { + hourlyHooks = append(hourlyHooks, func() { + db.dump(a.cfg.Db.DumpFile) + }) + db.dump(a.cfg.Db.DumpFile) + } + return nil +} + +func (a *goBlog) openDatabase(file string) (*database, error) { + // Register driver + dbDriverName := generateRandomString(15) + sql.Register("goblog_db_"+dbDriverName, &sqlite.SQLiteDriver{ ConnectHook: func(c *sqlite.SQLiteConn) error { if err := c.RegisterFunc("tolocal", toLocalSafe, true); err != nil { return err @@ -27,64 +51,54 @@ func initDatabase() (err error) { if err := c.RegisterFunc("wordcount", wordCount, true); err != nil { return err } - if err := c.RegisterFunc("mdtext", renderText, true); err != nil { + if err := c.RegisterFunc("mdtext", a.renderText, true); err != nil { return err } return nil }, }) - db, err := sql.Open("goblog_db", appConfig.Db.File+"?cache=shared&mode=rwc&_journal_mode=WAL") + // Open db + db, err := sql.Open("goblog_db_"+dbDriverName, file+"?cache=shared&mode=rwc&_journal_mode=WAL") if err != nil { - return err + return nil, err } db.SetMaxOpenConns(1) err = db.Ping() if err != nil { - return err + return nil, err } // Check available SQLite features rows, err := db.Query("pragma compile_options") if err != nil { - return err + return nil, err } cos := map[string]bool{} var co string for rows.Next() { err = rows.Scan(&co) if err != nil { - return err + return nil, err } cos[co] = true } if _, ok := cos["ENABLE_FTS5"]; !ok { - return errors.New("sqlite not compiled with FTS5") + return nil, errors.New("sqlite not compiled with FTS5") } // Migrate DB err = migrateDb(db) if err != nil { - return err + return nil, err } - // Create appDB - appDb = &goblogDb{ - db: db, - statementCache: map[string]*sql.Stmt{}, - } - appDb.vacuum() - addShutdownFunc(func() { - _ = appDb.close() - log.Println("Closed database") - }) - if appConfig.Db.DumpFile != "" { - hourlyHooks = append(hourlyHooks, func() { - appDb.dump() - }) - appDb.dump() - } - return nil + return &database{ + db: db, + stmts: map[string]*sql.Stmt{}, + }, nil } -func (db *goblogDb) dump() { - f, err := os.Create(appConfig.Db.DumpFile) +// Main features + +func (db *database) dump(file string) { + f, err := os.Create(file) if err != nil { log.Println("Error while dump db:", err.Error()) return @@ -94,18 +108,18 @@ func (db *goblogDb) dump() { } } -func (db *goblogDb) close() error { +func (db *database) close() error { db.vacuum() return db.db.Close() } -func (db *goblogDb) vacuum() { +func (db *database) vacuum() { _, _ = db.exec("VACUUM") } -func (db *goblogDb) prepare(query string) (*sql.Stmt, error) { - stmt, err, _ := cacheGroup.Do(query, func() (interface{}, error) { - stmt, ok := db.statementCache[query] +func (db *database) prepare(query string) (*sql.Stmt, error) { + stmt, err, _ := db.g.Do(query, func() (interface{}, error) { + stmt, ok := db.stmts[query] if ok && stmt != nil { return stmt, nil } @@ -113,7 +127,7 @@ func (db *goblogDb) prepare(query string) (*sql.Stmt, error) { if err != nil { return nil, err } - db.statementCache[query] = stmt + db.stmts[query] = stmt return stmt, nil }) if err != nil { @@ -122,7 +136,7 @@ func (db *goblogDb) prepare(query string) (*sql.Stmt, error) { return stmt.(*sql.Stmt), nil } -func (db *goblogDb) exec(query string, args ...interface{}) (sql.Result, error) { +func (db *database) exec(query string, args ...interface{}) (sql.Result, error) { stmt, err := db.prepare(query) if err != nil { return nil, err @@ -130,12 +144,12 @@ func (db *goblogDb) exec(query string, args ...interface{}) (sql.Result, error) return stmt.Exec(args...) } -func (db *goblogDb) execMulti(query string, args ...interface{}) (sql.Result, error) { +func (db *database) execMulti(query string, args ...interface{}) (sql.Result, error) { // Can't prepare the statement return db.db.Exec(query, args...) } -func (db *goblogDb) query(query string, args ...interface{}) (*sql.Rows, error) { +func (db *database) query(query string, args ...interface{}) (*sql.Rows, error) { stmt, err := db.prepare(query) if err != nil { return nil, err @@ -143,10 +157,16 @@ func (db *goblogDb) query(query string, args ...interface{}) (*sql.Rows, error) return stmt.Query(args...) } -func (db *goblogDb) queryRow(query string, args ...interface{}) (*sql.Row, error) { +func (db *database) queryRow(query string, args ...interface{}) (*sql.Row, error) { stmt, err := db.prepare(query) if err != nil { return nil, err } return stmt.QueryRow(args...), nil } + +// Other things + +func (d *database) rebuildFTSIndex() { + _, _ = d.exec("insert into posts_fts(posts_fts) values ('rebuild')") +} diff --git a/databaseMigrations.go b/databaseMigrations.go index c4885eb..aa06be3 100644 --- a/databaseMigrations.go +++ b/databaseMigrations.go @@ -2,12 +2,16 @@ package main import ( "database/sql" + "log" "github.com/lopezator/migrator" ) func migrateDb(db *sql.DB) error { m, err := migrator.New( + migrator.WithLogger(migrator.LoggerFunc(func(s string, i ...interface{}) { + log.Printf(s, i) + })), migrator.Migrations( &migrator.Migration{ Name: "00001", diff --git a/editor.go b/editor.go index bcf8d2a..bf78f95 100644 --- a/editor.go +++ b/editor.go @@ -11,45 +11,45 @@ import ( const editorPath = "/editor" -func serveEditor(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveEditor(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) - render(w, r, templateEditor, &renderData{ + a.render(w, r, templateEditor, &renderData{ BlogString: blog, Data: map[string]interface{}{ - "Drafts": loadDrafts(blog), + "Drafts": a.db.getDrafts(blog), }, }) } -func serveEditorPost(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveEditorPost(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) if action := r.FormValue("editoraction"); action != "" { switch action { case "loaddelete": - render(w, r, templateEditor, &renderData{ + a.render(w, r, templateEditor, &renderData{ BlogString: blog, Data: map[string]interface{}{ "DeleteURL": r.FormValue("url"), - "Drafts": loadDrafts(blog), + "Drafts": a.db.getDrafts(blog), }, }) case "loadupdate": parsedURL, err := url.Parse(r.FormValue("url")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - post, err := getPost(parsedURL.Path) + post, err := a.db.getPost(parsedURL.Path) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - render(w, r, templateEditor, &renderData{ + a.render(w, r, templateEditor, &renderData{ BlogString: blog, Data: map[string]interface{}{ "UpdatePostURL": parsedURL.String(), - "UpdatePostContent": post.toMfItem().Properties.Content[0], - "Drafts": loadDrafts(blog), + "UpdatePostContent": a.toMfItem(post).Properties.Content[0], + "Drafts": a.db.getDrafts(blog), }, }) case "updatepost": @@ -63,37 +63,32 @@ func serveEditorPost(w http.ResponseWriter, r *http.Request) { }, }) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } req, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(jsonBytes)) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } req.Header.Set(contentType, contentTypeJSON) - editorMicropubPost(w, req, false) + a.editorMicropubPost(w, req, false) case "upload": - editorMicropubPost(w, r, true) + a.editorMicropubPost(w, r, true) default: - serveError(w, r, "Unknown editoraction", http.StatusBadRequest) + a.serveError(w, r, "Unknown editoraction", http.StatusBadRequest) } return } - editorMicropubPost(w, r, false) + a.editorMicropubPost(w, r, false) } -func loadDrafts(blog string) []*post { - ps, _ := getPosts(&postsRequestConfig{status: statusDraft, blog: blog}) - return ps -} - -func editorMicropubPost(w http.ResponseWriter, r *http.Request, media bool) { +func (a *goBlog) editorMicropubPost(w http.ResponseWriter, r *http.Request, media bool) { recorder := httptest.NewRecorder() if media { - addAllScopes(http.HandlerFunc(serveMicropubMedia)).ServeHTTP(recorder, r) + addAllScopes(http.HandlerFunc(a.serveMicropubMedia)).ServeHTTP(recorder, r) } else { - addAllScopes(http.HandlerFunc(serveMicropubPost)).ServeHTTP(recorder, r) + addAllScopes(http.HandlerFunc(a.serveMicropubPost)).ServeHTTP(recorder, r) } result := recorder.Result() if location := result.Header.Get("Location"); location != "" { diff --git a/errors.go b/errors.go index a3353ad..53f25a4 100644 --- a/errors.go +++ b/errors.go @@ -12,19 +12,19 @@ type errorData struct { Message string } -func serve404(w http.ResponseWriter, r *http.Request) { - serveError(w, r, fmt.Sprintf("%s was not found", r.RequestURI), http.StatusNotFound) +func (a *goBlog) serve404(w http.ResponseWriter, r *http.Request) { + a.serveError(w, r, fmt.Sprintf("%s was not found", r.RequestURI), http.StatusNotFound) } -func serveNotAllowed(w http.ResponseWriter, r *http.Request) { - serveError(w, r, "", http.StatusMethodNotAllowed) +func (a *goBlog) serveNotAllowed(w http.ResponseWriter, r *http.Request) { + a.serveError(w, r, "", http.StatusMethodNotAllowed) } var errorCheckMediaTypes = []contenttype.MediaType{ contenttype.NewMediaType(contentTypeHTML), } -func serveError(w http.ResponseWriter, r *http.Request, message string, status int) { +func (a *goBlog) serveError(w http.ResponseWriter, r *http.Request, message string, status int) { if mt, _, err := contenttype.GetAcceptableMediaType(r, errorCheckMediaTypes); err != nil || mt.String() != errorCheckMediaTypes[0].String() { // Request doesn't accept HTML http.Error(w, message, status) @@ -35,7 +35,7 @@ func serveError(w http.ResponseWriter, r *http.Request, message string, status i message = http.StatusText(status) } w.WriteHeader(status) - render(w, r, templateError, &renderData{ + a.render(w, r, templateError, &renderData{ Data: &errorData{ Title: title, Message: message, diff --git a/feeds.go b/feeds.go index 707de38..9565547 100644 --- a/feeds.go +++ b/feeds.go @@ -22,25 +22,25 @@ const ( feedAudioLength = "audiolength" ) -func generateFeed(blog string, f feedType, w http.ResponseWriter, r *http.Request, posts []*post, title string, description string) { +func (a *goBlog) generateFeed(blog string, f feedType, w http.ResponseWriter, r *http.Request, posts []*post, title string, description string) { now := time.Now() if title == "" { - title = appConfig.Blogs[blog].Title + title = a.cfg.Blogs[blog].Title } if description == "" { - description = appConfig.Blogs[blog].Description + description = a.cfg.Blogs[blog].Description } feed := &feeds.Feed{ Title: title, Description: description, - Link: &feeds.Link{Href: appConfig.Server.PublicAddress + strings.TrimSuffix(r.URL.Path, "."+string(f))}, + Link: &feeds.Link{Href: a.cfg.Server.PublicAddress + strings.TrimSuffix(r.URL.Path, "."+string(f))}, Created: now, Author: &feeds.Author{ - Name: appConfig.User.Name, - Email: appConfig.User.Email, + Name: a.cfg.User.Name, + Email: a.cfg.User.Email, }, Image: &feeds.Image{ - Url: appConfig.User.Picture, + Url: a.cfg.User.Picture, }, } for _, p := range posts { @@ -56,10 +56,10 @@ func generateFeed(blog string, f feedType, w http.ResponseWriter, r *http.Reques } feed.Add(&feeds.Item{ Title: p.title(), - Link: &feeds.Link{Href: p.fullURL()}, - Description: p.summary(), + Link: &feeds.Link{Href: a.fullPostURL(p)}, + Description: a.summary(p), Id: p.Path, - Content: string(p.absoluteHTML()), + Content: string(a.absoluteHTML(p)), Created: created, Updated: updated, Enclosure: enc, @@ -82,7 +82,7 @@ func generateFeed(blog string, f feedType, w http.ResponseWriter, r *http.Reques } if err != nil { w.Header().Del(contentType) - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } w.Header().Set(contentType, feedMediaType+charsetUtf8Suffix) diff --git a/go.mod b/go.mod index dac7077..6dbd1cc 100644 --- a/go.mod +++ b/go.mod @@ -14,11 +14,9 @@ require ( github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de github.com/boombuler/barcode v1.0.1 // indirect github.com/caddyserver/certmagic v0.13.1 - // master - github.com/cretz/bine v0.1.1-0.20200124154328-f9f678b84cca + github.com/cretz/bine v0.2.0 github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f - // master - github.com/dgraph-io/ristretto v0.0.4-0.20210504190834-0bf2acd73aa3 + github.com/dgraph-io/ristretto v0.1.0 github.com/elnormous/contenttype v1.0.0 github.com/felixge/httpsnoop v1.0.2 // indirect github.com/go-chi/chi/v5 v5.0.3 @@ -38,6 +36,7 @@ require ( github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible github.com/lestrrat-go/strftime v1.0.4 // indirect github.com/lib/pq v1.9.0 // indirect + github.com/libdns/libdns v0.2.1 // indirect github.com/lopezator/migrator v0.3.0 github.com/magiconair/properties v1.8.5 // indirect github.com/mattn/go-sqlite3 v1.14.7 @@ -46,7 +45,7 @@ require ( github.com/mitchellh/go-server-timing v1.0.1 github.com/mitchellh/mapstructure v1.4.1 // indirect github.com/paulmach/go.geojson v1.4.0 - github.com/pelletier/go-toml v1.9.1 // indirect + github.com/pelletier/go-toml v1.9.2 // indirect github.com/pquerna/otp v1.3.0 github.com/schollz/sqlite3dump v1.2.4 github.com/smartystreets/assertions v1.2.0 // indirect @@ -63,10 +62,9 @@ require ( github.com/yuin/goldmark-emoji v1.0.1 go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.17.0 // indirect - golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a // indirect golang.org/x/net v0.0.0-20210525063256-abc453219eb5 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea // indirect + golang.org/x/sys v0.0.0-20210603125802-9665404d3644 // indirect golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.62.0 // indirect diff --git a/go.sum b/go.sum index 2ca76ac..0ca0126 100644 --- a/go.sum +++ b/go.sum @@ -59,15 +59,15 @@ github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3Ee github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/cretz/bine v0.1.1-0.20200124154328-f9f678b84cca h1:Q2r7AxHdJwWfLtBZwvW621M3sPqxPc6ITv2j1FGsYpw= -github.com/cretz/bine v0.1.1-0.20200124154328-f9f678b84cca/go.mod h1:6PF6fWAvYtwjRGkAuDEJeWNOv3a2hUouSP/yRYXmvHw= +github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo= +github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f h1:q/DpyjJjZs94bziQ7YkBmIlpqbVP7yw179rnzoNVX1M= github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f/go.mod h1:QGrK8vMWWHQYQ3QU9bw9Y9OPNfxccGzfb41qjvVeXtY= -github.com/dgraph-io/ristretto v0.0.4-0.20210504190834-0bf2acd73aa3 h1:jU/wpYsEL+8JPLf/QcjkQKI5g0dOjSuwcMjkThxt5x0= -github.com/dgraph-io/ristretto v0.0.4-0.20210504190834-0bf2acd73aa3/go.mod h1:fux0lOrBhrVCJd3lcTHsIJhq1T2rokOu6v9Vcb3Q9ug= +github.com/dgraph-io/ristretto v0.1.0 h1:Jv3CGQHp9OjuMBSne1485aDpUkTKEcUqF+jm/LuerPI= +github.com/dgraph-io/ristretto v0.1.0/go.mod h1:fux0lOrBhrVCJd3lcTHsIJhq1T2rokOu6v9Vcb3Q9ug= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= @@ -214,8 +214,9 @@ github.com/lestrrat-go/strftime v1.0.4/go.mod h1:E1nN3pCbtMSu1yjSVeyuRFVm/U0xoR7 github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/libdns/libdns v0.2.0 h1:ewg3ByWrdUrxrje8ChPVMBNcotg7H9LQYg+u5De2RzI= github.com/libdns/libdns v0.2.0/go.mod h1:yQCXzk1lEZmmCPa857bnk4TsOiqYasqpyOEeSObbb40= +github.com/libdns/libdns v0.2.1 h1:Wu59T7wSHRgtA0cfxC+n1c/e+O3upJGWytknkmFEDis= +github.com/libdns/libdns v0.2.1/go.mod h1:yQCXzk1lEZmmCPa857bnk4TsOiqYasqpyOEeSObbb40= github.com/lopezator/migrator v0.3.0 h1:VW/rR+J8NYwPdkBxjrFdjwejpgvP59LbmANJxXuNbuk= github.com/lopezator/migrator v0.3.0/go.mod h1:bpVAVPkWSvTw8ya2Pk7E/KiNAyDWNImgivQY79o8/8I= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= @@ -262,8 +263,8 @@ github.com/paulmach/go.geojson v1.4.0 h1:5x5moCkCtDo5x8af62P9IOAYGQcYHtxz2QJ3x1D github.com/paulmach/go.geojson v1.4.0/go.mod h1:YaKx1hKpWF+T2oj2lFJPsW/t1Q5e1jQI61eoQSTwpIs= github.com/pelletier/go-toml v1.0.1-0.20170904195809-1d6b12b7cb29/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.9.1 h1:a6qW1EVNZWH9WGI6CsYdD8WAylkoXBS5yv0XHlh17Tc= -github.com/pelletier/go-toml v1.9.1/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pelletier/go-toml v1.9.2 h1:7NiByeVF4jKSG1lDF3X8LTIkq2/bu+1uYbIm1eS5tzk= +github.com/pelletier/go-toml v1.9.2/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -450,8 +451,8 @@ golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea h1:+WiDlPBBaO+h9vPNZi8uJ3k4BkKQB7Iow3aqwHVA5hI= -golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644 h1:CA1DEQ4NdKphKeL70tvsWNdT5oFh1lOjihRcEDROi0I= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/healthcheck.go b/healthcheck.go index c405f59..23d6e01 100644 --- a/healthcheck.go +++ b/healthcheck.go @@ -6,8 +6,8 @@ import ( "net/http" ) -func healthcheck() bool { - req, err := http.NewRequest(http.MethodGet, appConfig.Server.PublicAddress+"/ping", nil) +func (a *goBlog) healthcheck() bool { + req, err := http.NewRequest(http.MethodGet, a.cfg.Server.PublicAddress+"/ping", nil) if err != nil { fmt.Println(err.Error()) return false @@ -22,8 +22,8 @@ func healthcheck() bool { return resp.StatusCode == 200 } -func healthcheckExitCode() int { - if healthcheck() { +func (a *goBlog) healthcheckExitCode() int { + if a.healthcheck() { return 0 } else { return 1 diff --git a/hooks.go b/hooks.go index eaf3d91..9877cef 100644 --- a/hooks.go +++ b/hooks.go @@ -8,68 +8,62 @@ import ( "time" ) -func preStartHooks() { - for _, cmd := range appConfig.Hooks.PreStart { +func (a *goBlog) preStartHooks() { + for _, cmd := range a.cfg.Hooks.PreStart { func(cmd string) { log.Println("Executing pre-start hook:", cmd) - executeCommand(cmd) + a.cfg.Hooks.executeCommand(cmd) }(cmd) } } type postHookFunc func(*post) -var ( - postPostHooks []postHookFunc - postUpdateHooks []postHookFunc - postDeleteHooks []postHookFunc -) - -func (p *post) postPostHooks() { +func (a *goBlog) postPostHooks(p *post) { // Hooks after post published - for _, cmdTmplString := range appConfig.Hooks.PostPost { + for _, cmdTmplString := range a.cfg.Hooks.PostPost { go func(p *post, cmdTmplString string) { - executeTemplateCommand("post-post", cmdTmplString, map[string]interface{}{ - "URL": p.fullURL(), + a.cfg.Hooks.executeTemplateCommand("post-post", cmdTmplString, map[string]interface{}{ + "URL": a.fullPostURL(p), "Post": p, }) }(p, cmdTmplString) } - for _, f := range postPostHooks { + for _, f := range a.pPostHooks { go f(p) } } -func (p *post) postUpdateHooks() { +func (a *goBlog) postUpdateHooks(p *post) { // Hooks after post updated - for _, cmdTmplString := range appConfig.Hooks.PostUpdate { + for _, cmdTmplString := range a.cfg.Hooks.PostUpdate { go func(p *post, cmdTmplString string) { - executeTemplateCommand("post-update", cmdTmplString, map[string]interface{}{ - "URL": p.fullURL(), + a.cfg.Hooks.executeTemplateCommand("post-update", cmdTmplString, map[string]interface{}{ + "URL": a.fullPostURL(p), "Post": p, }) }(p, cmdTmplString) } - for _, f := range postUpdateHooks { + for _, f := range a.pUpdateHooks { go f(p) } } -func (p *post) postDeleteHooks() { - for _, cmdTmplString := range appConfig.Hooks.PostDelete { +func (a *goBlog) postDeleteHooks(p *post) { + for _, cmdTmplString := range a.cfg.Hooks.PostDelete { go func(p *post, cmdTmplString string) { - executeTemplateCommand("post-delete", cmdTmplString, map[string]interface{}{ - "URL": p.fullURL(), + a.cfg.Hooks.executeTemplateCommand("post-delete", cmdTmplString, map[string]interface{}{ + "URL": a.fullPostURL(p), "Post": p, }) }(p, cmdTmplString) } - for _, f := range postDeleteHooks { + for _, f := range a.pDeleteHooks { go f(p) } } -func executeTemplateCommand(hookType string, tmpl string, data map[string]interface{}) { +func (cfg *configHooks) executeTemplateCommand(hookType string, tmpl string, data map[string]interface{}) { cmdTmpl, err := template.New("cmd").Parse(tmpl) if err != nil { log.Println("Failed to parse cmd template:", err.Error()) @@ -82,18 +76,18 @@ func executeTemplateCommand(hookType string, tmpl string, data map[string]interf } cmd := cmdBuf.String() log.Println("Executing "+hookType+" hook:", cmd) - executeCommand(cmd) + cfg.executeCommand(cmd) } var hourlyHooks = []func(){} -func startHourlyHooks() { +func (a *goBlog) startHourlyHooks() { // Add configured hourly hooks - for _, cmd := range appConfig.Hooks.Hourly { + for _, cmd := range a.cfg.Hooks.Hourly { c := cmd f := func() { log.Println("Executing hourly hook:", c) - executeCommand(c) + a.cfg.Hooks.executeCommand(c) } hourlyHooks = append(hourlyHooks, f) } @@ -121,8 +115,8 @@ func startHourlyHooks() { } } -func executeCommand(cmd string) { - out, err := exec.Command(appConfig.Hooks.Shell, "-c", cmd).CombinedOutput() +func (cfg *configHooks) executeCommand(cmd string) { + out, err := exec.Command(cfg.Shell, "-c", cmd).CombinedOutput() if err != nil { log.Println("Failed to execute command:", err.Error()) } diff --git a/http.go b/http.go index 0c2ac78..dd4267c 100644 --- a/http.go +++ b/http.go @@ -44,35 +44,33 @@ const ( appUserAgent = "GoBlog" ) -var d *dynamicHandler - -func startServer() (err error) { +func (a *goBlog) startServer() (err error) { // Start - d = &dynamicHandler{} + a.d = &dynamicHandler{} // Set basic middlewares - var finalHandler http.Handler = d - if appConfig.Server.PublicHTTPS || appConfig.Server.SecurityHeaders { - finalHandler = securityHeaders(finalHandler) + var finalHandler http.Handler = a.d + if a.cfg.Server.PublicHTTPS || a.cfg.Server.SecurityHeaders { + finalHandler = a.securityHeaders(finalHandler) } finalHandler = servertiming.Middleware(finalHandler, nil) finalHandler = middleware.Heartbeat("/ping")(finalHandler) finalHandler = middleware.Compress(flate.DefaultCompression)(finalHandler) finalHandler = middleware.Recoverer(finalHandler) - if appConfig.Server.Logging { - finalHandler = logMiddleware(finalHandler) + if a.cfg.Server.Logging { + finalHandler = a.logMiddleware(finalHandler) } // Create routers that don't change - if err = buildStaticHandlersRouters(); err != nil { + if err = a.buildStaticHandlersRouters(); err != nil { return err } // Load router - if err = reloadRouter(); err != nil { + if err = a.reloadRouter(); err != nil { return err } // Start Onion service - if appConfig.Server.Tor { + if a.cfg.Server.Tor { go func() { - if err := startOnionService(finalHandler); err != nil { + if err := a.startOnionService(finalHandler); err != nil { log.Println("Tor failed:", err.Error()) } }() @@ -84,10 +82,10 @@ func startServer() (err error) { WriteTimeout: 5 * time.Minute, } addShutdownFunc(shutdownServer(s, "main server")) - if appConfig.Server.PublicHTTPS { + if a.cfg.Server.PublicHTTPS { // Configure certmagic.Default.Storage = &certmagic.FileStorage{Path: "data/https"} - certmagic.DefaultACME.Email = appConfig.Server.LetsEncryptMail + certmagic.DefaultACME.Email = a.cfg.Server.LetsEncryptMail certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA // Start HTTP server for redirects httpServer := &http.Server{ @@ -104,9 +102,9 @@ func startServer() (err error) { }() // Start HTTPS s.Addr = ":https" - hosts := []string{appConfig.Server.publicHostname} - if appConfig.Server.shortPublicHostname != "" { - hosts = append(hosts, appConfig.Server.shortPublicHostname) + hosts := []string{a.cfg.Server.publicHostname} + if a.cfg.Server.shortPublicHostname != "" { + hosts = append(hosts, a.cfg.Server.shortPublicHostname) } listener, e := certmagic.Listen(hosts) if e != nil { @@ -116,7 +114,7 @@ func startServer() (err error) { return err } } else { - s.Addr = ":" + strconv.Itoa(appConfig.Server.Port) + s.Addr = ":" + strconv.Itoa(a.cfg.Server.Port) if err = s.ListenAndServe(); err != nil && err != http.ErrServerClosed { return err } @@ -142,13 +140,13 @@ func redirectToHttps(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, fmt.Sprintf("https://%s%s", requestHost, r.URL.RequestURI()), http.StatusMovedPermanently) } -func reloadRouter() error { - h, err := buildDynamicRouter() +func (a *goBlog) reloadRouter() error { + h, err := a.buildDynamicRouter() if err != nil { return err } - d.swapHandler(h) - purgeCache() + a.d.swapHandler(h) + a.cache.purge() return nil } @@ -157,107 +155,101 @@ const ( feedPath = ".{feed:rss|json|atom}" ) -var ( - privateMode = false - privateModeHandler = []func(http.Handler) http.Handler{} - - captchaHandler http.Handler - - micropubRouter, indieAuthRouter, webmentionsRouter, notificationsRouter, activitypubRouter, editorRouter, commentsRouter, searchRouter *chi.Mux - - setBlogMiddlewares = map[string]func(http.Handler) http.Handler{} - sectionMiddlewares = map[string]func(http.Handler) http.Handler{} - taxonomyMiddlewares = map[string]func(http.Handler) http.Handler{} - photosMiddlewares = map[string]func(http.Handler) http.Handler{} - searchMiddlewares = map[string]func(http.Handler) http.Handler{} - customPagesMiddlewares = map[string]func(http.Handler) http.Handler{} - commentsMiddlewares = map[string]func(http.Handler) http.Handler{} -) - -func buildStaticHandlersRouters() error { - if pm := appConfig.PrivateMode; pm != nil && pm.Enabled { - privateMode = true - privateModeHandler = append(privateModeHandler, authMiddleware) +func (a *goBlog) buildStaticHandlersRouters() error { + if pm := a.cfg.PrivateMode; pm != nil && pm.Enabled { + a.privateMode = true + a.privateModeHandler = append(a.privateModeHandler, a.authMiddleware) + } else { + a.privateMode = false + a.privateModeHandler = []func(http.Handler) http.Handler{} } - captchaHandler = captcha.Server(500, 250) + a.captchaHandler = captcha.Server(500, 250) - micropubRouter = chi.NewRouter() - micropubRouter.Use(checkIndieAuth) - micropubRouter.Get("/", serveMicropubQuery) - micropubRouter.Post("/", serveMicropubPost) - micropubRouter.Post(micropubMediaSubPath, serveMicropubMedia) + a.micropubRouter = chi.NewRouter() + a.micropubRouter.Use(a.checkIndieAuth) + a.micropubRouter.Get("/", a.serveMicropubQuery) + a.micropubRouter.Post("/", a.serveMicropubPost) + a.micropubRouter.Post(micropubMediaSubPath, a.serveMicropubMedia) - indieAuthRouter = chi.NewRouter() - indieAuthRouter.Get("/", indieAuthRequest) - indieAuthRouter.With(authMiddleware).Post("/accept", indieAuthAccept) - indieAuthRouter.Post("/", indieAuthVerification) - indieAuthRouter.Get("/token", indieAuthToken) - indieAuthRouter.Post("/token", indieAuthToken) + a.indieAuthRouter = chi.NewRouter() + a.indieAuthRouter.Get("/", a.indieAuthRequest) + a.indieAuthRouter.With(a.authMiddleware).Post("/accept", a.indieAuthAccept) + a.indieAuthRouter.Post("/", a.indieAuthVerification) + a.indieAuthRouter.Get("/token", a.indieAuthToken) + a.indieAuthRouter.Post("/token", a.indieAuthToken) - webmentionsRouter = chi.NewRouter() - if wm := appConfig.Webmention; wm != nil && !wm.DisableReceiving { - webmentionsRouter.Post("/", handleWebmention) - webmentionsRouter.Group(func(r chi.Router) { + a.webmentionsRouter = chi.NewRouter() + if wm := a.cfg.Webmention; wm != nil && !wm.DisableReceiving { + a.webmentionsRouter.Post("/", a.handleWebmention) + a.webmentionsRouter.Group(func(r chi.Router) { // Authenticated routes - r.Use(authMiddleware) - r.Get("/", webmentionAdmin) - r.Get(paginationPath, webmentionAdmin) - r.Post("/delete", webmentionAdminDelete) - r.Post("/approve", webmentionAdminApprove) - r.Post("/reverify", webmentionAdminReverify) + r.Use(a.authMiddleware) + r.Get("/", a.webmentionAdmin) + r.Get(paginationPath, a.webmentionAdmin) + r.Post("/delete", a.webmentionAdminDelete) + r.Post("/approve", a.webmentionAdminApprove) + r.Post("/reverify", a.webmentionAdminReverify) }) } - notificationsRouter = chi.NewRouter() - notificationsRouter.Use(authMiddleware) - notificationsRouter.Get("/", notificationsAdmin) - notificationsRouter.Get(paginationPath, notificationsAdmin) - notificationsRouter.Post("/delete", notificationsAdminDelete) + a.notificationsRouter = chi.NewRouter() + a.notificationsRouter.Use(a.authMiddleware) + a.notificationsRouter.Get("/", a.notificationsAdmin) + a.notificationsRouter.Get(paginationPath, a.notificationsAdmin) + a.notificationsRouter.Post("/delete", a.notificationsAdminDelete) - if ap := appConfig.ActivityPub; ap != nil && ap.Enabled { - activitypubRouter = chi.NewRouter() - activitypubRouter.Post("/inbox/{blog}", apHandleInbox) - activitypubRouter.Post("/{blog}/inbox", apHandleInbox) + if ap := a.cfg.ActivityPub; ap != nil && ap.Enabled { + a.activitypubRouter = chi.NewRouter() + a.activitypubRouter.Post("/inbox/{blog}", a.apHandleInbox) + a.activitypubRouter.Post("/{blog}/inbox", a.apHandleInbox) } - editorRouter = chi.NewRouter() - editorRouter.Use(authMiddleware) - editorRouter.Get("/", serveEditor) - editorRouter.Post("/", serveEditorPost) + a.editorRouter = chi.NewRouter() + a.editorRouter.Use(a.authMiddleware) + a.editorRouter.Get("/", a.serveEditor) + a.editorRouter.Post("/", a.serveEditorPost) - commentsRouter = chi.NewRouter() - commentsRouter.Use(privateModeHandler...) - commentsRouter.With(cacheMiddleware, noIndexHeader).Get("/{id:[0-9]+}", serveComment) - commentsRouter.With(captchaMiddleware).Post("/", createComment) - commentsRouter.Group(func(r chi.Router) { + a.commentsRouter = chi.NewRouter() + a.commentsRouter.Use(a.privateModeHandler...) + a.commentsRouter.With(a.cache.cacheMiddleware, noIndexHeader).Get("/{id:[0-9]+}", a.serveComment) + a.commentsRouter.With(a.captchaMiddleware).Post("/", a.createComment) + a.commentsRouter.Group(func(r chi.Router) { // Admin - r.Use(authMiddleware) - r.Get("/", commentsAdmin) - r.Get(paginationPath, commentsAdmin) - r.Post("/delete", commentsAdminDelete) + r.Use(a.authMiddleware) + r.Get("/", a.commentsAdmin) + r.Get(paginationPath, a.commentsAdmin) + r.Post("/delete", a.commentsAdminDelete) }) - searchRouter = chi.NewRouter() - searchRouter.Use(privateModeHandler...) - searchRouter.Use(cacheMiddleware) - searchRouter.Get("/", serveSearch) - searchRouter.Post("/", serveSearch) + a.searchRouter = chi.NewRouter() + a.searchRouter.Use(a.privateModeHandler...) + a.searchRouter.Use(a.cache.cacheMiddleware) + a.searchRouter.Get("/", a.serveSearch) + a.searchRouter.Post("/", a.serveSearch) searchResultPath := "/" + searchPlaceholder - searchRouter.Get(searchResultPath, serveSearchResult) - searchRouter.Get(searchResultPath+feedPath, serveSearchResult) - searchRouter.Get(searchResultPath+paginationPath, serveSearchResult) + a.searchRouter.Get(searchResultPath, a.serveSearchResult) + a.searchRouter.Get(searchResultPath+feedPath, a.serveSearchResult) + a.searchRouter.Get(searchResultPath+paginationPath, a.serveSearchResult) - for blog, blogConfig := range appConfig.Blogs { + a.setBlogMiddlewares = map[string]func(http.Handler) http.Handler{} + a.sectionMiddlewares = map[string]func(http.Handler) http.Handler{} + a.taxonomyMiddlewares = map[string]func(http.Handler) http.Handler{} + a.photosMiddlewares = map[string]func(http.Handler) http.Handler{} + a.searchMiddlewares = map[string]func(http.Handler) http.Handler{} + a.customPagesMiddlewares = map[string]func(http.Handler) http.Handler{} + a.commentsMiddlewares = map[string]func(http.Handler) http.Handler{} + + for blog, blogConfig := range a.cfg.Blogs { sbm := middleware.WithValue(blogContextKey, blog) - setBlogMiddlewares[blog] = sbm + a.setBlogMiddlewares[blog] = sbm - blogPath := blogPath(blog) + blogPath := a.blogPath(blog) for _, section := range blogConfig.Sections { if section.Name != "" { secPath := blogPath + "/" + section.Name - sectionMiddlewares[secPath] = middleware.WithValue(indexConfigKey, &indexConfig{ + a.sectionMiddlewares[secPath] = middleware.WithValue(indexConfigKey, &indexConfig{ path: secPath, section: section, }) @@ -267,12 +259,12 @@ func buildStaticHandlersRouters() error { for _, taxonomy := range blogConfig.Taxonomies { if taxonomy.Name != "" { taxPath := blogPath + "/" + taxonomy.Name - taxonomyMiddlewares[taxPath] = middleware.WithValue(taxonomyContextKey, taxonomy) + a.taxonomyMiddlewares[taxPath] = middleware.WithValue(taxonomyContextKey, taxonomy) } } if blogConfig.Photos != nil && blogConfig.Photos.Enabled { - photosMiddlewares[blog] = middleware.WithValue(indexConfigKey, &indexConfig{ + a.photosMiddlewares[blog] = middleware.WithValue(indexConfigKey, &indexConfig{ path: blogPath + blogConfig.Photos.Path, parameter: blogConfig.Photos.Parameter, title: blogConfig.Photos.Title, @@ -282,15 +274,15 @@ func buildStaticHandlersRouters() error { } if blogConfig.Search != nil && blogConfig.Search.Enabled { - searchMiddlewares[blog] = middleware.WithValue(pathContextKey, blogPath+blogConfig.Search.Path) + a.searchMiddlewares[blog] = middleware.WithValue(pathContextKey, blogPath+blogConfig.Search.Path) } for _, cp := range blogConfig.CustomPages { - customPagesMiddlewares[cp.Path] = middleware.WithValue(customPageContextKey, cp) + a.customPagesMiddlewares[cp.Path] = middleware.WithValue(customPageContextKey, cp) } if commentsConfig := blogConfig.Comments; commentsConfig != nil && commentsConfig.Enabled { - commentsMiddlewares[blog] = middleware.WithValue(pathContextKey, blogPath+"/comment") + a.commentsMiddlewares[blog] = middleware.WithValue(pathContextKey, blogPath+"/comment") } } @@ -301,127 +293,127 @@ var ( taxValueMiddlewares = map[string]func(http.Handler) http.Handler{} ) -func buildDynamicRouter() (*chi.Mux, error) { +func (a *goBlog) buildDynamicRouter() (*chi.Mux, error) { r := chi.NewRouter() // Basic middleware - r.Use(redirectShortDomain) + r.Use(a.redirectShortDomain) r.Use(middleware.RedirectSlashes) r.Use(middleware.CleanPath) r.Use(middleware.GetHead) - if !appConfig.Cache.Enable { + if !a.cfg.Cache.Enable { r.Use(middleware.NoCache) } // No Index Header - if privateMode { + if a.privateMode { r.Use(noIndexHeader) } // Login middleware etc. - r.Use(checkIsLogin) - r.Use(checkIsCaptcha) - r.Use(checkLoggedIn) + r.Use(a.checkIsLogin) + r.Use(a.checkIsCaptcha) + r.Use(a.checkLoggedIn) // Logout - r.With(authMiddleware).Get("/login", serveLogin) - r.With(authMiddleware).Get("/logout", serveLogout) + r.With(a.authMiddleware).Get("/login", serveLogin) + r.With(a.authMiddleware).Get("/logout", a.serveLogout) // Micropub - r.Mount(micropubPath, micropubRouter) + r.Mount(micropubPath, a.micropubRouter) // IndieAuth - r.Mount("/indieauth", indieAuthRouter) + r.Mount("/indieauth", a.indieAuthRouter) // ActivityPub and stuff - if ap := appConfig.ActivityPub; ap != nil && ap.Enabled { - r.Mount("/activitypub", activitypubRouter) - r.With(cacheMiddleware).Get("/.well-known/webfinger", apHandleWebfinger) - r.With(cacheMiddleware).Get("/.well-known/host-meta", handleWellKnownHostMeta) - r.With(cacheMiddleware).Get("/.well-known/nodeinfo", serveNodeInfoDiscover) - r.With(cacheMiddleware).Get("/nodeinfo", serveNodeInfo) + if ap := a.cfg.ActivityPub; ap != nil && ap.Enabled { + r.Mount("/activitypub", a.activitypubRouter) + r.With(a.cache.cacheMiddleware).Get("/.well-known/webfinger", a.apHandleWebfinger) + r.With(a.cache.cacheMiddleware).Get("/.well-known/host-meta", handleWellKnownHostMeta) + r.With(a.cache.cacheMiddleware).Get("/.well-known/nodeinfo", a.serveNodeInfoDiscover) + r.With(a.cache.cacheMiddleware).Get("/nodeinfo", a.serveNodeInfo) } // Webmentions - r.Mount(webmentionPath, webmentionsRouter) + r.Mount(webmentionPath, a.webmentionsRouter) // Notifications - r.Mount(notificationsPath, notificationsRouter) + r.Mount(notificationsPath, a.notificationsRouter) // Posts - pp, err := allPostPaths(statusPublished) + pp, err := a.db.allPostPaths(statusPublished) if err != nil { return nil, err } r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(checkActivityStreamsRequest, cacheMiddleware) + r.Use(a.privateModeHandler...) + r.Use(a.checkActivityStreamsRequest, a.cache.cacheMiddleware) for _, path := range pp { - r.Get(path, servePost) + r.Get(path, a.servePost) } }) // Drafts - dp, err := allPostPaths(statusDraft) + dp, err := a.db.allPostPaths(statusDraft) if err != nil { return nil, err } r.Group(func(r chi.Router) { - r.Use(authMiddleware) + r.Use(a.authMiddleware) for _, path := range dp { - r.Get(path, servePost) + r.Get(path, a.servePost) } }) // Post aliases - allPostAliases, err := allPostAliases() + allPostAliases, err := a.db.allPostAliases() if err != nil { return nil, err } r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware) for _, path := range allPostAliases { - r.Get(path, servePostAlias) + r.Get(path, a.servePostAlias) } }) // Assets - for _, path := range allAssetPaths() { - r.Get(path, serveAsset) + for _, path := range a.allAssetPaths() { + r.Get(path, a.serveAsset) } // Static files for _, path := range allStaticPaths() { - r.Get(path, serveStaticFile) + r.Get(path, a.serveStaticFile) } // Media files - r.With(privateModeHandler...).Get(`/m/{file:[0-9a-fA-F]+(\.[0-9a-zA-Z]+)?}`, serveMediaFile) + r.With(a.privateModeHandler...).Get(`/m/{file:[0-9a-fA-F]+(\.[0-9a-zA-Z]+)?}`, a.serveMediaFile) // Captcha - r.Handle("/captcha/*", captchaHandler) + r.Handle("/captcha/*", a.captchaHandler) // Short paths - r.With(privateModeHandler...).With(cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", redirectToLongPath) + r.With(a.privateModeHandler...).With(a.cache.cacheMiddleware).Get("/s/{id:[0-9a-fA-F]+}", a.redirectToLongPath) - for blog, blogConfig := range appConfig.Blogs { - blogPath := blogPath(blog) + for blog, blogConfig := range a.cfg.Blogs { + blogPath := a.blogPath(blog) - sbm := setBlogMiddlewares[blog] + sbm := a.setBlogMiddlewares[blog] // Sections r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm) for _, section := range blogConfig.Sections { if section.Name != "" { secPath := blogPath + "/" + section.Name r.Group(func(r chi.Router) { - r.Use(sectionMiddlewares[secPath]) - r.Get(secPath, serveIndex) - r.Get(secPath+feedPath, serveIndex) - r.Get(secPath+paginationPath, serveIndex) + r.Use(a.sectionMiddlewares[secPath]) + r.Get(secPath, a.serveIndex) + r.Get(secPath+feedPath, a.serveIndex) + r.Get(secPath+paginationPath, a.serveIndex) }) } } @@ -431,14 +423,14 @@ func buildDynamicRouter() (*chi.Mux, error) { for _, taxonomy := range blogConfig.Taxonomies { if taxonomy.Name != "" { taxPath := blogPath + "/" + taxonomy.Name - taxValues, err := allTaxonomyValues(blog, taxonomy.Name) + taxValues, err := a.db.allTaxonomyValues(blog, taxonomy.Name) if err != nil { return nil, err } r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm) - r.With(taxonomyMiddlewares[taxPath]).Get(taxPath, serveTaxonomy) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm) + r.With(a.taxonomyMiddlewares[taxPath]).Get(taxPath, a.serveTaxonomy) for _, tv := range taxValues { r.Group(func(r chi.Router) { vPath := taxPath + "/" + urlize(tv) @@ -450,9 +442,9 @@ func buildDynamicRouter() (*chi.Mux, error) { }) } r.Use(taxValueMiddlewares[vPath]) - r.Get(vPath, serveIndex) - r.Get(vPath+feedPath, serveIndex) - r.Get(vPath+paginationPath, serveIndex) + r.Get(vPath, a.serveIndex) + r.Get(vPath+feedPath, a.serveIndex) + r.Get(vPath+paginationPath, a.serveIndex) }) } }) @@ -462,75 +454,75 @@ func buildDynamicRouter() (*chi.Mux, error) { // Photos if blogConfig.Photos != nil && blogConfig.Photos.Enabled { r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm, photosMiddlewares[blog]) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm, a.photosMiddlewares[blog]) photoPath := blogPath + blogConfig.Photos.Path - r.Get(photoPath, serveIndex) - r.Get(photoPath+feedPath, serveIndex) - r.Get(photoPath+paginationPath, serveIndex) + r.Get(photoPath, a.serveIndex) + r.Get(photoPath+feedPath, a.serveIndex) + r.Get(photoPath+paginationPath, a.serveIndex) }) } // Search if blogConfig.Search != nil && blogConfig.Search.Enabled { searchPath := blogPath + blogConfig.Search.Path - r.With(sbm, searchMiddlewares[blog]).Mount(searchPath, searchRouter) + r.With(sbm, a.searchMiddlewares[blog]).Mount(searchPath, a.searchRouter) } // Stats if blogConfig.BlogStats != nil && blogConfig.BlogStats.Enabled { statsPath := blogPath + blogConfig.BlogStats.Path r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm) - r.Get(statsPath, serveBlogStats) - r.Get(statsPath+".table.html", serveBlogStatsTable) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm) + r.Get(statsPath, a.serveBlogStats) + r.Get(statsPath+".table.html", a.serveBlogStatsTable) }) } // Date archives r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm) yearRegex := `/{year:x|\d\d\d\d}` monthRegex := `/{month:x|\d\d}` dayRegex := `/{day:\d\d}` yearPath := blogPath + yearRegex - r.Get(yearPath, serveDate) - r.Get(yearPath+feedPath, serveDate) - r.Get(yearPath+paginationPath, serveDate) + r.Get(yearPath, a.serveDate) + r.Get(yearPath+feedPath, a.serveDate) + r.Get(yearPath+paginationPath, a.serveDate) monthPath := yearPath + monthRegex - r.Get(monthPath, serveDate) - r.Get(monthPath+feedPath, serveDate) - r.Get(monthPath+paginationPath, serveDate) + r.Get(monthPath, a.serveDate) + r.Get(monthPath+feedPath, a.serveDate) + r.Get(monthPath+paginationPath, a.serveDate) dayPath := monthPath + dayRegex - r.Get(dayPath, serveDate) - r.Get(dayPath+feedPath, serveDate) - r.Get(dayPath+paginationPath, serveDate) + r.Get(dayPath, a.serveDate) + r.Get(dayPath+feedPath, a.serveDate) + r.Get(dayPath+paginationPath, a.serveDate) }) // Blog if !blogConfig.PostAsHome { r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) + r.Use(a.privateModeHandler...) r.Use(sbm) - r.With(checkActivityStreamsRequest, cacheMiddleware).Get(blogConfig.Path, serveHome) - r.With(cacheMiddleware).Get(blogConfig.Path+feedPath, serveHome) - r.With(cacheMiddleware).Get(blogPath+paginationPath, serveHome) + r.With(a.checkActivityStreamsRequest, a.cache.cacheMiddleware).Get(blogConfig.Path, a.serveHome) + r.With(a.cache.cacheMiddleware).Get(blogConfig.Path+feedPath, a.serveHome) + r.With(a.cache.cacheMiddleware).Get(blogPath+paginationPath, a.serveHome) }) } // Custom pages for _, cp := range blogConfig.CustomPages { - scp := customPagesMiddlewares[cp.Path] + scp := a.customPagesMiddlewares[cp.Path] if cp.Cache { - r.With(privateModeHandler...).With(cacheMiddleware, sbm, scp).Get(cp.Path, serveCustomPage) + r.With(a.privateModeHandler...).With(a.cache.cacheMiddleware, sbm, scp).Get(cp.Path, a.serveCustomPage) } else { - r.With(privateModeHandler...).With(sbm, scp).Get(cp.Path, serveCustomPage) + r.With(a.privateModeHandler...).With(sbm, scp).Get(cp.Path, a.serveCustomPage) } } @@ -540,50 +532,50 @@ func buildDynamicRouter() (*chi.Mux, error) { if randomPath == "" { randomPath = "/random" } - r.With(privateModeHandler...).With(sbm).Get(blogPath+randomPath, redirectToRandomPost) + r.With(a.privateModeHandler...).With(sbm).Get(blogPath+randomPath, a.redirectToRandomPost) } // Editor - r.With(sbm).Mount(blogPath+"/editor", editorRouter) + r.With(sbm).Mount(blogPath+"/editor", a.editorRouter) // Comments if commentsConfig := blogConfig.Comments; commentsConfig != nil && commentsConfig.Enabled { commentsPath := blogPath + "/comment" - r.With(sbm, commentsMiddlewares[blog]).Mount(commentsPath, commentsRouter) + r.With(sbm, a.commentsMiddlewares[blog]).Mount(commentsPath, a.commentsRouter) } // Blogroll if brConfig := blogConfig.Blogroll; brConfig != nil && brConfig.Enabled { brPath := blogPath + brConfig.Path r.Group(func(r chi.Router) { - r.Use(privateModeHandler...) - r.Use(cacheMiddleware, sbm) - r.Get(brPath, serveBlogroll) - r.Get(brPath+".opml", serveBlogrollExport) + r.Use(a.privateModeHandler...) + r.Use(a.cache.cacheMiddleware, sbm) + r.Get(brPath, a.serveBlogroll) + r.Get(brPath+".opml", a.serveBlogrollExport) }) } } // Sitemap - r.With(privateModeHandler...).With(cacheMiddleware).Get(sitemapPath, serveSitemap) + r.With(a.privateModeHandler...).With(a.cache.cacheMiddleware).Get(sitemapPath, a.serveSitemap) // Robots.txt - doesn't need cache, because it's too simple - if !privateMode { - r.Get("/robots.txt", serveRobotsTXT) + if !a.privateMode { + r.Get("/robots.txt", a.serveRobotsTXT) } else { r.Get("/robots.txt", servePrivateRobotsTXT) } // Check redirects, then serve 404 - r.With(cacheMiddleware, checkRegexRedirects).NotFound(serve404) + r.With(a.cache.cacheMiddleware, a.checkRegexRedirects).NotFound(a.serve404) - r.MethodNotAllowed(serveNotAllowed) + r.MethodNotAllowed(a.serveNotAllowed) return r, nil } -func blogPath(blog string) string { - blogPath := appConfig.Blogs[blog].Path +func (a *goBlog) blogPath(blog string) string { + blogPath := a.cfg.Blogs[blog].Path if blogPath == "/" { return "" } @@ -595,20 +587,20 @@ const pathContextKey requestContextKey = "httpPath" var cspDomains = "" -func refreshCSPDomains() { +func (a *goBlog) refreshCSPDomains() { cspDomains = "" - if mp := appConfig.Micropub.MediaStorage; mp != nil && mp.MediaURL != "" { + if mp := a.cfg.Micropub.MediaStorage; mp != nil && mp.MediaURL != "" { if u, err := url.Parse(mp.MediaURL); err == nil { cspDomains += " " + u.Hostname() } } - if len(appConfig.Server.CSPDomains) > 0 { - cspDomains += " " + strings.Join(appConfig.Server.CSPDomains, " ") + if len(a.cfg.Server.CSPDomains) > 0 { + cspDomains += " " + strings.Join(a.cfg.Server.CSPDomains, " ") } } -func securityHeaders(next http.Handler) http.Handler { - refreshCSPDomains() +func (a *goBlog) securityHeaders(next http.Handler) http.Handler { + a.refreshCSPDomains() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Strict-Transport-Security", "max-age=31536000;") w.Header().Set("Referrer-Policy", "no-referrer") @@ -616,8 +608,8 @@ func securityHeaders(next http.Handler) http.Handler { w.Header().Set("X-Frame-Options", "SAMEORIGIN") w.Header().Set("X-Xss-Protection", "1; mode=block") w.Header().Set("Content-Security-Policy", "default-src 'self'"+cspDomains) - if appConfig.Server.Tor && torAddress != "" { - w.Header().Set("Onion-Location", fmt.Sprintf("http://%v%v", torAddress, r.RequestURI)) + if a.cfg.Server.Tor && a.torAddress != "" { + w.Header().Set("Onion-Location", fmt.Sprintf("http://%v%v", a.torAddress, r.RequestURI)) } next.ServeHTTP(w, r) }) diff --git a/httpLogs.go b/httpLogs.go index f182d2d..5ee3a20 100644 --- a/httpLogs.go +++ b/httpLogs.go @@ -8,15 +8,13 @@ import ( rotatelogs "github.com/lestrrat-go/file-rotatelogs" ) -var logf *rotatelogs.RotateLogs - -func initHTTPLog() (err error) { - if !appConfig.Server.Logging { +func (a *goBlog) initHTTPLog() (err error) { + if !a.cfg.Server.Logging { return nil } - logf, err = rotatelogs.New( - appConfig.Server.LogFile+".%Y%m%d", - rotatelogs.WithLinkName(appConfig.Server.LogFile), + a.logf, err = rotatelogs.New( + a.cfg.Server.LogFile+".%Y%m%d", + rotatelogs.WithLinkName(a.cfg.Server.LogFile), rotatelogs.WithClock(rotatelogs.UTC), rotatelogs.WithMaxAge(30*24*time.Hour), rotatelogs.WithRotationTime(24*time.Hour), @@ -24,8 +22,8 @@ func initHTTPLog() (err error) { return } -func logMiddleware(next http.Handler) http.Handler { - h := handlers.CombinedLoggingHandler(logf, next) +func (a *goBlog) logMiddleware(next http.Handler) http.Handler { + h := handlers.CombinedLoggingHandler(a.logf, next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Remove remote address for privacy r.RemoteAddr = "" diff --git a/indieAuth.go b/indieAuth.go index c5bbcd8..7f6503a 100644 --- a/indieAuth.go +++ b/indieAuth.go @@ -8,15 +8,15 @@ import ( const indieAuthScope requestContextKey = "scope" -func checkIndieAuth(next http.Handler) http.Handler { +func (a *goBlog) checkIndieAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bearerToken := r.Header.Get("Authorization") if len(bearerToken) == 0 { bearerToken = r.URL.Query().Get("access_token") } - tokenData, err := verifyIndieAuthToken(bearerToken) + tokenData, err := a.db.verifyIndieAuthToken(bearerToken) if err != nil { - serveError(w, r, err.Error(), http.StatusUnauthorized) + a.serveError(w, r, err.Error(), http.StatusUnauthorized) return } next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), indieAuthScope, strings.Join(tokenData.Scopes, " ")))) diff --git a/indieAuthServer.go b/indieAuthServer.go index 561a7b1..2da8a6e 100644 --- a/indieAuthServer.go +++ b/indieAuthServer.go @@ -27,10 +27,10 @@ type indieAuthData struct { time time.Time } -func indieAuthRequest(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) indieAuthRequest(w http.ResponseWriter, r *http.Request) { // Authorization request if err := r.ParseForm(); err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } data := &indieAuthData{ @@ -39,21 +39,21 @@ func indieAuthRequest(w http.ResponseWriter, r *http.Request) { State: r.Form.Get("state"), } if rt := r.Form.Get("response_type"); rt != "code" && rt != "id" && rt != "" { - serveError(w, r, "response_type must be code", http.StatusBadRequest) + a.serveError(w, r, "response_type must be code", http.StatusBadRequest) return } if scope := r.Form.Get("scope"); scope != "" { data.Scopes = strings.Split(scope, " ") } if !isValidProfileURL(data.ClientID) || !isValidProfileURL(data.RedirectURI) { - serveError(w, r, "client_id and redirect_uri need to by valid URLs", http.StatusBadRequest) + a.serveError(w, r, "client_id and redirect_uri need to by valid URLs", http.StatusBadRequest) return } if data.State == "" { - serveError(w, r, "state must not be empty", http.StatusBadRequest) + a.serveError(w, r, "state must not be empty", http.StatusBadRequest) return } - render(w, r, "indieauth", &renderData{ + a.render(w, r, "indieauth", &renderData{ Data: data, }) } @@ -79,10 +79,10 @@ func isValidProfileURL(profileURL string) bool { return true } -func indieAuthAccept(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) indieAuthAccept(w http.ResponseWriter, r *http.Request) { // Authentication flow if err := r.ParseForm(); err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } data := &indieAuthData{ @@ -94,13 +94,13 @@ func indieAuthAccept(w http.ResponseWriter, r *http.Request) { } sha := sha1.New() if _, err := sha.Write([]byte(data.time.String() + data.ClientID)); err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } data.code = fmt.Sprintf("%x", sha.Sum(nil)) - err := data.saveAuthorization() + err := a.db.saveAuthorization(data) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, data.RedirectURI+"?code="+data.code+"&state="+data.State, http.StatusFound) @@ -114,10 +114,10 @@ type tokenResponse struct { ClientID string `json:"client_id,omitempty"` } -func indieAuthVerification(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) indieAuthVerification(w http.ResponseWriter, r *http.Request) { // Authorization verification if err := r.ParseForm(); err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } data := &indieAuthData{ @@ -125,33 +125,33 @@ func indieAuthVerification(w http.ResponseWriter, r *http.Request) { ClientID: r.Form.Get("client_id"), RedirectURI: r.Form.Get("redirect_uri"), } - valid, err := data.verifyAuthorization() + valid, err := a.db.verifyAuthorization(data) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } if !valid { - serveError(w, r, "Authentication not valid", http.StatusForbidden) + a.serveError(w, r, "Authentication not valid", http.StatusForbidden) return } b, _ := json.Marshal(tokenResponse{ - Me: appConfig.Server.PublicAddress, + Me: a.cfg.Server.PublicAddress, }) w.Header().Set(contentType, contentTypeJSONUTF8) _, _ = writeMinified(w, contentTypeJSON, b) } -func indieAuthToken(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) indieAuthToken(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { // Token verification - data, err := verifyIndieAuthToken(r.Header.Get("Authorization")) + data, err := a.db.verifyIndieAuthToken(r.Header.Get("Authorization")) if err != nil { - serveError(w, r, "Invalid token or token not found", http.StatusUnauthorized) + a.serveError(w, r, "Invalid token or token not found", http.StatusUnauthorized) return } res := &tokenResponse{ Scope: strings.Join(data.Scopes, " "), - Me: appConfig.Server.PublicAddress, + Me: a.cfg.Server.PublicAddress, ClientID: data.ClientID, } b, _ := json.Marshal(res) @@ -160,12 +160,12 @@ func indieAuthToken(w http.ResponseWriter, r *http.Request) { return } else if r.Method == http.MethodPost { if err := r.ParseForm(); err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } // Token Revocation if r.Form.Get("action") == "revoke" { - revokeIndieAuthToken(r.Form.Get("token")) + a.db.revokeIndieAuthToken(r.Form.Get("token")) w.WriteHeader(http.StatusOK) return } @@ -176,55 +176,55 @@ func indieAuthToken(w http.ResponseWriter, r *http.Request) { ClientID: r.Form.Get("client_id"), RedirectURI: r.Form.Get("redirect_uri"), } - valid, err := data.verifyAuthorization() + valid, err := a.db.verifyAuthorization(data) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } if !valid { - serveError(w, r, "Authentication not valid", http.StatusForbidden) + a.serveError(w, r, "Authentication not valid", http.StatusForbidden) return } if len(data.Scopes) < 1 { - serveError(w, r, "No scope", http.StatusBadRequest) + a.serveError(w, r, "No scope", http.StatusBadRequest) return } data.time = time.Now() sha := sha1.New() if _, err := sha.Write([]byte(data.time.String() + data.ClientID)); err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } data.token = fmt.Sprintf("%x", sha.Sum(nil)) - err = data.saveToken() + err = a.db.saveToken(data) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } res := &tokenResponse{ TokenType: "Bearer", AccessToken: data.token, Scope: strings.Join(data.Scopes, " "), - Me: appConfig.Server.PublicAddress, + Me: a.cfg.Server.PublicAddress, } b, _ := json.Marshal(res) w.Header().Set(contentType, contentTypeJSONUTF8) _, _ = writeMinified(w, contentTypeJSON, b) return } - serveError(w, r, "", http.StatusBadRequest) + a.serveError(w, r, "", http.StatusBadRequest) return } } -func (data *indieAuthData) saveAuthorization() (err error) { - _, err = appDb.exec("insert into indieauthauth (time, code, client, redirect, scope) values (?, ?, ?, ?, ?)", data.time.Unix(), data.code, data.ClientID, data.RedirectURI, strings.Join(data.Scopes, " ")) +func (db *database) saveAuthorization(data *indieAuthData) (err error) { + _, err = db.exec("insert into indieauthauth (time, code, client, redirect, scope) values (?, ?, ?, ?, ?)", data.time.Unix(), data.code, data.ClientID, data.RedirectURI, strings.Join(data.Scopes, " ")) return } -func (data *indieAuthData) verifyAuthorization() (valid bool, err error) { +func (db *database) verifyAuthorization(data *indieAuthData) (valid bool, err error) { // code valid for 600 seconds - row, err := appDb.queryRow("select code, client, redirect, scope from indieauthauth where time >= ? and code = ? and client = ? and redirect = ?", time.Now().Unix()-600, data.code, data.ClientID, data.RedirectURI) + row, err := db.queryRow("select code, client, redirect, scope from indieauthauth where time >= ? and code = ? and client = ? and redirect = ?", time.Now().Unix()-600, data.code, data.ClientID, data.RedirectURI) if err != nil { return false, err } @@ -239,22 +239,22 @@ func (data *indieAuthData) verifyAuthorization() (valid bool, err error) { data.Scopes = strings.Split(scope, " ") } valid = true - _, err = appDb.exec("delete from indieauthauth where code = ? or time < ?", data.code, time.Now().Unix()-600) + _, err = db.exec("delete from indieauthauth where code = ? or time < ?", data.code, time.Now().Unix()-600) data.code = "" return } -func (data *indieAuthData) saveToken() (err error) { - _, err = appDb.exec("insert into indieauthtoken (time, token, client, scope) values (?, ?, ?, ?)", data.time.Unix(), data.token, data.ClientID, strings.Join(data.Scopes, " ")) +func (db *database) saveToken(data *indieAuthData) (err error) { + _, err = db.exec("insert into indieauthtoken (time, token, client, scope) values (?, ?, ?, ?)", data.time.Unix(), data.token, data.ClientID, strings.Join(data.Scopes, " ")) return } -func verifyIndieAuthToken(token string) (data *indieAuthData, err error) { +func (db *database) verifyIndieAuthToken(token string) (data *indieAuthData, err error) { token = strings.ReplaceAll(token, "Bearer ", "") data = &indieAuthData{ Scopes: []string{}, } - row, err := appDb.queryRow("select time, token, client, scope from indieauthtoken where token = @token", sql.Named("token", token)) + row, err := db.queryRow("select time, token, client, scope from indieauthtoken where token = @token", sql.Named("token", token)) if err != nil { return nil, err } @@ -273,8 +273,8 @@ func verifyIndieAuthToken(token string) (data *indieAuthData, err error) { return } -func revokeIndieAuthToken(token string) { +func (db *database) revokeIndieAuthToken(token string) { if token != "" { - _, _ = appDb.exec("delete from indieauthtoken where token=?", token) + _, _ = db.exec("delete from indieauthtoken where token=?", token) } } diff --git a/main.go b/main.go index bbea259..e8fcdcb 100644 --- a/main.go +++ b/main.go @@ -47,9 +47,11 @@ func main() { }() } + app := &goBlog{} + // Initialize config log.Println("Initialize configuration...") - if err = initConfig(); err != nil { + if err = app.initConfig(); err != nil { logErrAndQuit("Failed to init config:", err.Error()) return } @@ -57,7 +59,7 @@ func main() { // Healthcheck tool if len(os.Args) >= 2 && os.Args[1] == "healthcheck" { // Connect to public address + "/ping" and exit with 0 when successful - health := healthcheckExitCode() + health := app.healthcheckExitCode() shutdown() os.Exit(health) return @@ -66,8 +68,8 @@ func main() { // Tool to generate TOTP secret if len(os.Args) >= 2 && os.Args[1] == "totp-secret" { key, err := totp.Generate(totp.GenerateOpts{ - Issuer: appConfig.Server.PublicAddress, - AccountName: appConfig.User.Nick, + Issuer: app.cfg.Server.PublicAddress, + AccountName: app.cfg.User.Nick, }) if err != nil { logErrAndQuit(err.Error()) @@ -82,65 +84,64 @@ func main() { initGC() // Execute pre-start hooks - preStartHooks() + app.preStartHooks() // Initialize database and markdown log.Println("Initialize database...") - if err = initDatabase(); err != nil { + if err = app.initDatabase(); err != nil { logErrAndQuit("Failed to init database:", err.Error()) return } log.Println("Initialize server components...") - initMarkdown() + app.initMarkdown() // Link check tool after init of markdown if len(os.Args) >= 2 && os.Args[1] == "check" { - checkAllExternalLinks() + app.checkAllExternalLinks() shutdown() return } // More initializations - initMinify() - if err = initTemplateAssets(); err != nil { // Needs minify + if err = app.initTemplateAssets(); err != nil { // Needs minify logErrAndQuit("Failed to init template assets:", err.Error()) return } - if err = initTemplateStrings(); err != nil { + if err = app.initTemplateStrings(); err != nil { logErrAndQuit("Failed to init template translations:", err.Error()) return } - if err = initRendering(); err != nil { // Needs assets and minify + if err = app.initRendering(); err != nil { // Needs assets and minify logErrAndQuit("Failed to init HTML rendering:", err.Error()) return } - if err = initCache(); err != nil { + if err = app.initCache(); err != nil { logErrAndQuit("Failed to init HTTP cache:", err.Error()) return } - if err = initRegexRedirects(); err != nil { + if err = app.initRegexRedirects(); err != nil { logErrAndQuit("Failed to init redirects:", err.Error()) return } - if err = initHTTPLog(); err != nil { + if err = app.initHTTPLog(); err != nil { logErrAndQuit("Failed to init HTTP logging:", err.Error()) return } - if err = initActivityPub(); err != nil { + if err = app.initActivityPub(); err != nil { logErrAndQuit("Failed to init ActivityPub:", err.Error()) return } - initWebmention() - initTelegram() - initBlogStats() - initSessions() + app.initWebmention() + app.initTelegram() + app.initBlogStats() + app.initSessions() // Start cron hooks - startHourlyHooks() + app.startHourlyHooks() // Start the server log.Println("Starting server(s)...") - err = startServer() + err = app.startServer() if err != nil { logErrAndQuit("Failed to start server(s):", err.Error()) return diff --git a/markdown.go b/markdown.go index e1a6fc9..9ad8385 100644 --- a/markdown.go +++ b/markdown.go @@ -15,9 +15,7 @@ import ( "github.com/yuin/goldmark/util" ) -var defaultMarkdown, absoluteMarkdown goldmark.Markdown - -func initMarkdown() { +func (a *goBlog) initMarkdown() { defaultGoldmarkOptions := []goldmark.Option{ goldmark.WithRendererOptions( html.WithUnsafe(), @@ -35,22 +33,28 @@ func initMarkdown() { emoji.Emoji, ), } - defaultMarkdown = goldmark.New(append(defaultGoldmarkOptions, goldmark.WithExtensions(&customExtension{absoluteLinks: false}))...) - absoluteMarkdown = goldmark.New(append(defaultGoldmarkOptions, goldmark.WithExtensions(&customExtension{absoluteLinks: true}))...) + a.md = goldmark.New(append(defaultGoldmarkOptions, goldmark.WithExtensions(&customExtension{ + absoluteLinks: false, + publicAddress: a.cfg.Server.PublicAddress, + }))...) + a.absoluteMd = goldmark.New(append(defaultGoldmarkOptions, goldmark.WithExtensions(&customExtension{ + absoluteLinks: true, + publicAddress: a.cfg.Server.PublicAddress, + }))...) } -func renderMarkdown(source string, absoluteLinks bool) (rendered []byte, err error) { +func (a *goBlog) renderMarkdown(source string, absoluteLinks bool) (rendered []byte, err error) { var buffer bytes.Buffer if absoluteLinks { - err = absoluteMarkdown.Convert([]byte(source), &buffer) + err = a.absoluteMd.Convert([]byte(source), &buffer) } else { - err = defaultMarkdown.Convert([]byte(source), &buffer) + err = a.md.Convert([]byte(source), &buffer) } return buffer.Bytes(), err } -func renderText(s string) string { - h, err := renderMarkdown(s, false) +func (a *goBlog) renderText(s string) string { + h, err := a.renderMarkdown(s, false) if err != nil { return "" } @@ -66,18 +70,21 @@ func renderText(s string) string { // Links type customExtension struct { absoluteLinks bool + publicAddress string } func (l *customExtension) Extend(m goldmark.Markdown) { m.Renderer().AddOptions(renderer.WithNodeRenderers( util.Prioritized(&customRenderer{ absoluteLinks: l.absoluteLinks, + publicAddress: l.publicAddress, }, 500), )) } type customRenderer struct { absoluteLinks bool + publicAddress string } func (c *customRenderer) RegisterFuncs(r renderer.NodeRendererFuncRegisterer) { @@ -91,8 +98,8 @@ func (c *customRenderer) renderLink(w util.BufWriter, _ []byte, node ast.Node, e _, _ = w.WriteString(" 0 { - entry.Parameters[appConfig.Micropub.CategoryParam] = mf.Properties.Category + entry.Parameters[a.cfg.Micropub.CategoryParam] = mf.Properties.Category } if len(mf.Properties.InReplyTo) == 1 { - entry.Parameters[appConfig.Micropub.ReplyParam] = mf.Properties.InReplyTo + entry.Parameters[a.cfg.Micropub.ReplyParam] = mf.Properties.InReplyTo } if len(mf.Properties.LikeOf) == 1 { - entry.Parameters[appConfig.Micropub.LikeParam] = mf.Properties.LikeOf + entry.Parameters[a.cfg.Micropub.LikeParam] = mf.Properties.LikeOf } if len(mf.Properties.BookmarkOf) == 1 { - entry.Parameters[appConfig.Micropub.BookmarkParam] = mf.Properties.BookmarkOf + entry.Parameters[a.cfg.Micropub.BookmarkParam] = mf.Properties.BookmarkOf } if len(mf.Properties.Audio) > 0 { - entry.Parameters[appConfig.Micropub.AudioParam] = mf.Properties.Audio + entry.Parameters[a.cfg.Micropub.AudioParam] = mf.Properties.Audio } if len(mf.Properties.Photo) > 0 { for _, photo := range mf.Properties.Photo { if theString, justString := photo.(string); justString { - entry.Parameters[appConfig.Micropub.PhotoParam] = append(entry.Parameters[appConfig.Micropub.PhotoParam], theString) - entry.Parameters[appConfig.Micropub.PhotoDescriptionParam] = append(entry.Parameters[appConfig.Micropub.PhotoDescriptionParam], "") + entry.Parameters[a.cfg.Micropub.PhotoParam] = append(entry.Parameters[a.cfg.Micropub.PhotoParam], theString) + entry.Parameters[a.cfg.Micropub.PhotoDescriptionParam] = append(entry.Parameters[a.cfg.Micropub.PhotoDescriptionParam], "") } else if thePhoto, isPhoto := photo.(map[string]interface{}); isPhoto { - entry.Parameters[appConfig.Micropub.PhotoParam] = append(entry.Parameters[appConfig.Micropub.PhotoParam], cast.ToString(thePhoto["value"])) - entry.Parameters[appConfig.Micropub.PhotoDescriptionParam] = append(entry.Parameters[appConfig.Micropub.PhotoDescriptionParam], cast.ToString(thePhoto["alt"])) + entry.Parameters[a.cfg.Micropub.PhotoParam] = append(entry.Parameters[a.cfg.Micropub.PhotoParam], cast.ToString(thePhoto["value"])) + entry.Parameters[a.cfg.Micropub.PhotoDescriptionParam] = append(entry.Parameters[a.cfg.Micropub.PhotoDescriptionParam], cast.ToString(thePhoto["alt"])) } } } if len(mf.Properties.MpSlug) == 1 { entry.Slug = mf.Properties.MpSlug[0] } - err := entry.computeExtraPostParameters() + err := a.computeExtraPostParameters(entry) if err != nil { return nil, err } @@ -375,7 +375,7 @@ func convertMPMfToPost(mf *microformatItem) (*post, error) { } -func (p *post) computeExtraPostParameters() error { +func (a *goBlog) computeExtraPostParameters(p *post) error { p.Content = regexp.MustCompile("\r\n").ReplaceAllString(p.Content, "\n") if split := strings.Split(p.Content, "---\n"); len(split) >= 3 && len(strings.TrimSpace(split[0])) == 0 { // Contains frontmatter @@ -405,7 +405,7 @@ func (p *post) computeExtraPostParameters() error { p.Blog = blog[0] delete(p.Parameters, "blog") } else { - p.Blog = appConfig.DefaultBlog + p.Blog = a.cfg.DefaultBlog } if path := p.Parameters["path"]; len(path) == 1 { p.Path = path[0] @@ -433,15 +433,15 @@ func (p *post) computeExtraPostParameters() error { } if p.Path == "" && p.Section == "" { // Has no path or section -> default section - p.Section = appConfig.Blogs[p.Blog].DefaultSection + p.Section = a.cfg.Blogs[p.Blog].DefaultSection } if p.Published == "" && p.Section != "" { // Has no published date, but section -> published now p.Published = time.Now().Local().String() } // Add images not in content - images := p.Parameters[appConfig.Micropub.PhotoParam] - imageAlts := p.Parameters[appConfig.Micropub.PhotoDescriptionParam] + images := p.Parameters[a.cfg.Micropub.PhotoParam] + imageAlts := p.Parameters[a.cfg.Micropub.PhotoDescriptionParam] useAlts := len(images) == len(imageAlts) for i, image := range images { if !strings.Contains(p.Content, image) { @@ -455,26 +455,26 @@ func (p *post) computeExtraPostParameters() error { return nil } -func micropubDelete(w http.ResponseWriter, r *http.Request, u *url.URL) { +func (a *goBlog) micropubDelete(w http.ResponseWriter, r *http.Request, u *url.URL) { if !strings.Contains(r.Context().Value(indieAuthScope).(string), "delete") { - serveError(w, r, "delete scope missing", http.StatusForbidden) + a.serveError(w, r, "delete scope missing", http.StatusForbidden) return } - if err := deletePost(u.Path); err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + if err := a.deletePost(u.Path); err != nil { + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } http.Redirect(w, r, u.String(), http.StatusNoContent) } -func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *microformatItem) { +func (a *goBlog) micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *microformatItem) { if !strings.Contains(r.Context().Value(indieAuthScope).(string), "update") { - serveError(w, r, "update scope missing", http.StatusForbidden) + a.serveError(w, r, "update scope missing", http.StatusForbidden) return } - p, err := getPost(u.Path) + p, err := a.db.getPost(u.Path) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } oldPath := p.Path @@ -491,15 +491,15 @@ func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *micr case "name": p.Parameters["title"] = cast.ToStringSlice(value) case "category": - p.Parameters[appConfig.Micropub.CategoryParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.CategoryParam] = cast.ToStringSlice(value) case "in-reply-to": - p.Parameters[appConfig.Micropub.ReplyParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.ReplyParam] = cast.ToStringSlice(value) case "like-of": - p.Parameters[appConfig.Micropub.LikeParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.LikeParam] = cast.ToStringSlice(value) case "bookmark-of": - p.Parameters[appConfig.Micropub.BookmarkParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.BookmarkParam] = cast.ToStringSlice(value) case "audio": - p.Parameters[appConfig.Micropub.AudioParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.AudioParam] = cast.ToStringSlice(value) // TODO: photo } } @@ -514,23 +514,23 @@ func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *micr case "updated": p.Updated = strings.TrimSpace(strings.Join(cast.ToStringSlice(value), " ")) case "category": - category := p.Parameters[appConfig.Micropub.CategoryParam] + category := p.Parameters[a.cfg.Micropub.CategoryParam] if category == nil { category = []string{} } - p.Parameters[appConfig.Micropub.CategoryParam] = append(category, cast.ToStringSlice(value)...) + p.Parameters[a.cfg.Micropub.CategoryParam] = append(category, cast.ToStringSlice(value)...) case "in-reply-to": - p.Parameters[appConfig.Micropub.ReplyParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.ReplyParam] = cast.ToStringSlice(value) case "like-of": - p.Parameters[appConfig.Micropub.LikeParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.LikeParam] = cast.ToStringSlice(value) case "bookmark-of": - p.Parameters[appConfig.Micropub.BookmarkParam] = cast.ToStringSlice(value) + p.Parameters[a.cfg.Micropub.BookmarkParam] = cast.ToStringSlice(value) case "audio": - audio := p.Parameters[appConfig.Micropub.CategoryParam] + audio := p.Parameters[a.cfg.Micropub.CategoryParam] if audio == nil { audio = []string{} } - p.Parameters[appConfig.Micropub.AudioParam] = append(audio, cast.ToStringSlice(value)...) + p.Parameters[a.cfg.Micropub.AudioParam] = append(audio, cast.ToStringSlice(value)...) // TODO: photo } } @@ -548,18 +548,18 @@ func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *micr case "updated": p.Updated = "" case "category": - delete(p.Parameters, appConfig.Micropub.CategoryParam) + delete(p.Parameters, a.cfg.Micropub.CategoryParam) case "in-reply-to": - delete(p.Parameters, appConfig.Micropub.ReplyParam) + delete(p.Parameters, a.cfg.Micropub.ReplyParam) case "like-of": - delete(p.Parameters, appConfig.Micropub.LikeParam) + delete(p.Parameters, a.cfg.Micropub.LikeParam) case "bookmark-of": - delete(p.Parameters, appConfig.Micropub.BookmarkParam) + delete(p.Parameters, a.cfg.Micropub.BookmarkParam) case "audio": - delete(p.Parameters, appConfig.Micropub.AudioParam) + delete(p.Parameters, a.cfg.Micropub.AudioParam) case "photo": - delete(p.Parameters, appConfig.Micropub.PhotoParam) - delete(p.Parameters, appConfig.Micropub.PhotoDescriptionParam) + delete(p.Parameters, a.cfg.Micropub.PhotoParam) + delete(p.Parameters, a.cfg.Micropub.PhotoDescriptionParam) } } } @@ -576,11 +576,11 @@ func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *micr case "updated": p.Updated = "" case "in-reply-to": - delete(p.Parameters, appConfig.Micropub.ReplyParam) + delete(p.Parameters, a.cfg.Micropub.ReplyParam) case "like-of": - delete(p.Parameters, appConfig.Micropub.LikeParam) + delete(p.Parameters, a.cfg.Micropub.LikeParam) case "bookmark-of": - delete(p.Parameters, appConfig.Micropub.BookmarkParam) + delete(p.Parameters, a.cfg.Micropub.BookmarkParam) // Use content to edit other parameters } } @@ -588,15 +588,15 @@ func micropubUpdate(w http.ResponseWriter, r *http.Request, u *url.URL, mf *micr } } } - err = p.computeExtraPostParameters() + err = a.computeExtraPostParameters(p) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - err = p.replace(oldPath, oldStatus) + err = a.replacePost(p, oldPath, oldStatus) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - http.Redirect(w, r, p.fullURL(), http.StatusNoContent) + http.Redirect(w, r, a.fullPostURL(p), http.StatusNoContent) } diff --git a/micropubMedia.go b/micropubMedia.go index 52d8ae6..49909be 100644 --- a/micropubMedia.go +++ b/micropubMedia.go @@ -15,23 +15,23 @@ import ( const micropubMediaSubPath = "/media" -func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveMicropubMedia(w http.ResponseWriter, r *http.Request) { if !strings.Contains(r.Context().Value(indieAuthScope).(string), "media") { - serveError(w, r, "media scope missing", http.StatusForbidden) + a.serveError(w, r, "media scope missing", http.StatusForbidden) return } if ct := r.Header.Get(contentType); !strings.Contains(ct, contentTypeMultipartForm) { - serveError(w, r, "wrong content-type", http.StatusBadRequest) + a.serveError(w, r, "wrong content-type", http.StatusBadRequest) return } err := r.ParseMultipartForm(0) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } file, header, err := r.FormFile("file") if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } defer func() { _ = file.Close() }() @@ -39,7 +39,7 @@ func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { defer func() { _ = hashFile.Close() }() fileName, err := getSHA256(hashFile) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } fileExtension := filepath.Ext(header.Filename) @@ -55,22 +55,22 @@ func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { } fileName += strings.ToLower(fileExtension) // Save file - location, err := uploadFile(fileName, file) + location, err := a.uploadFile(fileName, file) if err != nil { - serveError(w, r, "failed to save original file: "+err.Error(), http.StatusInternalServerError) + a.serveError(w, r, "failed to save original file: "+err.Error(), http.StatusInternalServerError) return } // Try to compress file (only when not in private mode) - if pm := appConfig.PrivateMode; !(pm != nil && pm.Enabled) { + if pm := a.cfg.PrivateMode; !(pm != nil && pm.Enabled) { serveCompressionError := func(ce error) { - serveError(w, r, "failed to compress file: "+ce.Error(), http.StatusInternalServerError) + a.serveError(w, r, "failed to compress file: "+ce.Error(), http.StatusInternalServerError) } var compressedLocation string var compressionErr error - if ms := appConfig.Micropub.MediaStorage; ms != nil { + if ms := a.cfg.Micropub.MediaStorage; ms != nil { // Default ShortPixel if ms.ShortPixelKey != "" { - compressedLocation, compressionErr = shortPixel(location, ms) + compressedLocation, compressionErr = a.shortPixel(location, ms) } if compressionErr != nil { serveCompressionError(compressionErr) @@ -78,7 +78,7 @@ func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { } // Fallback Tinify if compressedLocation == "" && ms.TinifyKey != "" { - compressedLocation, compressionErr = tinify(location, ms) + compressedLocation, compressionErr = a.tinify(location, ms) } if compressionErr != nil { serveCompressionError(compressionErr) @@ -86,7 +86,7 @@ func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { } // Fallback Cloudflare if compressedLocation == "" && ms.CloudflareCompressionEnabled { - compressedLocation, compressionErr = cloudflare(location) + compressedLocation, compressionErr = a.cloudflare(location) } if compressionErr != nil { serveCompressionError(compressionErr) @@ -101,10 +101,10 @@ func serveMicropubMedia(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, location, http.StatusCreated) } -func uploadFile(filename string, f io.Reader) (string, error) { - ms := appConfig.Micropub.MediaStorage +func (a *goBlog) uploadFile(filename string, f io.Reader) (string, error) { + ms := a.cfg.Micropub.MediaStorage if ms != nil && ms.BunnyStorageKey != "" && ms.BunnyStorageName != "" { - return uploadToBunny(filename, f, ms) + return ms.uploadToBunny(filename, f) } loc, err := saveMediaFile(filename, f) if err != nil { @@ -113,10 +113,10 @@ func uploadFile(filename string, f io.Reader) (string, error) { if ms != nil && ms.MediaURL != "" { return ms.MediaURL + loc, nil } - return appConfig.Server.PublicAddress + loc, nil + return a.cfg.Server.PublicAddress + loc, nil } -func uploadToBunny(filename string, f io.Reader, config *configMicropubMedia) (location string, err error) { +func (config *configMicropubMedia) uploadToBunny(filename string, f io.Reader) (location string, err error) { if config == nil || config.BunnyStorageName == "" || config.BunnyStorageKey == "" || config.MediaURL == "" { return "", errors.New("Bunny storage not completely configured") } diff --git a/minify.go b/minify.go index e5d1ccc..6b51b59 100644 --- a/minify.go +++ b/minify.go @@ -2,6 +2,7 @@ package main import ( "io" + "sync" "github.com/tdewolff/minify/v2" mCss "github.com/tdewolff/minify/v2/css" @@ -11,22 +12,28 @@ import ( mXml "github.com/tdewolff/minify/v2/xml" ) -var minifier *minify.M +var ( + initMinify sync.Once + minifier *minify.M +) -func initMinify() { - minifier = minify.New() - minifier.AddFunc(contentTypeHTML, mHtml.Minify) - minifier.AddFunc("text/css", mCss.Minify) - minifier.AddFunc(contentTypeXML, mXml.Minify) - minifier.AddFunc("application/javascript", mJs.Minify) - minifier.AddFunc(contentTypeRSS, mXml.Minify) - minifier.AddFunc(contentTypeATOM, mXml.Minify) - minifier.AddFunc(contentTypeJSONFeed, mJson.Minify) - minifier.AddFunc(contentTypeAS, mJson.Minify) +func getMinifier() *minify.M { + initMinify.Do(func() { + minifier = minify.New() + minifier.AddFunc(contentTypeHTML, mHtml.Minify) + minifier.AddFunc("text/css", mCss.Minify) + minifier.AddFunc(contentTypeXML, mXml.Minify) + minifier.AddFunc("application/javascript", mJs.Minify) + minifier.AddFunc(contentTypeRSS, mXml.Minify) + minifier.AddFunc(contentTypeATOM, mXml.Minify) + minifier.AddFunc(contentTypeJSONFeed, mJson.Minify) + minifier.AddFunc(contentTypeAS, mJson.Minify) + }) + return minifier } func writeMinified(w io.Writer, mediatype string, b []byte) (int, error) { - mw := minifier.Writer(mediatype, w) - defer func() { mw.Close() }() + mw := getMinifier().Writer(mediatype, w) + defer func() { _ = mw.Close() }() return mw.Write(b) } diff --git a/nodeinfo.go b/nodeinfo.go index c28892e..473eaae 100644 --- a/nodeinfo.go +++ b/nodeinfo.go @@ -5,11 +5,11 @@ import ( "net/http" ) -func serveNodeInfoDiscover(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveNodeInfoDiscover(w http.ResponseWriter, r *http.Request) { b, _ := json.Marshal(map[string]interface{}{ "links": []map[string]interface{}{ { - "href": appConfig.Server.PublicAddress + "/nodeinfo", + "href": a.cfg.Server.PublicAddress + "/nodeinfo", "rel": "http://nodeinfo.diaspora.software/ns/schema/2.1", }, }, @@ -18,8 +18,8 @@ func serveNodeInfoDiscover(w http.ResponseWriter, r *http.Request) { _, _ = writeMinified(w, contentTypeJSON, b) } -func serveNodeInfo(w http.ResponseWriter, r *http.Request) { - localPosts, _ := countPosts(&postsRequestConfig{ +func (a *goBlog) serveNodeInfo(w http.ResponseWriter, r *http.Request) { + localPosts, _ := a.db.countPosts(&postsRequestConfig{ status: statusPublished, }) b, _ := json.Marshal(map[string]interface{}{ @@ -30,7 +30,7 @@ func serveNodeInfo(w http.ResponseWriter, r *http.Request) { }, "usage": map[string]interface{}{ "users": map[string]interface{}{ - "total": len(appConfig.Blogs), + "total": len(a.cfg.Blogs), }, "localPosts": localPosts, }, diff --git a/notifications.go b/notifications.go index b4ef948..4401df2 100644 --- a/notifications.go +++ b/notifications.go @@ -21,15 +21,15 @@ type notification struct { Text string } -func sendNotification(text string) { +func (a *goBlog) sendNotification(text string) { n := ¬ification{ Time: time.Now().Unix(), Text: text, } - if err := saveNotification(n); err != nil { + if err := a.db.saveNotification(n); err != nil { log.Println("Failed to save notification:", err.Error()) } - if an := appConfig.Notifications; an != nil { + if an := a.cfg.Notifications; an != nil { if tg := an.Telegram; tg != nil && tg.Enabled { err := sendTelegramMessage(n.Text, "", tg.BotToken, tg.ChatID) if err != nil { @@ -39,15 +39,15 @@ func sendNotification(text string) { } } -func saveNotification(n *notification) error { - if _, err := appDb.exec("insert into notifications (time, text) values (@time, @text)", sql.Named("time", n.Time), sql.Named("text", n.Text)); err != nil { +func (db *database) saveNotification(n *notification) error { + if _, err := db.exec("insert into notifications (time, text) values (@time, @text)", sql.Named("time", n.Time), sql.Named("text", n.Text)); err != nil { return err } return nil } -func deleteNotification(id int) error { - _, err := appDb.exec("delete from notifications where id = @id", sql.Named("id", id)) +func (db *database) deleteNotification(id int) error { + _, err := db.exec("delete from notifications where id = @id", sql.Named("id", id)) return err } @@ -65,10 +65,10 @@ func buildNotificationsQuery(config *notificationsRequestConfig) (query string, return } -func getNotifications(config *notificationsRequestConfig) ([]*notification, error) { +func (db *database) getNotifications(config *notificationsRequestConfig) ([]*notification, error) { notifications := []*notification{} query, args := buildNotificationsQuery(config) - rows, err := appDb.query(query, args...) + rows, err := db.query(query, args...) if err != nil { return nil, err } @@ -83,10 +83,10 @@ func getNotifications(config *notificationsRequestConfig) ([]*notification, erro return notifications, nil } -func countNotifications(config *notificationsRequestConfig) (count int, err error) { +func (db *database) countNotifications(config *notificationsRequestConfig) (count int, err error) { query, params := buildNotificationsQuery(config) query = "select count(*) from (" + query + ")" - row, err := appDb.queryRow(query, params...) + row, err := db.queryRow(query, params...) if err != nil { return } @@ -97,11 +97,12 @@ func countNotifications(config *notificationsRequestConfig) (count int, err erro type notificationsPaginationAdapter struct { config *notificationsRequestConfig nums int64 + db *database } func (p *notificationsPaginationAdapter) Nums() (int64, error) { if p.nums == 0 { - nums, _ := countNotifications(p.config) + nums, _ := p.db.countNotifications(p.config) p.nums = int64(nums) } return p.nums, nil @@ -112,21 +113,21 @@ func (p *notificationsPaginationAdapter) Slice(offset, length int, data interfac modifiedConfig.offset = offset modifiedConfig.limit = length - notifications, err := getNotifications(&modifiedConfig) + notifications, err := p.db.getNotifications(&modifiedConfig) reflect.ValueOf(data).Elem().Set(reflect.ValueOf(¬ifications).Elem()) return err } -func notificationsAdmin(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) notificationsAdmin(w http.ResponseWriter, r *http.Request) { // Adapter pageNoString := chi.URLParam(r, "page") pageNo, _ := strconv.Atoi(pageNoString) - p := paginator.New(¬ificationsPaginationAdapter{config: ¬ificationsRequestConfig{}}, 10) + p := paginator.New(¬ificationsPaginationAdapter{config: ¬ificationsRequestConfig{}, db: a.db}, 10) p.SetPage(pageNo) var notifications []*notification err := p.Results(¬ifications) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } // Navigation @@ -152,7 +153,7 @@ func notificationsAdmin(w http.ResponseWriter, r *http.Request) { } nextPath = fmt.Sprintf("%s/page/%d", notificationsPath, nextPage) // Render - render(w, r, templateNotificationsAdmin, &renderData{ + a.render(w, r, templateNotificationsAdmin, &renderData{ Data: map[string]interface{}{ "Notifications": notifications, "HasPrev": hasPrev, @@ -163,15 +164,15 @@ func notificationsAdmin(w http.ResponseWriter, r *http.Request) { }) } -func notificationsAdminDelete(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) notificationsAdminDelete(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.FormValue("notificationid")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - err = deleteNotification(id) + err = a.db.deleteNotification(id) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, ".", http.StatusFound) diff --git a/persistentCache.go b/persistentCache.go index bd26b13..07bf090 100644 --- a/persistentCache.go +++ b/persistentCache.go @@ -3,21 +3,17 @@ package main import ( "database/sql" "time" - - "golang.org/x/sync/singleflight" ) -func cachePersistently(key string, data []byte) error { +func (db *database) cachePersistently(key string, data []byte) error { date, _ := toLocal(time.Now().String()) - _, err := appDb.exec("insert or replace into persistent_cache(key, data, date) values(@key, @data, @date)", sql.Named("key", key), sql.Named("data", data), sql.Named("date", date)) + _, err := db.exec("insert or replace into persistent_cache(key, data, date) values(@key, @data, @date)", sql.Named("key", key), sql.Named("data", data), sql.Named("date", date)) return err } -var persistentCacheGroup singleflight.Group - -func retrievePersistentCache(key string) (data []byte, err error) { - d, err, _ := persistentCacheGroup.Do(key, func() (interface{}, error) { - if row, err := appDb.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows { +func (db *database) retrievePersistentCache(key string) (data []byte, err error) { + d, err, _ := db.persistentCacheGroup.Do(key, func() (interface{}, error) { + if row, err := db.queryRow("select data from persistent_cache where key = @key", sql.Named("key", key)); err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err @@ -32,7 +28,7 @@ func retrievePersistentCache(key string) (data []byte, err error) { return d.([]byte), nil } -func clearPersistentCache(pattern string) error { - _, err := appDb.exec("delete from persistent_cache where key like @pattern", sql.Named("pattern", pattern)) +func (db *database) clearPersistentCache(pattern string) error { + _, err := db.exec("delete from persistent_cache where key like @pattern", sql.Named("pattern", pattern)) return err } diff --git a/postAliases.go b/postAliases.go index cac696d..e663eae 100644 --- a/postAliases.go +++ b/postAliases.go @@ -5,9 +5,9 @@ import ( "net/http" ) -func allPostAliases() ([]string, error) { +func (db *database) allPostAliases() ([]string, error) { var aliases []string - rows, err := appDb.query("select distinct value from post_parameters where parameter = 'aliases' and value != path") + rows, err := db.query("select distinct value from post_parameters where parameter = 'aliases' and value != path") if err != nil { return nil, err } @@ -21,18 +21,18 @@ func allPostAliases() ([]string, error) { return aliases, nil } -func servePostAlias(w http.ResponseWriter, r *http.Request) { - row, err := appDb.queryRow("select path from post_parameters where parameter = 'aliases' and value = @alias", sql.Named("alias", r.URL.Path)) +func (a *goBlog) servePostAlias(w http.ResponseWriter, r *http.Request) { + row, err := a.db.queryRow("select path from post_parameters where parameter = 'aliases' and value = @alias", sql.Named("alias", r.URL.Path)) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } var path string if err := row.Scan(&path); err == sql.ErrNoRows { - serve404(w, r) + a.serve404(w, r) return } else if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, path, http.StatusFound) diff --git a/posts.go b/posts.go index c25693d..e4f7f45 100644 --- a/posts.go +++ b/posts.go @@ -41,45 +41,45 @@ const ( statusDraft postStatus = "draft" ) -func servePost(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) servePost(w http.ResponseWriter, r *http.Request) { t := servertiming.FromContext(r.Context()).NewMetric("gp").Start() - p, err := getPost(r.URL.Path) + p, err := a.db.getPost(r.URL.Path) t.Stop() if err == errPostNotFound { - serve404(w, r) + a.serve404(w, r) return } else if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { - if r.URL.Path == blogPath(p.Blog) { - appConfig.Blogs[p.Blog].serveActivityStreams(p.Blog, w, r) + if r.URL.Path == a.blogPath(p.Blog) { + a.serveActivityStreams(p.Blog, w, r) return } - p.serveActivityStreams(w) + a.serveActivityStreamsPost(p, w) return } canonical := p.firstParameter("original") if canonical == "" { - canonical = p.fullURL() + canonical = a.fullPostURL(p) } template := templatePost - if p.Path == appConfig.Blogs[p.Blog].Path { + if p.Path == a.cfg.Blogs[p.Blog].Path { template = templateStaticHome } - w.Header().Add("Link", fmt.Sprintf("<%s>; rel=shortlink", p.shortURL())) - render(w, r, template, &renderData{ + w.Header().Add("Link", fmt.Sprintf("<%s>; rel=shortlink", a.shortPostURL(p))) + a.render(w, r, template, &renderData{ BlogString: p.Blog, Canonical: canonical, Data: p, }) } -func redirectToRandomPost(rw http.ResponseWriter, r *http.Request) { - randomPath, err := getRandomPostPath(r.Context().Value(blogContextKey).(string)) +func (a *goBlog) redirectToRandomPost(rw http.ResponseWriter, r *http.Request) { + randomPath, err := a.getRandomPostPath(r.Context().Value(blogContextKey).(string)) if err != nil { - serveError(rw, r, err.Error(), http.StatusInternalServerError) + a.serveError(rw, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(rw, r, randomPath, http.StatusFound) @@ -88,11 +88,12 @@ func redirectToRandomPost(rw http.ResponseWriter, r *http.Request) { type postPaginationAdapter struct { config *postsRequestConfig nums int64 + db *database } func (p *postPaginationAdapter) Nums() (int64, error) { if p.nums == 0 { - nums, _ := countPosts(p.config) + nums, _ := p.db.countPosts(p.config) p.nums = int64(nums) } return p.nums, nil @@ -103,23 +104,23 @@ func (p *postPaginationAdapter) Slice(offset, length int, data interface{}) erro modifiedConfig.offset = offset modifiedConfig.limit = length - posts, err := getPosts(&modifiedConfig) + posts, err := p.db.getPosts(&modifiedConfig) reflect.ValueOf(data).Elem().Set(reflect.ValueOf(&posts).Elem()) return err } -func serveHome(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveHome(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest { - appConfig.Blogs[blog].serveActivityStreams(blog, w, r) + a.serveActivityStreams(blog, w, r) return } - serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ - path: blogPath(blog), + a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ + path: a.blogPath(blog), }))) } -func serveDate(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveDate(w http.ResponseWriter, r *http.Request) { var year, month, day int if ys := chi.URLParam(r, "year"); ys != "" && ys != "x" { year, _ = strconv.Atoi(ys) @@ -131,11 +132,11 @@ func serveDate(w http.ResponseWriter, r *http.Request) { day, _ = strconv.Atoi(ds) } if year == 0 && month == 0 && day == 0 { - serve404(w, r) + a.serve404(w, r) return } var title, dPath strings.Builder - dPath.WriteString(blogPath(r.Context().Value(blogContextKey).(string)) + "/") + dPath.WriteString(a.blogPath(r.Context().Value(blogContextKey).(string)) + "/") if year != 0 { ys := fmt.Sprintf("%0004d", year) title.WriteString(ys) @@ -155,7 +156,7 @@ func serveDate(w http.ResponseWriter, r *http.Request) { title.WriteString(fmt.Sprintf("-%02d", day)) dPath.WriteString(fmt.Sprintf("/%02d", day)) } - serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ + a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ path: dPath.String(), year: year, month: month, @@ -179,7 +180,7 @@ type indexConfig struct { const indexConfigKey requestContextKey = "indexConfig" -func serveIndex(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveIndex(w http.ResponseWriter, r *http.Request) { ic := r.Context().Value(indexConfigKey).(*indexConfig) blog := ic.blog if blog == "" { @@ -195,7 +196,7 @@ func serveIndex(w http.ResponseWriter, r *http.Request) { if ic.section != nil { sections = []string{ic.section.Name} } else { - for sectionKey := range appConfig.Blogs[blog].Sections { + for sectionKey := range a.cfg.Blogs[blog].Sections { sections = append(sections, sectionKey) } } @@ -210,14 +211,14 @@ func serveIndex(w http.ResponseWriter, r *http.Request) { publishedMonth: ic.month, publishedDay: ic.day, status: statusPublished, - }}, appConfig.Blogs[blog].Pagination) + }, db: a.db}, a.cfg.Blogs[blog].Pagination) p.SetPage(pageNo) var posts []*post t := servertiming.FromContext(r.Context()).NewMetric("gp").Start() err := p.Results(&posts) t.Stop() if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } // Meta @@ -229,13 +230,13 @@ func serveIndex(w http.ResponseWriter, r *http.Request) { title = ic.section.Title description = ic.section.Description } else if search != "" { - title = fmt.Sprintf("%s: %s", appConfig.Blogs[blog].Search.Title, search) + title = fmt.Sprintf("%s: %s", a.cfg.Blogs[blog].Search.Title, search) } // Clean title title = bluemonday.StrictPolicy().Sanitize(title) // Check if feed if ft := feedType(chi.URLParam(r, "feed")); ft != noFeed { - generateFeed(blog, ft, w, r, posts, title, description) + a.generateFeed(blog, ft, w, r, posts, title, description) return } // Path @@ -269,9 +270,9 @@ func serveIndex(w http.ResponseWriter, r *http.Request) { if summaryTemplate == "" { summaryTemplate = templateSummary } - render(w, r, templateIndex, &renderData{ + a.render(w, r, templateIndex, &renderData{ BlogString: blog, - Canonical: appConfig.Server.PublicAddress + path, + Canonical: a.cfg.Server.PublicAddress + path, Data: map[string]interface{}{ "Title": title, "Description": description, diff --git a/postsDb.go b/postsDb.go index a6553d9..bc3192f 100644 --- a/postsDb.go +++ b/postsDb.go @@ -13,7 +13,7 @@ import ( "github.com/araddon/dateparse" ) -func (p *post) checkPost() (err error) { +func (a *goBlog) checkPost(p *post) (err error) { if p == nil { return errors.New("no post") } @@ -57,13 +57,13 @@ func (p *post) checkPost() (err error) { } // Check blog if p.Blog == "" { - p.Blog = appConfig.DefaultBlog + p.Blog = a.cfg.DefaultBlog } - if _, ok := appConfig.Blogs[p.Blog]; !ok { + if _, ok := a.cfg.Blogs[p.Blog]; !ok { return errors.New("blog doesn't exist") } // Check if section exists - if _, ok := appConfig.Blogs[p.Blog].Sections[p.Section]; p.Section != "" && !ok { + if _, ok := a.cfg.Blogs[p.Blog].Sections[p.Section]; p.Section != "" && !ok { return errors.New("section doesn't exist") } // Check path @@ -72,14 +72,14 @@ func (p *post) checkPost() (err error) { } if p.Path == "" { if p.Section == "" { - p.Section = appConfig.Blogs[p.Blog].DefaultSection + p.Section = a.cfg.Blogs[p.Blog].DefaultSection } if p.Slug == "" { random := generateRandomString(5) p.Slug = fmt.Sprintf("%v-%02d-%02d-%v", now.Year(), int(now.Month()), now.Day(), random) } published, _ := dateparse.ParseLocal(p.Published) - pathTmplString := appConfig.Blogs[p.Blog].Sections[p.Section].PathTemplate + pathTmplString := a.cfg.Blogs[p.Blog].Sections[p.Section].PathTemplate if pathTmplString == "" { return errors.New("path template empty") } @@ -89,7 +89,7 @@ func (p *post) checkPost() (err error) { } var pathBuffer bytes.Buffer err = pathTmpl.Execute(&pathBuffer, map[string]interface{}{ - "BlogPath": appConfig.Blogs[p.Blog].Path, + "BlogPath": a.cfg.Blogs[p.Blog].Path, "Year": published.Year(), "Month": int(published.Month()), "Day": published.Day(), @@ -107,12 +107,12 @@ func (p *post) checkPost() (err error) { return nil } -func (p *post) create() error { - return p.createOrReplace(&postCreationOptions{new: true}) +func (a *goBlog) createPost(p *post) error { + return a.createOrReplacePost(p, &postCreationOptions{new: true}) } -func (p *post) replace(oldPath string, oldStatus postStatus) error { - return p.createOrReplace(&postCreationOptions{new: false, oldPath: oldPath, oldStatus: oldStatus}) +func (a *goBlog) replacePost(p *post, oldPath string, oldStatus postStatus) error { + return a.createOrReplacePost(p, &postCreationOptions{new: false, oldPath: oldPath, oldStatus: oldStatus}) } type postCreationOptions struct { @@ -123,8 +123,8 @@ type postCreationOptions struct { var postCreationMutex sync.Mutex -func (p *post) createOrReplace(o *postCreationOptions) error { - err := p.checkPost() +func (a *goBlog) createOrReplacePost(p *post, o *postCreationOptions) error { + err := a.checkPost(p) if err != nil { return err } @@ -135,7 +135,7 @@ func (p *post) createOrReplace(o *postCreationOptions) error { if o.new || (p.Path != o.oldPath) { // Post is new or post path was changed newPathExists := false - row, err := appDb.queryRow("select exists(select 1 from posts where path = @path)", sql.Named("path", p.Path)) + row, err := a.db.queryRow("select exists(select 1 from posts where path = @path)", sql.Named("path", p.Path)) if err != nil { return err } @@ -169,65 +169,37 @@ func (p *post) createOrReplace(o *postCreationOptions) error { } } // Execute - _, err = appDb.execMulti(sqlBuilder.String(), sqlArgs...) + _, err = a.db.execMulti(sqlBuilder.String(), sqlArgs...) if err != nil { return err } // Update FTS index, trigger hooks and reload router - rebuildFTSIndex() + a.db.rebuildFTSIndex() if p.Status == statusPublished { if o.new || o.oldStatus == statusDraft { - defer p.postPostHooks() + defer a.postPostHooks(p) } else { - defer p.postUpdateHooks() + defer a.postUpdateHooks(p) } } - return reloadRouter() + return a.reloadRouter() } -func deletePost(path string) error { +func (a *goBlog) deletePost(path string) error { if path == "" { return nil } - p, err := getPost(path) + p, err := a.db.getPost(path) if err != nil { return err } - _, err = appDb.exec("delete from posts where path = @path", sql.Named("path", p.Path)) + _, err = a.db.exec("delete from posts where path = @path", sql.Named("path", p.Path)) if err != nil { return err } - rebuildFTSIndex() - defer p.postDeleteHooks() - return reloadRouter() -} - -func rebuildFTSIndex() { - _, _ = appDb.exec("insert into posts_fts(posts_fts) values ('rebuild')") -} - -func getPost(path string) (*post, error) { - posts, err := getPosts(&postsRequestConfig{path: path}) - if err != nil { - return nil, err - } else if len(posts) == 0 { - return nil, errPostNotFound - } - return posts[0], nil -} - -func getRandomPostPath(blog string) (string, error) { - var sections []string - for sectionKey := range appConfig.Blogs[blog].Sections { - sections = append(sections, sectionKey) - } - posts, err := getPosts(&postsRequestConfig{randomOrder: true, limit: 1, blog: blog, sections: sections}) - if err != nil { - return "", err - } else if len(posts) == 0 { - return "", errPostNotFound - } - return posts[0].Path, nil + a.db.rebuildFTSIndex() + defer a.postDeleteHooks(p) + return a.reloadRouter() } type postsRequestConfig struct { @@ -246,39 +218,39 @@ type postsRequestConfig struct { randomOrder bool } -func buildPostsQuery(config *postsRequestConfig) (query string, args []interface{}) { +func buildPostsQuery(c *postsRequestConfig) (query string, args []interface{}) { args = []interface{}{} defaultSelection := "select p.path as path, coalesce(content, '') as content, coalesce(published, '') as published, coalesce(updated, '') as updated, coalesce(blog, '') as blog, coalesce(section, '') as section, coalesce(status, '') as status, coalesce(parameter, '') as parameter, coalesce(value, '') as value " postsTable := "posts" - if config.search != "" { + if c.search != "" { postsTable = "posts_fts(@search)" - args = append(args, sql.Named("search", config.search)) + args = append(args, sql.Named("search", c.search)) } - if config.status != "" && config.status != statusNil { + if c.status != "" && c.status != statusNil { postsTable = "(select * from " + postsTable + " where status = @status)" - args = append(args, sql.Named("status", config.status)) + args = append(args, sql.Named("status", c.status)) } - if config.blog != "" { + if c.blog != "" { postsTable = "(select * from " + postsTable + " where blog = @blog)" - args = append(args, sql.Named("blog", config.blog)) + args = append(args, sql.Named("blog", c.blog)) } - if config.parameter != "" { + if c.parameter != "" { postsTable = "(select distinct p.* from " + postsTable + " p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @param " - args = append(args, sql.Named("param", config.parameter)) - if config.parameterValue != "" { + args = append(args, sql.Named("param", c.parameter)) + if c.parameterValue != "" { postsTable += "and pp.value = @paramval)" - args = append(args, sql.Named("paramval", config.parameterValue)) + args = append(args, sql.Named("paramval", c.parameterValue)) } else { postsTable += "and length(coalesce(pp.value, '')) > 1)" } } - if config.taxonomy != nil && len(config.taxonomyValue) > 0 { + if c.taxonomy != nil && len(c.taxonomyValue) > 0 { postsTable = "(select distinct p.* from " + postsTable + " p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @taxname and lower(pp.value) = lower(@taxval))" - args = append(args, sql.Named("taxname", config.taxonomy.Name), sql.Named("taxval", config.taxonomyValue)) + args = append(args, sql.Named("taxname", c.taxonomy.Name), sql.Named("taxval", c.taxonomyValue)) } - if len(config.sections) > 0 { + if len(c.sections) > 0 { postsTable = "(select * from " + postsTable + " where" - for i, section := range config.sections { + for i, section := range c.sections { if i > 0 { postsTable += " or" } @@ -288,38 +260,38 @@ func buildPostsQuery(config *postsRequestConfig) (query string, args []interface } postsTable += ")" } - if config.publishedYear != 0 { + if c.publishedYear != 0 { postsTable = "(select * from " + postsTable + " p where substr(p.published, 1, 4) = @publishedyear)" - args = append(args, sql.Named("publishedyear", fmt.Sprintf("%0004d", config.publishedYear))) + args = append(args, sql.Named("publishedyear", fmt.Sprintf("%0004d", c.publishedYear))) } - if config.publishedMonth != 0 { + if c.publishedMonth != 0 { postsTable = "(select * from " + postsTable + " p where substr(p.published, 6, 2) = @publishedmonth)" - args = append(args, sql.Named("publishedmonth", fmt.Sprintf("%02d", config.publishedMonth))) + args = append(args, sql.Named("publishedmonth", fmt.Sprintf("%02d", c.publishedMonth))) } - if config.publishedDay != 0 { + if c.publishedDay != 0 { postsTable = "(select * from " + postsTable + " p where substr(p.published, 9, 2) = @publishedday)" - args = append(args, sql.Named("publishedday", fmt.Sprintf("%02d", config.publishedDay))) + args = append(args, sql.Named("publishedday", fmt.Sprintf("%02d", c.publishedDay))) } defaultTables := " from " + postsTable + " p left outer join post_parameters pp on p.path = pp.path " defaultSorting := " order by p.published desc " - if config.randomOrder { + if c.randomOrder { defaultSorting = " order by random() " } - if config.path != "" { + if c.path != "" { query = defaultSelection + defaultTables + " where p.path = @path" + defaultSorting - args = append(args, sql.Named("path", config.path)) - } else if config.limit != 0 || config.offset != 0 { + args = append(args, sql.Named("path", c.path)) + } else if c.limit != 0 || c.offset != 0 { query = defaultSelection + " from (select * from " + postsTable + " p " + defaultSorting + " limit @limit offset @offset) p left outer join post_parameters pp on p.path = pp.path " - args = append(args, sql.Named("limit", config.limit), sql.Named("offset", config.offset)) + args = append(args, sql.Named("limit", c.limit), sql.Named("offset", c.offset)) } else { query = defaultSelection + defaultTables + defaultSorting } return } -func getPosts(config *postsRequestConfig) (posts []*post, err error) { +func (d *database) getPosts(config *postsRequestConfig) (posts []*post, err error) { query, queryParams := buildPostsQuery(config) - rows, err := appDb.query(query, queryParams...) + rows, err := d.query(query, queryParams...) if err != nil { return nil, err } @@ -351,10 +323,25 @@ func getPosts(config *postsRequestConfig) (posts []*post, err error) { return posts, nil } -func countPosts(config *postsRequestConfig) (count int, err error) { +func (d *database) getPost(path string) (*post, error) { + posts, err := d.getPosts(&postsRequestConfig{path: path}) + if err != nil { + return nil, err + } else if len(posts) == 0 { + return nil, errPostNotFound + } + return posts[0], nil +} + +func (d *database) getDrafts(blog string) []*post { + ps, _ := d.getPosts(&postsRequestConfig{status: statusDraft, blog: blog}) + return ps +} + +func (d *database) countPosts(config *postsRequestConfig) (count int, err error) { query, params := buildPostsQuery(config) query = "select count(distinct path) from (" + query + ")" - row, err := appDb.queryRow(query, params...) + row, err := d.queryRow(query, params...) if err != nil { return } @@ -362,9 +349,9 @@ func countPosts(config *postsRequestConfig) (count int, err error) { return } -func allPostPaths(status postStatus) ([]string, error) { +func (d *database) allPostPaths(status postStatus) ([]string, error) { var postPaths []string - rows, err := appDb.query("select path from posts where status = @status", sql.Named("status", status)) + rows, err := d.query("select path from posts where status = @status", sql.Named("status", status)) if err != nil { return nil, err } @@ -378,9 +365,23 @@ func allPostPaths(status postStatus) ([]string, error) { return postPaths, nil } -func allTaxonomyValues(blog string, taxonomy string) ([]string, error) { +func (a *goBlog) getRandomPostPath(blog string) (string, error) { + var sections []string + for sectionKey := range a.cfg.Blogs[blog].Sections { + sections = append(sections, sectionKey) + } + posts, err := a.db.getPosts(&postsRequestConfig{randomOrder: true, limit: 1, blog: blog, sections: sections}) + if err != nil { + return "", err + } else if len(posts) == 0 { + return "", errPostNotFound + } + return posts[0].Path, nil +} + +func (d *database) allTaxonomyValues(blog string, taxonomy string) ([]string, error) { var values []string - rows, err := appDb.query("select distinct pp.value from posts p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @tax and length(coalesce(pp.value, '')) > 1 and blog = @blog and status = @status", sql.Named("tax", taxonomy), sql.Named("blog", blog), sql.Named("status", statusPublished)) + rows, err := d.query("select distinct pp.value from posts p left outer join post_parameters pp on p.path = pp.path where pp.parameter = @tax and length(coalesce(pp.value, '')) > 1 and blog = @blog and status = @status", sql.Named("tax", taxonomy), sql.Named("blog", blog), sql.Named("status", statusPublished)) if err != nil { return nil, err } @@ -396,8 +397,8 @@ type publishedDate struct { year, month, day int } -func allPublishedDates(blog string) (dates []publishedDate, err error) { - rows, err := appDb.query("select distinct substr(published, 1, 4) as year, substr(published, 6, 2) as month, substr(published, 9, 2) as day from posts where blog = @blog and status = @status and year != '' and month != '' and day != ''", sql.Named("blog", blog), sql.Named("status", statusPublished)) +func (d *database) allPublishedDates(blog string) (dates []publishedDate, err error) { + rows, err := d.query("select distinct substr(published, 1, 4) as year, substr(published, 6, 2) as month, substr(published, 9, 2) as day from posts where blog = @blog and status = @status and year != '' and month != '' and day != ''", sql.Named("blog", blog), sql.Named("status", statusPublished)) if err != nil { return nil, err } diff --git a/postsFuncs.go b/postsFuncs.go index 4dfc60f..8a39ef7 100644 --- a/postsFuncs.go +++ b/postsFuncs.go @@ -8,19 +8,19 @@ import ( "github.com/PuerkitoBio/goquery" ) -func (p *post) fullURL() string { - return appConfig.Server.PublicAddress + p.Path +func (a *goBlog) fullPostURL(p *post) string { + return a.cfg.Server.PublicAddress + p.Path } -func (p *post) shortURL() string { - s, err := shortenPath(p.Path) +func (a *goBlog) shortPostURL(p *post) string { + s, err := a.db.shortenPath(p.Path) if err != nil { return "" } - if appConfig.Server.ShortPublicAddress != "" { - return appConfig.Server.ShortPublicAddress + s + if a.cfg.Server.ShortPublicAddress != "" { + return a.cfg.Server.ShortPublicAddress + s } - return appConfig.Server.PublicAddress + s + return a.cfg.Server.PublicAddress + s } func (p *post) firstParameter(parameter string) (result string) { @@ -34,11 +34,11 @@ func (p *post) title() string { return p.firstParameter("title") } -func (p *post) html() template.HTML { +func (a *goBlog) html(p *post) template.HTML { if p.rendered != "" { return p.rendered } - htmlContent, err := renderMarkdown(p.Content, false) + htmlContent, err := a.renderMarkdown(p.Content, false) if err != nil { log.Fatal(err) return "" @@ -47,11 +47,11 @@ func (p *post) html() template.HTML { return p.rendered } -func (p *post) absoluteHTML() template.HTML { +func (a *goBlog) absoluteHTML(p *post) template.HTML { if p.absoluteRendered != "" { return p.absoluteRendered } - htmlContent, err := renderMarkdown(p.Content, true) + htmlContent, err := a.renderMarkdown(p.Content, true) if err != nil { log.Fatal(err) return "" @@ -62,12 +62,12 @@ func (p *post) absoluteHTML() template.HTML { const summaryDivider = "" -func (p *post) summary() (summary string) { +func (a *goBlog) summary(p *post) (summary string) { summary = p.firstParameter("summary") if summary != "" { return } - html := string(p.html()) + html := string(a.html(p)) if splitted := strings.Split(html, summaryDivider); len(splitted) > 1 { doc, _ := goquery.NewDocumentFromReader(strings.NewReader(splitted[0])) summary = doc.Text() @@ -78,12 +78,12 @@ func (p *post) summary() (summary string) { return } -func (p *post) translations() []*post { +func (a *goBlog) translations(p *post) []*post { translationkey := p.firstParameter("translationkey") if translationkey == "" { return nil } - posts, err := getPosts(&postsRequestConfig{ + posts, err := a.db.getPosts(&postsRequestConfig{ parameter: "translationkey", parameterValue: translationkey, }) diff --git a/queue.go b/queue.go index 56e4f2d..6bd53a2 100644 --- a/queue.go +++ b/queue.go @@ -8,11 +8,11 @@ import ( "github.com/araddon/dateparse" ) -func enqueue(name string, content []byte, schedule time.Time) error { +func (db *database) enqueue(name string, content []byte, schedule time.Time) error { if len(content) == 0 { return errors.New("empty content") } - _, err := appDb.exec("insert into queue (name, content, schedule) values (@name, @content, @schedule)", + _, err := db.exec("insert into queue (name, content, schedule) values (@name, @content, @schedule)", sql.Named("name", name), sql.Named("content", content), sql.Named("schedule", schedule.UTC().String())) return err } @@ -24,18 +24,18 @@ type queueItem struct { schedule *time.Time } -func (qi *queueItem) reschedule(dur time.Duration) error { - _, err := appDb.exec("update queue set schedule = @schedule, content = @content where id = @id", sql.Named("schedule", qi.schedule.Add(dur).UTC().String()), sql.Named("content", qi.content), sql.Named("id", qi.id)) +func (db *database) reschedule(qi *queueItem, dur time.Duration) error { + _, err := db.exec("update queue set schedule = @schedule, content = @content where id = @id", sql.Named("schedule", qi.schedule.Add(dur).UTC().String()), sql.Named("content", qi.content), sql.Named("id", qi.id)) return err } -func (qi *queueItem) dequeue() error { - _, err := appDb.exec("delete from queue where id = @id", sql.Named("id", qi.id)) +func (db *database) dequeue(qi *queueItem) error { + _, err := db.exec("delete from queue where id = @id", sql.Named("id", qi.id)) return err } -func peekQueue(name string) (*queueItem, error) { - row, err := appDb.queryRow("select id, name, content, schedule from queue where schedule <= @schedule and name = @name order by schedule asc limit 1", sql.Named("name", name), sql.Named("schedule", time.Now().UTC().String())) +func (db *database) peekQueue(name string) (*queueItem, error) { + row, err := db.queryRow("select id, name, content, schedule from queue where schedule <= @schedule and name = @name order by schedule asc limit 1", sql.Named("name", name), sql.Named("schedule", time.Now().UTC().String())) if err != nil { return nil, err } diff --git a/regexRedirects.go b/regexRedirects.go index 8a96cc1..139159c 100644 --- a/regexRedirects.go +++ b/regexRedirects.go @@ -5,16 +5,14 @@ import ( "regexp" ) -var regexRedirects []*regexRedirect - type regexRedirect struct { From *regexp.Regexp To string Type int } -func initRegexRedirects() error { - for _, cr := range appConfig.PathRedirects { +func (a *goBlog) initRegexRedirects() error { + for _, cr := range a.cfg.PathRedirects { re, err := regexp.Compile(cr.From) if err != nil { return err @@ -27,14 +25,14 @@ func initRegexRedirects() error { if r.Type == 0 { r.Type = http.StatusFound } - regexRedirects = append(regexRedirects, r) + a.regexRedirects = append(a.regexRedirects, r) } return nil } -func checkRegexRedirects(next http.Handler) http.Handler { +func (a *goBlog) checkRegexRedirects(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for _, re := range regexRedirects { + for _, re := range a.regexRedirects { if newPath := re.From.ReplaceAllString(r.URL.Path, re.To); r.URL.Path != newPath { r.URL.Path = newPath http.Redirect(w, r, r.URL.String(), re.Type) diff --git a/render.go b/render.go index 6714464..f80187c 100644 --- a/render.go +++ b/render.go @@ -45,18 +45,17 @@ const ( templateBlogroll = "blogroll" ) -var templates map[string]*template.Template = map[string]*template.Template{} - -func initRendering() error { +func (a *goBlog) initRendering() error { + a.templates = map[string]*template.Template{} templateFunctions := template.FuncMap{ "menu": func(blog *configBlog, id string) *menu { return blog.Menus[id] }, "user": func() *configUser { - return appConfig.User + return a.cfg.User }, "md": func(content string) template.HTML { - htmlContent, err := renderMarkdown(content, false) + htmlContent, err := a.renderMarkdown(content, false) if err != nil { log.Fatal(err) return "" @@ -80,16 +79,16 @@ func initRendering() error { return p.title() }, "content": func(p *post) template.HTML { - return p.html() + return a.html(p) }, "summary": func(p *post) string { - return p.summary() + return a.summary(p) }, "translations": func(p *post) []*post { - return p.translations() + return a.translations(p) }, "shorturl": func(p *post) string { - return p.shortURL() + return a.shortPostURL(p) }, // Others "dateformat": dateFormat, @@ -120,9 +119,9 @@ func initRendering() error { } return d.Before(b) }, - "asset": assetFileName, - "assetsri": assetSRI, - "string": appTs.GetTemplateStringVariantFunc(), + "asset": a.assetFileName, + "assetsri": a.assetSRI, + "string": a.ts.GetTemplateStringVariantFunc(), "include": func(templateName string, data ...interface{}) (template.HTML, error) { if len(data) == 0 || len(data) > 2 { return "", errors.New("wrong argument count") @@ -134,7 +133,7 @@ func initRendering() error { rd = &nrd } var buf bytes.Buffer - err := templates[templateName].ExecuteTemplate(&buf, templateName, rd) + err := a.templates[templateName].ExecuteTemplate(&buf, templateName, rd) return template.HTML(buf.String()), err } return "", errors.New("wrong arguments") @@ -142,7 +141,7 @@ func initRendering() error { "urlize": urlize, "sort": sortedStrings, "absolute": func(path string) string { - return appConfig.Server.PublicAddress + path + return a.cfg.Server.PublicAddress + path }, "blogrelative": func(blog *configBlog, path string) string { return blog.getRelativePath(path) @@ -161,7 +160,7 @@ func initRendering() error { return parsed }, "mentions": func(absolute string) []*mention { - mentions, _ := getWebmentions(&webmentionsRequestConfig{ + mentions, _ := a.db.getWebmentions(&webmentionsRequestConfig{ target: absolute, status: webmentionStatusApproved, asc: true, @@ -181,7 +180,7 @@ func initRendering() error { } return }, - "geotitle": geoTitle, + "geotitle": a.db.geoTitle, } baseTemplate, err := template.New("base").Funcs(templateFunctions).ParseFiles(path.Join(templatesDir, templateBase+templatesExt)) @@ -194,7 +193,7 @@ func initRendering() error { } if info.Mode().IsRegular() && path.Ext(p) == templatesExt { if name := strings.TrimSuffix(path.Base(p), templatesExt); name != templateBase { - if templates[name], err = template.Must(baseTemplate.Clone()).New(name).ParseFiles(p); err != nil { + if a.templates[name], err = template.Must(baseTemplate.Clone()).New(name).ParseFiles(p); err != nil { return err } } @@ -219,26 +218,26 @@ type renderData struct { TorUsed bool } -func render(w http.ResponseWriter, r *http.Request, template string, data *renderData) { +func (a *goBlog) render(w http.ResponseWriter, r *http.Request, template string, data *renderData) { // Server timing t := servertiming.FromContext(r.Context()).NewMetric("r").Start() // Check render data if data.Blog == nil { if len(data.BlogString) == 0 { - data.BlogString = appConfig.DefaultBlog + data.BlogString = a.cfg.DefaultBlog } - data.Blog = appConfig.Blogs[data.BlogString] + data.Blog = a.cfg.Blogs[data.BlogString] } if data.BlogString == "" { - for s, b := range appConfig.Blogs { + for s, b := range a.cfg.Blogs { if b == data.Blog { data.BlogString = s break } } } - if appConfig.Server.Tor && torAddress != "" { - data.TorAddress = fmt.Sprintf("http://%v%v", torAddress, r.RequestURI) + if a.cfg.Server.Tor && a.torAddress != "" { + data.TorAddress = fmt.Sprintf("http://%v%v", a.torAddress, r.RequestURI) } if data.Data == nil { data.Data = map[string]interface{}{} @@ -250,23 +249,25 @@ func render(w http.ResponseWriter, r *http.Request, template string, data *rende // Check if comments enabled data.CommentsEnabled = data.Blog.Comments != nil && data.Blog.Comments.Enabled // Check if able to receive webmentions - data.WebmentionReceivingEnabled = appConfig.Webmention == nil || !appConfig.Webmention.DisableReceiving + data.WebmentionReceivingEnabled = a.cfg.Webmention == nil || !a.cfg.Webmention.DisableReceiving // Check if Tor request if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed { data.TorUsed = true } + // Set content type + w.Header().Set(contentType, contentTypeHTMLUTF8) // Minify and write response - mw := minifier.Writer(contentTypeHTML, w) - defer func() { - _ = mw.Close() - }() - err := templates[template].ExecuteTemplate(mw, template, data) + var tw bytes.Buffer + err := a.templates[template].ExecuteTemplate(&tw, template, data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, err = writeMinified(w, contentTypeHTML, tw.Bytes()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - // Set content type - w.Header().Set(contentType, contentTypeHTMLUTF8) // Server timing t.Stop() } diff --git a/reverseGeo.go b/reverseGeo.go index 579e29f..c8b9889 100644 --- a/reverseGeo.go +++ b/reverseGeo.go @@ -12,8 +12,8 @@ import ( "github.com/thoas/go-funk" ) -func geoTitle(lat, lon float64, lang string) string { - ba, err := photonReverse(lat, lon, lang) +func (db *database) geoTitle(lat, lon float64, lang string) string { + ba, err := db.photonReverse(lat, lon, lang) if err != nil { return "" } @@ -29,9 +29,9 @@ func geoTitle(lat, lon float64, lang string) string { return strings.Join(funk.FilterString([]string{name, city, state, country}, func(s string) bool { return s != "" }), ", ") } -func photonReverse(lat, lon float64, lang string) ([]byte, error) { +func (db *database) photonReverse(lat, lon float64, lang string) ([]byte, error) { cacheKey := fmt.Sprintf("photon-%v-%v-%v", lat, lon, lang) - cache, _ := retrievePersistentCache(cacheKey) + cache, _ := db.retrievePersistentCache(cacheKey) if cache != nil { return cache, nil } @@ -63,6 +63,6 @@ func photonReverse(lat, lon float64, lang string) ([]byte, error) { if err != nil { return nil, err } - _ = cachePersistently(cacheKey, ba) + _ = db.cachePersistently(cacheKey, ba) return ba, nil } diff --git a/robotstxt.go b/robotstxt.go index c33bf51..e76cddc 100644 --- a/robotstxt.go +++ b/robotstxt.go @@ -5,8 +5,8 @@ import ( "net/http" ) -func serveRobotsTXT(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte(fmt.Sprintf("User-agent: *\nSitemap: %v", appConfig.Server.PublicAddress+sitemapPath))) +func (a *goBlog) serveRobotsTXT(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(fmt.Sprintf("User-agent: *\nSitemap: %v", a.cfg.Server.PublicAddress+sitemapPath))) } func servePrivateRobotsTXT(w http.ResponseWriter, r *http.Request) { diff --git a/search.go b/search.go index a03057a..10fc07e 100644 --- a/search.go +++ b/search.go @@ -11,26 +11,26 @@ import ( const searchPlaceholder = "{search}" -func serveSearch(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveSearch(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) servePath := r.Context().Value(pathContextKey).(string) err := r.ParseForm() if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } if q := r.Form.Get("q"); q != "" { http.Redirect(w, r, path.Join(servePath, searchEncode(q)), http.StatusFound) return } - render(w, r, templateSearch, &renderData{ + a.render(w, r, templateSearch, &renderData{ BlogString: blog, - Canonical: appConfig.Server.PublicAddress + servePath, + Canonical: a.cfg.Server.PublicAddress + servePath, }) } -func serveSearchResult(w http.ResponseWriter, r *http.Request) { - serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ +func (a *goBlog) serveSearchResult(w http.ResponseWriter, r *http.Request) { + a.serveIndex(w, r.WithContext(context.WithValue(r.Context(), indexConfigKey, &indexConfig{ path: r.Context().Value(pathContextKey).(string) + "/" + searchPlaceholder, }))) } diff --git a/sessions.go b/sessions.go index bab1881..32c7a27 100644 --- a/sessions.go +++ b/sessions.go @@ -13,46 +13,47 @@ import ( "github.com/gorilla/sessions" ) -var loginSessionsStore, captchaSessionsStore *dbSessionStore - const ( sessionCreatedOn = "created" sessionModifiedOn = "modified" sessionExpiresOn = "expires" ) -func initSessions() { +func (a *goBlog) initSessions() { deleteExpiredSessions := func() { - if _, err := appDb.exec("delete from sessions where expires < @now", + if _, err := a.db.exec("delete from sessions where expires < @now", sql.Named("now", time.Now().Local().String())); err != nil { log.Println("Failed to delete expired sessions:", err.Error()) } } deleteExpiredSessions() hourlyHooks = append(hourlyHooks, deleteExpiredSessions) - loginSessionsStore = &dbSessionStore{ - codecs: securecookie.CodecsFromPairs(jwtKey()), + a.loginSessions = &dbSessionStore{ + codecs: securecookie.CodecsFromPairs(a.jwtKey()), options: &sessions.Options{ - Secure: httpsConfigured(), + Secure: a.httpsConfigured(), HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int((7 * 24 * time.Hour).Seconds()), }, + db: a.db, } - captchaSessionsStore = &dbSessionStore{ - codecs: securecookie.CodecsFromPairs(jwtKey()), + a.captchaSessions = &dbSessionStore{ + codecs: securecookie.CodecsFromPairs(a.jwtKey()), options: &sessions.Options{ - Secure: httpsConfigured(), + Secure: a.httpsConfigured(), HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: int((24 * time.Hour).Seconds()), }, + db: a.db, } } type dbSessionStore struct { options *sessions.Options codecs []securecookie.Codec + db *database } func (s *dbSessionStore) Get(r *http.Request, name string) (*sessions.Session, error) { @@ -101,14 +102,14 @@ func (s *dbSessionStore) Delete(r *http.Request, w http.ResponseWriter, session for k := range session.Values { delete(session.Values, k) } - if _, err := appDb.exec("delete from sessions where id = @id", sql.Named("id", session.ID)); err != nil { + if _, err := s.db.exec("delete from sessions where id = @id", sql.Named("id", session.ID)); err != nil { return err } return nil } func (s *dbSessionStore) load(session *sessions.Session) (err error) { - row, err := appDb.queryRow("select data, created, modified, expires from sessions where id = @id", sql.Named("id", session.ID)) + row, err := s.db.queryRow("select data, created, modified, expires from sessions where id = @id", sql.Named("id", session.ID)) if err != nil { return err } @@ -144,7 +145,7 @@ func (s *dbSessionStore) insert(session *sessions.Session) (err error) { if err != nil { return err } - res, err := appDb.exec("insert into sessions(data, created, modified, expires) values(@data, @created, @modified, @expires)", + res, err := s.db.exec("insert into sessions(data, created, modified, expires) values(@data, @created, @modified, @expires)", sql.Named("data", encoded), sql.Named("created", created.Local().String()), sql.Named("modified", modified.Local().String()), sql.Named("expires", expires.Local().String())) if err != nil { return err @@ -168,7 +169,7 @@ func (s *dbSessionStore) save(session *sessions.Session) (err error) { if err != nil { return err } - _, err = appDb.exec("update sessions set data = @data, modified = @modified where id = @id", + _, err = s.db.exec("update sessions set data = @data, modified = @modified where id = @id", sql.Named("data", encoded), sql.Named("modified", time.Now().Local().String()), sql.Named("id", session.ID)) if err != nil { return err diff --git a/shortDomain.go b/shortDomain.go index 1f5136a..8fbb4de 100644 --- a/shortDomain.go +++ b/shortDomain.go @@ -4,10 +4,10 @@ import ( "net/http" ) -func redirectShortDomain(next http.Handler) http.Handler { +func (a *goBlog) redirectShortDomain(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if appConfig.Server.shortPublicHostname != "" && r.Host == appConfig.Server.shortPublicHostname { - http.Redirect(rw, r, appConfig.Server.PublicAddress+r.RequestURI, http.StatusMovedPermanently) + if a.cfg.Server.shortPublicHostname != "" && r.Host == a.cfg.Server.shortPublicHostname { + http.Redirect(rw, r, a.cfg.Server.PublicAddress+r.RequestURI, http.StatusMovedPermanently) return } next.ServeHTTP(rw, r) diff --git a/shortPath.go b/shortPath.go index 30b3b59..6e00fae 100644 --- a/shortPath.go +++ b/shortPath.go @@ -10,17 +10,17 @@ import ( "github.com/go-chi/chi/v5" ) -func shortenPath(p string) (string, error) { +func (db *database) shortenPath(p string) (string, error) { if p == "" { return "", errors.New("empty path") } - id := getShortPathID(p) + id := db.getShortPathID(p) if id == -1 { - _, err := appDb.exec("insert or ignore into shortpath (path) values (@path)", sql.Named("path", p)) + _, err := db.exec("insert or ignore into shortpath (path) values (@path)", sql.Named("path", p)) if err != nil { return "", err } - id = getShortPathID(p) + id = db.getShortPathID(p) } if id == -1 { return "", errors.New("failed to retrieve short path for " + p) @@ -28,11 +28,11 @@ func shortenPath(p string) (string, error) { return fmt.Sprintf("/s/%x", id), nil } -func getShortPathID(p string) (id int) { +func (db *database) getShortPathID(p string) (id int) { if p == "" { return -1 } - row, err := appDb.queryRow("select id from shortpath where path = @path", sql.Named("path", p)) + row, err := db.queryRow("select id from shortpath where path = @path", sql.Named("path", p)) if err != nil { return -1 } @@ -43,21 +43,21 @@ func getShortPathID(p string) (id int) { return id } -func redirectToLongPath(rw http.ResponseWriter, r *http.Request) { +func (a *goBlog) redirectToLongPath(rw http.ResponseWriter, r *http.Request) { id, err := strconv.ParseInt(chi.URLParam(r, "id"), 16, 64) if err != nil { - serve404(rw, r) + a.serve404(rw, r) return } - row, err := appDb.queryRow("select path from shortpath where id = @id", sql.Named("id", id)) + row, err := a.db.queryRow("select path from shortpath where id = @id", sql.Named("id", id)) if err != nil { - serve404(rw, r) + a.serve404(rw, r) return } var path string err = row.Scan(&path) if err != nil { - serve404(rw, r) + a.serve404(rw, r) return } http.Redirect(rw, r, path, http.StatusMovedPermanently) diff --git a/sitemap.go b/sitemap.go index 41552fc..a8df206 100644 --- a/sitemap.go +++ b/sitemap.go @@ -11,24 +11,24 @@ import ( const sitemapPath = "/sitemap.xml" -func serveSitemap(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveSitemap(w http.ResponseWriter, r *http.Request) { sm := sitemap.New() sm.Minify = true // Blogs - for b, bc := range appConfig.Blogs { + for b, bc := range a.cfg.Blogs { // Blog blogPath := bc.Path if blogPath == "/" { blogPath = "" } sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + blogPath, + Loc: a.cfg.Server.PublicAddress + blogPath, }) // Sections for _, section := range bc.Sections { if section.Name != "" { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + bc.getRelativePath("/"+section.Name), + Loc: a.cfg.Server.PublicAddress + bc.getRelativePath("/"+section.Name), }) } } @@ -38,27 +38,27 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { // Taxonomy taxPath := bc.getRelativePath("/" + taxonomy.Name) sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + taxPath, + Loc: a.cfg.Server.PublicAddress + taxPath, }) // Values - if taxValues, err := allTaxonomyValues(b, taxonomy.Name); err == nil { + if taxValues, err := a.db.allTaxonomyValues(b, taxonomy.Name); err == nil { for _, tv := range taxValues { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + taxPath + "/" + urlize(tv), + Loc: a.cfg.Server.PublicAddress + taxPath + "/" + urlize(tv), }) } } } } // Year / month archives - if dates, err := allPublishedDates(b); err == nil { + if dates, err := a.db.allPublishedDates(b); err == nil { already := map[string]bool{} for _, d := range dates { // Year yearPath := bc.getRelativePath("/" + fmt.Sprintf("%0004d", d.year)) if !already[yearPath] { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + yearPath, + Loc: a.cfg.Server.PublicAddress + yearPath, }) already[yearPath] = true } @@ -66,7 +66,7 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { monthPath := yearPath + "/" + fmt.Sprintf("%02d", d.month) if !already[monthPath] { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + monthPath, + Loc: a.cfg.Server.PublicAddress + monthPath, }) already[monthPath] = true } @@ -74,7 +74,7 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { dayPath := monthPath + "/" + fmt.Sprintf("%02d", d.day) if !already[dayPath] { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + dayPath, + Loc: a.cfg.Server.PublicAddress + dayPath, }) already[dayPath] = true } @@ -82,7 +82,7 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { genericMonthPath := blogPath + "/x/" + fmt.Sprintf("%02d", d.month) if !already[genericMonthPath] { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + genericMonthPath, + Loc: a.cfg.Server.PublicAddress + genericMonthPath, }) already[genericMonthPath] = true } @@ -90,7 +90,7 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { genericMonthDayPath := genericMonthPath + "/" + fmt.Sprintf("%02d", d.day) if !already[genericMonthDayPath] { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + genericMonthDayPath, + Loc: a.cfg.Server.PublicAddress + genericMonthDayPath, }) already[genericMonthDayPath] = true } @@ -99,38 +99,38 @@ func serveSitemap(w http.ResponseWriter, r *http.Request) { // Photos if bc.Photos != nil && bc.Photos.Enabled { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + bc.getRelativePath(bc.Photos.Path), + Loc: a.cfg.Server.PublicAddress + bc.getRelativePath(bc.Photos.Path), }) } // Search if bc.Search != nil && bc.Search.Enabled { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + bc.getRelativePath(bc.Search.Path), + Loc: a.cfg.Server.PublicAddress + bc.getRelativePath(bc.Search.Path), }) } // Stats if bc.BlogStats != nil && bc.BlogStats.Enabled { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + bc.getRelativePath(bc.BlogStats.Path), + Loc: a.cfg.Server.PublicAddress + bc.getRelativePath(bc.BlogStats.Path), }) } // Blogroll if bc.Blogroll != nil && bc.Blogroll.Enabled { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + bc.getRelativePath(bc.Blogroll.Path), + Loc: a.cfg.Server.PublicAddress + bc.getRelativePath(bc.Blogroll.Path), }) } // Custom pages for _, cp := range bc.CustomPages { sm.Add(&sitemap.URL{ - Loc: appConfig.Server.PublicAddress + cp.Path, + Loc: a.cfg.Server.PublicAddress + cp.Path, }) } } // Posts - if posts, err := getPosts(&postsRequestConfig{status: statusPublished}); err == nil { + if posts, err := a.db.getPosts(&postsRequestConfig{status: statusPublished}); err == nil { for _, p := range posts { - item := &sitemap.URL{Loc: p.fullURL()} + item := &sitemap.URL{Loc: a.fullPostURL(p)} var lastMod time.Time if p.Updated != "" { lastMod, _ = dateparse.ParseLocal(p.Updated) diff --git a/staticFiles.go b/staticFiles.go index 8ad97e6..fc0b703 100644 --- a/staticFiles.go +++ b/staticFiles.go @@ -28,7 +28,7 @@ func allStaticPaths() (paths []string) { } // Gets only called by registered paths -func serveStaticFile(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", appConfig.Cache.Expiration, appConfig.Cache.Expiration/3, appConfig.Cache.Expiration)) +func (a *goBlog) serveStaticFile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", fmt.Sprintf("public,max-age=%d,s-max-age=%d,stale-while-revalidate=%d", a.cfg.Cache.Expiration, a.cfg.Cache.Expiration/3, a.cfg.Cache.Expiration)) http.ServeFile(w, r, filepath.Join(staticFolder, r.URL.Path)) } diff --git a/taxonomies.go b/taxonomies.go index 4e2cddd..cb893f6 100644 --- a/taxonomies.go +++ b/taxonomies.go @@ -4,17 +4,17 @@ import "net/http" const taxonomyContextKey = "taxonomy" -func serveTaxonomy(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) serveTaxonomy(w http.ResponseWriter, r *http.Request) { blog := r.Context().Value(blogContextKey).(string) tax := r.Context().Value(taxonomyContextKey).(*taxonomy) - allValues, err := allTaxonomyValues(blog, tax.Name) + allValues, err := a.db.allTaxonomyValues(blog, tax.Name) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - render(w, r, templateTaxonomy, &renderData{ + a.render(w, r, templateTaxonomy, &renderData{ BlogString: blog, - Canonical: appConfig.Server.PublicAddress + r.URL.Path, + Canonical: a.cfg.Server.PublicAddress + r.URL.Path, Data: map[string]interface{}{ "Taxonomy": tax, "ValueGroups": groupStrings(allValues), diff --git a/telegram.go b/telegram.go index 74af3fe..083ef5f 100644 --- a/telegram.go +++ b/telegram.go @@ -13,40 +13,39 @@ import ( const telegramBaseURL = "https://api.telegram.org/bot" -func initTelegram() { +func (a *goBlog) initTelegram() { enable := false - for _, b := range appConfig.Blogs { + for _, b := range a.cfg.Blogs { if tg := b.Telegram; tg != nil && tg.Enabled && tg.BotToken != "" && tg.ChatID != "" { enable = true } } if enable { - postPostHooks = append(postPostHooks, func(p *post) { + a.pPostHooks = append(a.pPostHooks, func(p *post) { if p.isPublishedSectionPost() { - p.tgPost() + tgPost(a.cfg.Blogs[p.Blog].Telegram, p.title(), a.fullPostURL(p), a.shortPostURL(p)) } }) } } -func (p *post) tgPost() { - tg := appConfig.Blogs[p.Blog].Telegram +func tgPost(tg *configTelegram, title, fullURL, shortURL string) { if tg == nil || !tg.Enabled || tg.BotToken == "" || tg.ChatID == "" { return } replacer := strings.NewReplacer("<", "<", ">", ">", "&", "&") var message bytes.Buffer - if title := p.title(); title != "" { + if title != "" { message.WriteString(replacer.Replace(title)) message.WriteString("\n\n") } if tg.InstantViewHash != "" { - message.WriteString("") - message.WriteString(replacer.Replace(p.shortURL())) + message.WriteString("") + message.WriteString(replacer.Replace(shortURL)) message.WriteString("") } else { - message.WriteString("") - message.WriteString(replacer.Replace(p.shortURL())) + message.WriteString("") + message.WriteString(replacer.Replace(shortURL)) message.WriteString("") } if err := sendTelegramMessage(message.String(), "HTML", tg.BotToken, tg.ChatID); err != nil { diff --git a/templateAssets.go b/templateAssets.go index c22a3f5..4b406a7 100644 --- a/templateAssets.go +++ b/templateAssets.go @@ -16,24 +16,23 @@ import ( const assetsFolder = "templates/assets" -var assetFileNames map[string]string = map[string]string{} -var assetFiles map[string]*assetFile = map[string]*assetFile{} - type assetFile struct { contentType string sri string body []byte } -func initTemplateAssets() (err error) { +func (a *goBlog) initTemplateAssets() (err error) { + a.assetFileNames = map[string]string{} + a.assetFiles = map[string]*assetFile{} err = filepath.Walk(assetsFolder, func(path string, info os.FileInfo, err error) error { if info.Mode().IsRegular() { - compiled, err := compileAsset(path) + compiled, err := a.compileAsset(path) if err != nil { return err } if compiled != "" { - assetFileNames[strings.TrimPrefix(path, assetsFolder+"/")] = compiled + a.assetFileNames[strings.TrimPrefix(path, assetsFolder+"/")] = compiled } } return nil @@ -44,21 +43,22 @@ func initTemplateAssets() (err error) { return nil } -func compileAsset(name string) (string, error) { +func (a *goBlog) compileAsset(name string) (string, error) { content, err := os.ReadFile(name) if err != nil { return "", err } ext := path.Ext(name) compiledExt := ext + m := getMinifier() switch ext { case ".js": - content, err = minifier.Bytes("application/javascript", content) + content, err = m.Bytes("application/javascript", content) if err != nil { return "", err } case ".css": - content, err = minifier.Bytes("text/css", content) + content, err = m.Bytes("text/css", content) if err != nil { return "", err } @@ -76,7 +76,7 @@ func compileAsset(name string) (string, error) { // SRI sriHash := fmt.Sprintf("sha512-%s", base64.StdEncoding.EncodeToString(sha512Hash.Sum(nil))) // Create struct - assetFiles[compiledFileName] = &assetFile{ + a.assetFiles[compiledFileName] = &assetFile{ contentType: mime.TypeByExtension(compiledExt), sri: sriHash, body: content, @@ -85,27 +85,27 @@ func compileAsset(name string) (string, error) { } // Function for templates -func assetFileName(fileName string) string { - return "/" + assetFileNames[fileName] +func (a *goBlog) assetFileName(fileName string) string { + return "/" + a.assetFileNames[fileName] } -func assetSRI(fileName string) string { - return assetFiles[assetFileNames[fileName]].sri +func (a *goBlog) assetSRI(fileName string) string { + return a.assetFiles[a.assetFileNames[fileName]].sri } -func allAssetPaths() []string { +func (a *goBlog) allAssetPaths() []string { var paths []string - for _, name := range assetFileNames { + for _, name := range a.assetFileNames { paths = append(paths, "/"+name) } return paths } // Gets only called by registered paths -func serveAsset(w http.ResponseWriter, r *http.Request) { - af, ok := assetFiles[strings.TrimPrefix(r.URL.Path, "/")] +func (a *goBlog) serveAsset(w http.ResponseWriter, r *http.Request) { + af, ok := a.assetFiles[strings.TrimPrefix(r.URL.Path, "/")] if !ok { - serve404(w, r) + a.serve404(w, r) return } w.Header().Set("Cache-Control", "public,max-age=31536000,immutable") diff --git a/templateStrings.go b/templateStrings.go index 3949d7d..deee57b 100644 --- a/templateStrings.go +++ b/templateStrings.go @@ -4,13 +4,11 @@ import ( ts "git.jlel.se/jlelse/template-strings" ) -var appTs *ts.TemplateStrings - -func initTemplateStrings() (err error) { +func (a *goBlog) initTemplateStrings() (err error) { var blogLangs []string - for _, b := range appConfig.Blogs { + for _, b := range a.cfg.Blogs { blogLangs = append(blogLangs, b.Lang) } - appTs, err = ts.InitTemplateStrings("templates/strings", ".yaml", "default", blogLangs...) + a.ts, err = ts.InitTemplateStrings("templates/strings", ".yaml", "default", blogLangs...) return err } diff --git a/tor.go b/tor.go index ab98e45..8172ce9 100644 --- a/tor.go +++ b/tor.go @@ -16,13 +16,9 @@ import ( "github.com/go-chi/chi/v5/middleware" ) -var ( - torAddress string -) - var torUsedKey requestContextKey = "tor" -func startOnionService(h http.Handler) error { +func (a *goBlog) startOnionService(h http.Handler) error { torDataPath, err := filepath.Abs("data/tor") if err != nil { return err @@ -76,10 +72,10 @@ func startOnionService(h http.Handler) error { return err } defer onion.Close() - torAddress = onion.String() - log.Println("Onion service published on http://" + torAddress) + a.torAddress = onion.String() + log.Println("Onion service published on http://" + a.torAddress) // Clear cache - purgeCache() + a.cache.purge() // Serve handler s := &http.Server{ Handler: middleware.WithValue(torUsedKey, true)(h), diff --git a/webmention.go b/webmention.go index 4e91a3c..eb56c29 100644 --- a/webmention.go +++ b/webmention.go @@ -30,32 +30,32 @@ type mention struct { Status webmentionStatus } -func initWebmention() { +func (a *goBlog) initWebmention() { // Add hooks hookFunc := func(p *post) { if p.Status == statusPublished { - _ = p.sendWebmentions() + _ = a.sendWebmentions(p) } } - postPostHooks = append(postPostHooks, hookFunc) - postUpdateHooks = append(postUpdateHooks, hookFunc) - postDeleteHooks = append(postDeleteHooks, hookFunc) + a.pPostHooks = append(a.pPostHooks, hookFunc) + a.pUpdateHooks = append(a.pUpdateHooks, hookFunc) + a.pDeleteHooks = append(a.pDeleteHooks, hookFunc) // Start verifier - initWebmentionQueue() + a.initWebmentionQueue() } -func handleWebmention(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) handleWebmention(w http.ResponseWriter, r *http.Request) { m, err := extractMention(r) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - if !isAllowedHost(httptest.NewRequest(http.MethodGet, m.Target, nil), appConfig.Server.publicHostname) { - serveError(w, r, "target not allowed", http.StatusBadRequest) + if !isAllowedHost(httptest.NewRequest(http.MethodGet, m.Target, nil), a.cfg.Server.publicHostname) { + a.serveError(w, r, "target not allowed", http.StatusBadRequest) return } - if err = queueMention(m); err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + if err = a.queueMention(m); err != nil { + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } w.WriteHeader(http.StatusAccepted) @@ -82,9 +82,9 @@ func extractMention(r *http.Request) (*mention, error) { }, nil } -func webmentionExists(source, target string) bool { +func (db *database) webmentionExists(source, target string) bool { result := 0 - row, err := appDb.queryRow("select exists(select 1 from webmentions where source = ? and target = ?)", source, target) + row, err := db.queryRow("select exists(select 1 from webmentions where source = ? and target = ?)", source, target) if err != nil { return false } @@ -94,26 +94,26 @@ func webmentionExists(source, target string) bool { return result == 1 } -func createWebmention(source, target string) (err error) { - return queueMention(&mention{ +func (a *goBlog) createWebmention(source, target string) (err error) { + return a.queueMention(&mention{ Source: source, Target: unescapedPath(target), Created: time.Now().Unix(), }) } -func deleteWebmention(id int) error { - _, err := appDb.exec("delete from webmentions where id = @id", sql.Named("id", id)) +func (db *database) deleteWebmention(id int) error { + _, err := db.exec("delete from webmentions where id = @id", sql.Named("id", id)) return err } -func approveWebmention(id int) error { - _, err := appDb.exec("update webmentions set status = ? where id = ?", webmentionStatusApproved, id) +func (db *database) approveWebmention(id int) error { + _, err := db.exec("update webmentions set status = ? where id = ?", webmentionStatusApproved, id) return err } -func reverifyWebmention(id int) error { - m, err := getWebmentions(&webmentionsRequestConfig{ +func (a *goBlog) reverifyWebmention(id int) error { + m, err := a.db.getWebmentions(&webmentionsRequestConfig{ id: id, limit: 1, }) @@ -121,7 +121,7 @@ func reverifyWebmention(id int) error { return err } if len(m) > 0 { - err = queueMention(m[0]) + err = a.queueMention(m[0]) } return err } @@ -169,10 +169,10 @@ func buildWebmentionsQuery(config *webmentionsRequestConfig) (query string, args return query, args } -func getWebmentions(config *webmentionsRequestConfig) ([]*mention, error) { +func (db *database) getWebmentions(config *webmentionsRequestConfig) ([]*mention, error) { mentions := []*mention{} query, args := buildWebmentionsQuery(config) - rows, err := appDb.query(query, args...) + rows, err := db.query(query, args...) if err != nil { return nil, err } @@ -187,10 +187,10 @@ func getWebmentions(config *webmentionsRequestConfig) ([]*mention, error) { return mentions, nil } -func countWebmentions(config *webmentionsRequestConfig) (count int, err error) { +func (db *database) countWebmentions(config *webmentionsRequestConfig) (count int, err error) { query, params := buildWebmentionsQuery(config) query = "select count(*) from (" + query + ")" - row, err := appDb.queryRow(query, params...) + row, err := db.queryRow(query, params...) if err != nil { return } diff --git a/webmentionAdmin.go b/webmentionAdmin.go index 7eaa935..e2f5a9d 100644 --- a/webmentionAdmin.go +++ b/webmentionAdmin.go @@ -14,11 +14,12 @@ import ( type webmentionPaginationAdapter struct { config *webmentionsRequestConfig nums int64 + db *database } func (p *webmentionPaginationAdapter) Nums() (int64, error) { if p.nums == 0 { - nums, _ := countWebmentions(p.config) + nums, _ := p.db.countWebmentions(p.config) p.nums = int64(nums) } return p.nums, nil @@ -29,12 +30,12 @@ func (p *webmentionPaginationAdapter) Slice(offset, length int, data interface{} modifiedConfig.offset = offset modifiedConfig.limit = length - wms, err := getWebmentions(&modifiedConfig) + wms, err := p.db.getWebmentions(&modifiedConfig) reflect.ValueOf(data).Elem().Set(reflect.ValueOf(&wms).Elem()) return err } -func webmentionAdmin(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) webmentionAdmin(w http.ResponseWriter, r *http.Request) { pageNoString := chi.URLParam(r, "page") pageNo, _ := strconv.Atoi(pageNoString) var status webmentionStatus = "" @@ -48,12 +49,12 @@ func webmentionAdmin(w http.ResponseWriter, r *http.Request) { p := paginator.New(&webmentionPaginationAdapter{config: &webmentionsRequestConfig{ status: status, sourcelike: sourcelike, - }}, 10) + }, db: a.db}, 10) p.SetPage(pageNo) var mentions []*mention err := p.Results(&mentions) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } // Navigation @@ -91,7 +92,7 @@ func webmentionAdmin(w http.ResponseWriter, r *http.Request) { query = "?" + params.Encode() } // Render - render(w, r, templateWebmentionAdmin, &renderData{ + a.render(w, r, templateWebmentionAdmin, &renderData{ Data: map[string]interface{}{ "Mentions": mentions, "HasPrev": hasPrev, @@ -102,45 +103,45 @@ func webmentionAdmin(w http.ResponseWriter, r *http.Request) { }) } -func webmentionAdminDelete(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) webmentionAdminDelete(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.FormValue("mentionid")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - err = deleteWebmention(id) + err = a.db.deleteWebmention(id) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - purgeCache() + a.cache.purge() http.Redirect(w, r, ".", http.StatusFound) } -func webmentionAdminApprove(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) webmentionAdminApprove(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.FormValue("mentionid")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - err = approveWebmention(id) + err = a.db.approveWebmention(id) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } - purgeCache() + a.cache.purge() http.Redirect(w, r, ".", http.StatusFound) } -func webmentionAdminReverify(w http.ResponseWriter, r *http.Request) { +func (a *goBlog) webmentionAdminReverify(w http.ResponseWriter, r *http.Request) { id, err := strconv.Atoi(r.FormValue("mentionid")) if err != nil { - serveError(w, r, err.Error(), http.StatusBadRequest) + a.serveError(w, r, err.Error(), http.StatusBadRequest) return } - err = reverifyWebmention(id) + err = a.reverifyWebmention(id) if err != nil { - serveError(w, r, err.Error(), http.StatusInternalServerError) + a.serveError(w, r, err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, ".", http.StatusFound) diff --git a/webmentionSending.go b/webmentionSending.go index aa2f5c4..1bc388a 100644 --- a/webmentionSending.go +++ b/webmentionSending.go @@ -13,40 +13,44 @@ import ( "github.com/tomnomnom/linkheader" ) -func (p *post) sendWebmentions() error { - if wm := appConfig.Webmention; wm != nil && wm.DisableSending { +func (a *goBlog) sendWebmentions(p *post) error { + if wm := a.cfg.Webmention; wm != nil && wm.DisableSending { // Just ignore the mentions return nil } links := []string{} - contentLinks, err := allLinksFromHTML(strings.NewReader(string(p.html())), p.fullURL()) + contentLinks, err := allLinksFromHTML(strings.NewReader(string(a.html(p))), a.fullPostURL(p)) if err != nil { return err } links = append(links, contentLinks...) - links = append(links, p.firstParameter("link"), p.firstParameter(appConfig.Micropub.LikeParam), p.firstParameter(appConfig.Micropub.ReplyParam), p.firstParameter(appConfig.Micropub.BookmarkParam)) + links = append(links, p.firstParameter("link"), p.firstParameter(a.cfg.Micropub.LikeParam), p.firstParameter(a.cfg.Micropub.ReplyParam), p.firstParameter(a.cfg.Micropub.BookmarkParam)) for _, link := range funk.UniqString(links) { if link == "" { continue } // Internal mention - if strings.HasPrefix(link, appConfig.Server.PublicAddress) { + if strings.HasPrefix(link, a.cfg.Server.PublicAddress) { // Save mention directly - if err := createWebmention(p.fullURL(), link); err != nil { + if err := a.createWebmention(a.fullPostURL(p), link); err != nil { log.Println("Failed to create webmention:", err.Error()) } continue } // External mention - if pm := appConfig.PrivateMode; pm != nil && pm.Enabled { + if pm := a.cfg.PrivateMode; pm != nil && pm.Enabled { // Private mode, don't send external mentions continue } + if wm := a.cfg.Webmention; wm != nil && wm.DisableSending { + // Just ignore the mention + continue + } endpoint := discoverEndpoint(link) if endpoint == "" { continue } - if err = sendWebmention(endpoint, p.fullURL(), link); err != nil { + if err = sendWebmention(endpoint, a.fullPostURL(p), link); err != nil { log.Println("Sending webmention to " + link + " failed") continue } @@ -56,10 +60,6 @@ func (p *post) sendWebmentions() error { } func sendWebmention(endpoint, source, target string) error { - if wm := appConfig.Webmention; wm != nil && wm.DisableSending { - // Just ignore the mention - return nil - } req, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(url.Values{ "source": []string{source}, "target": []string{target}, diff --git a/webmentionVerification.go b/webmentionVerification.go index 1f1b78a..e2aa615 100644 --- a/webmentionVerification.go +++ b/webmentionVerification.go @@ -20,10 +20,10 @@ import ( "willnorris.com/go/microformats" ) -func initWebmentionQueue() { +func (a *goBlog) initWebmentionQueue() { go func() { for { - qi, err := peekQueue("wm") + qi, err := a.db.peekQueue("wm") if err != nil { log.Println(err.Error()) continue @@ -32,14 +32,14 @@ func initWebmentionQueue() { err = gob.NewDecoder(bytes.NewReader(qi.content)).Decode(&m) if err != nil { log.Println(err.Error()) - _ = qi.dequeue() + _ = a.db.dequeue(qi) continue } - err = m.verifyMention() + err = a.verifyMention(&m) if err != nil { log.Println(fmt.Sprintf("Failed to verify webmention from %s to %s: %s", m.Source, m.Target, err.Error())) } - err = qi.dequeue() + err = a.db.dequeue(qi) if err != nil { log.Println(err.Error()) } @@ -51,26 +51,30 @@ func initWebmentionQueue() { }() } -func queueMention(m *mention) error { - if wm := appConfig.Webmention; wm != nil && wm.DisableReceiving { +func (a *goBlog) queueMention(m *mention) error { + if wm := a.cfg.Webmention; wm != nil && wm.DisableReceiving { return errors.New("webmention receiving disabled") } var buf bytes.Buffer if err := gob.NewEncoder(&buf).Encode(m); err != nil { return err } - return enqueue("wm", buf.Bytes(), time.Now()) + return a.db.enqueue("wm", buf.Bytes(), time.Now()) } -func (m *mention) verifyMention() error { +func (a *goBlog) verifyMention(m *mention) error { req, err := http.NewRequest(http.MethodGet, m.Source, nil) if err != nil { return err } var resp *http.Response - if strings.HasPrefix(m.Source, appConfig.Server.PublicAddress) { + if strings.HasPrefix(m.Source, a.cfg.Server.PublicAddress) { rec := httptest.NewRecorder() - d.ServeHTTP(rec, req.WithContext(context.WithValue(req.Context(), loggedInKey, true))) + for a.d == nil { + // Server not yet started + time.Sleep(10 * time.Second) + } + a.d.ServeHTTP(rec, req.WithContext(context.WithValue(req.Context(), loggedInKey, true))) resp = rec.Result() } else { req.Header.Set(userAgent, appUserAgent) @@ -82,7 +86,7 @@ func (m *mention) verifyMention() error { err = m.verifyReader(resp.Body) _ = resp.Body.Close() if err != nil { - _, err := appDb.exec("delete from webmentions where source = @source and target = @target", sql.Named("source", m.Source), sql.Named("target", m.Target)) + _, err := a.db.exec("delete from webmentions where source = @source and target = @target", sql.Named("source", m.Source), sql.Named("target", m.Target)) return err } if len(m.Content) > 500 { @@ -92,13 +96,13 @@ func (m *mention) verifyMention() error { m.Title = m.Title[0:57] + "…" } newStatus := webmentionStatusVerified - if webmentionExists(m.Source, m.Target) { - _, err = appDb.exec("update webmentions set status = @status, title = @title, content = @content, author = @author where source = @source and target = @target", + if a.db.webmentionExists(m.Source, m.Target) { + _, err = a.db.exec("update webmentions set status = @status, title = @title, content = @content, author = @author where source = @source and target = @target", sql.Named("status", newStatus), sql.Named("title", m.Title), sql.Named("content", m.Content), sql.Named("author", m.Author), sql.Named("source", m.Source), sql.Named("target", m.Target)) } else { - _, err = appDb.exec("insert into webmentions (source, target, created, status, title, content, author) values (@source, @target, @created, @status, @title, @content, @author)", + _, err = a.db.exec("insert into webmentions (source, target, created, status, title, content, author) values (@source, @target, @created, @status, @title, @content, @author)", sql.Named("source", m.Source), sql.Named("target", m.Target), sql.Named("created", m.Created), sql.Named("status", newStatus), sql.Named("title", m.Title), sql.Named("content", m.Content), sql.Named("author", m.Author)) - sendNotification(fmt.Sprintf("New webmention from %s to %s", m.Source, m.Target)) + a.sendNotification(fmt.Sprintf("New webmention from %s to %s", m.Source, m.Target)) } return err }