1
mirror of https://github.com/jlelse/GoBlog synced 2024-07-15 12:22:58 +00:00

Refactor some http things

This commit is contained in:
Jan-Lukas Else 2024-06-29 13:22:55 +02:00
parent 45e03e2246
commit 02f0a34a16
7 changed files with 265 additions and 238 deletions

12
go.mod
View File

@ -2,8 +2,6 @@ module go.goblog.app/app
go 1.22.0
replace github.com/yuin/goldmark-emoji v1.0.2 => github.com/jlelse/goldmark-emoji v0.0.0-20240604064618-68e4be972ba7
require (
git.jlel.se/jlelse/go-geouri v0.0.0-20210525190615-a9c1d50f42d6
git.jlel.se/jlelse/go-shutdowner v0.0.0-20210707065515-773db8099c30
@ -24,7 +22,7 @@ require (
github.com/go-ap/activitypub v0.0.0-20240408091739-ba76b44c2594
github.com/go-ap/client v0.0.0-20240408093509-f0721baa55c2
github.com/go-ap/jsonld v0.0.0-20221030091449-f2a191312c73
github.com/go-chi/chi/v5 v5.0.14
github.com/go-chi/chi/v5 v5.1.0
github.com/go-fed/httpsig v1.1.0
github.com/google/uuid v1.6.0
github.com/gorilla/handlers v1.5.2
@ -44,7 +42,7 @@ require (
github.com/paulmach/go.geojson v1.5.0
github.com/posener/wstest v1.2.0
github.com/pquerna/otp v1.4.0
github.com/samber/lo v1.39.0
github.com/samber/lo v1.43.0
github.com/schollz/sqlite3dump v1.3.1
github.com/snabb/sitemap v1.0.4
github.com/sourcegraph/conc v0.3.0
@ -58,8 +56,8 @@ require (
github.com/traefik/yaegi v0.16.1
github.com/vcraescu/go-paginator/v2 v2.0.0
github.com/xhit/go-simple-mail/v2 v2.16.0
github.com/yuin/goldmark v1.7.2
github.com/yuin/goldmark-emoji v1.0.2
github.com/yuin/goldmark v1.7.4
github.com/yuin/goldmark-emoji v1.0.3
go.hacdias.com/indielib v0.3.0
golang.org/x/crypto v0.24.0
golang.org/x/net v0.26.0
@ -121,7 +119,7 @@ require (
go.mau.fi/util v0.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect
golang.org/x/image v0.17.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sys v0.21.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

20
go.sum
View File

@ -79,8 +79,8 @@ github.com/go-ap/errors v0.0.0-20240304112515-6077fa9c17b0 h1:H9MGShwybHLSln6K8R
github.com/go-ap/errors v0.0.0-20240304112515-6077fa9c17b0/go.mod h1:5x8a6P/dhmMGFxWLcyYlyOuJ2lRNaHGhRv+yu8BaTSI=
github.com/go-ap/jsonld v0.0.0-20221030091449-f2a191312c73 h1:GMKIYXyXPGIp+hYiWOhfqK4A023HdgisDT4YGgf99mw=
github.com/go-ap/jsonld v0.0.0-20221030091449-f2a191312c73/go.mod h1:jyveZeGw5LaADntW+UEsMjl3IlIwk+DxlYNsbofQkGA=
github.com/go-chi/chi/v5 v5.0.14 h1:PyEwo2Vudraa0x/Wl6eDRRW2NXBvekgfxyydcM0WGE0=
github.com/go-chi/chi/v5 v5.0.14/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-fed/httpsig v1.1.0 h1:9M+hb0jkEICD8/cAiNqEB66R87tTINszBRTjwjQzWcI=
github.com/go-fed/httpsig v1.1.0/go.mod h1:RCMrTZvN1bJYtofsG4rd5NaO5obxQ5xBkdiS7xsT7bM=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
@ -141,8 +141,6 @@ github.com/jlaffaye/ftp v0.2.0 h1:lXNvW7cBu7R/68bknOX3MrRIIqZ61zELs1P2RAiA3lg=
github.com/jlaffaye/ftp v0.2.0/go.mod h1:is2Ds5qkhceAPy2xD6RLI6hmp/qysSoymZ+Z2uTnspI=
github.com/jlelse/feeds v1.3.0 h1:Vdks2qJ3XyxLYPle2UYa2Ucpw6GB48pBvpARJxz9fys=
github.com/jlelse/feeds v1.3.0/go.mod h1:2cAT6A2cQ4zcIz3FrCZKGXjHuJiGYe62MeM46/R0RxM=
github.com/jlelse/goldmark-emoji v0.0.0-20240604064618-68e4be972ba7 h1:eK8tIO23uVoX6pLxDtgKMAsMau+Ia8CM9c2BHGyl2HU=
github.com/jlelse/goldmark-emoji v0.0.0-20240604064618-68e4be972ba7/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg=
github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@ -218,8 +216,8 @@ github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3
github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.43.0 h1:ts0VhPi8+ZQZFVLv/2Vkgt2Cds05FM2v3Enmv+YMBtg=
github.com/samber/lo v1.43.0/go.mod h1:w7R6fO7h2lrnx/s0bWcZ55vXJI89p5UPM6+kyDL373E=
github.com/schollz/sqlite3dump v1.3.1 h1:QXizJ7XEJ7hggjqjZ3YRtF3+javm8zKtzNByYtEkPRA=
github.com/schollz/sqlite3dump v1.3.1/go.mod h1:mzSTjZpJH4zAb1FN3iNlhWPbbdyeBpOaTW0hukyMHyI=
github.com/scylladb/termtables v0.0.0-20191203121021-c4c0b6d42ff4/go.mod h1:C1a7PQSMz9NShzorzCiG2fk9+xuCgLkPeCvMHYR2OWg=
@ -291,8 +289,10 @@ github.com/xhit/go-simple-mail/v2 v2.16.0/go.mod h1:b7P5ygho6SYE+VIqpxA6QkYfv4te
github.com/yuin/goldmark v1.3.7/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc=
github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg=
github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4=
github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
go.hacdias.com/indielib v0.3.0 h1:rbF1gnxSiFtEUQjnmqVdOzofPd5pO/3KEoa7iBIu2Is=
go.hacdias.com/indielib v0.3.0/go.mod h1:6wtl0LcTQ1JPoNld1yVy29qBEPlHeoLBKnGXYg8+dO4=
go.mau.fi/util v0.5.0 h1:8yELAl+1CDRrwGe9NUmREgVclSs26Z68pTWePHVxuDo=
@ -307,8 +307,8 @@ golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5D
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.17.0 h1:nTRVVdajgB8zCMZVsViyzhnMKPwYeroEERRC64JuLco=
golang.org/x/image v0.17.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=

View File

@ -40,7 +40,7 @@ func (a *goBlog) startServer() (err error) {
if a.cfg.Server.Logging {
h = h.Append(a.logMiddleware)
}
h = h.Append(middleware.Recoverer, httpcompress.Compress())
h = h.Append(middleware.Recoverer, httpcompress.CompressMiddleware)
if a.cfg.Server.SecurityHeaders {
h = h.Append(a.securityHeaders)
}

View File

@ -65,7 +65,7 @@ func (a *goBlog) webmentionsRouter(r chi.Router) {
return
}
// Endpoint
r.With(bodylimit.BodyLimit(bodylimit.MB)).Post("/", a.handleWebmention)
r.With(bodylimit.BodyLimit(10*bodylimit.KB)).Post("/", a.handleWebmention)
// Authenticated routes
r.Group(func(r chi.Router) {
r.Use(a.authMiddleware)
@ -125,7 +125,7 @@ func (a *goBlog) otherRoutesRouter(r chi.Router) {
// Reactions
if a.reactionsEnabled() {
r.Get("/reactions", a.getReactions)
r.With(bodylimit.BodyLimit(100*bodylimit.KB)).Post("/reactions", a.postReaction)
r.With(bodylimit.BodyLimit(10*bodylimit.KB)).Post("/reactions", a.postReaction)
}
}
@ -312,7 +312,7 @@ func (a *goBlog) blogSearchRouter(conf *configBlog) func(r chi.Router) {
middleware.WithValue(pathKey, searchPath),
)
r.Get("/", a.serveSearch)
r.With(bodylimit.BodyLimit(100*bodylimit.KB)).Post("/", a.serveSearch)
r.With(bodylimit.BodyLimit(10*bodylimit.KB)).Post("/", a.serveSearch)
searchResultPath := "/" + searchPlaceholder
r.Get(searchResultPath, a.serveSearchResult)
r.Get(searchResultPath+feedPath, a.serveSearchResult)

View File

@ -2,270 +2,176 @@ package httpcompress
import (
"bufio"
"errors"
"io"
"net"
"net/http"
"slices"
"strings"
"sync"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
"github.com/samber/lo"
"go.goblog.app/app/pkgs/contenttype"
)
var defaultCompressibleContentTypes = []string{
contenttype.AS,
contenttype.ATOM,
contenttype.CSS,
contenttype.HTML,
contenttype.JS,
contenttype.JSON,
contenttype.JSONFeed,
contenttype.LDJSON,
contenttype.RSS,
contenttype.Text,
contenttype.XML,
"application/opensearchdescription+xml",
"application/jrd+json",
"application/xrd+xml",
}
// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header.
func Compress(types ...string) func(next http.Handler) http.Handler {
return NewCompressor(types...).Handler
}
// Compressor represents a set of encoding configurations.
type Compressor struct {
// The mapping of pooled encoders to pools.
pooledEncoders map[string]*sync.Pool
// The set of content types allowed to be compressed.
allowedTypes map[string]any
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The types are the content types that are allowed to be compressed.
func NewCompressor(types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are
// provided, use the default list.
allowedTypes := lo.SliceToMap(
lo.If(len(types) > 0, types).Else(defaultCompressibleContentTypes),
func(t string) (string, any) { return t, nil },
)
c := &Compressor{
pooledEncoders: map[string]*sync.Pool{},
allowedTypes: allowedTypes,
}
c.SetEncoder("gzip", encoderGzip)
c.SetEncoder("zstd", encoderZstd)
return c
}
// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardised identifier. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
// Deleted already registered encoder
delete(c.pooledEncoders, encoding)
c.encodingPrecedence = slices.DeleteFunc(c.encodingPrecedence, func(e string) bool { return e == encoding })
// Register new encoder
c.pooledEncoders[encoding] = &sync.Pool{
New: func() any {
return fn(io.Discard)
var (
zstdWriterPool = sync.Pool{
New: func() interface{} {
w, _ := zstd.NewWriter(nil)
return w
},
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}
gzipWriterPool = sync.Pool{
New: func() interface{} {
w, _ := gzip.NewWriterLevel(nil, gzip.DefaultCompression)
return w
},
}
// Global list of compressible content types
compressibleTypes = []string{
contenttype.AS,
contenttype.ATOM,
contenttype.CSS,
contenttype.HTML,
contenttype.JS,
contenttype.JSON,
contenttype.JSONFeed,
contenttype.LDJSON,
contenttype.RSS,
contenttype.Text,
contenttype.XML,
"application/opensearchdescription+xml",
"application/jrd+json",
"application/xrd+xml",
}
)
// Handler returns a new middleware that will compress the response based on the
// current Compressor.
func (c *Compressor) Handler(next http.Handler) http.Handler {
func CompressMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &compressResponseWriter{
compressor: c,
ResponseWriter: w,
request: r,
acceptEncoding := r.Header.Get("Accept-Encoding")
supportsGzip := strings.Contains(acceptEncoding, "gzip")
supportsZstd := strings.Contains(acceptEncoding, "zstd")
if !supportsGzip && !supportsZstd {
next.ServeHTTP(w, r)
return
}
cw := &compressWriter{
ResponseWriter: w,
supportsGzip: supportsGzip,
supportsZstd: supportsZstd,
}
defer cw.Close()
next.ServeHTTP(cw, r)
_ = cw.Close()
cw.doCleanup()
})
}
// An EncoderFunc is a function that wraps the provided io.Writer with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer) compressWriter
// Interface for types that allow resetting io.Writers.
type compressWriter interface {
io.Writer
Reset(w io.Writer)
Flush() error
type compressWriter struct {
http.ResponseWriter
supportsGzip, supportsZstd bool
writer io.Writer
contentType string
headerWritten, compressionSet bool
statusCode int
}
type compressResponseWriter struct {
http.ResponseWriter // The response writer to delegate to.
encoder compressWriter // The encoder to use (if any).
cleanup func() // Cleanup function to reset and repool encoder.
compressor *Compressor // Holds the compressor configuration.
request *http.Request // The request that is being handled.
wroteHeader bool // Whether the header has been written.
}
func (cw *compressResponseWriter) isCompressable() bool {
// Parse the first part of the Content-Type response header.
contentType := cw.Header().Get("Content-Type")
if idx := strings.Index(contentType, ";"); idx >= 0 {
contentType = contentType[0:idx]
func (cw *compressWriter) WriteHeader(statusCode int) {
if cw.headerWritten {
return
}
cw.statusCode = statusCode
// Is the content type compressable?
_, ok := cw.compressor.allowedTypes[contentType]
return ok
cw.contentType = cw.Header().Get("Content-Type")
cw.setupCompression()
cw.ResponseWriter.WriteHeader(statusCode)
cw.headerWritten = true
}
func (cw *compressResponseWriter) writer() io.Writer {
if cw.encoder != nil {
return cw.encoder
func (cw *compressWriter) Write(p []byte) (int, error) {
if !cw.headerWritten {
cw.WriteHeader(http.StatusOK)
}
return cw.ResponseWriter
return cw.writer.Write(p)
}
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) {
// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(strings.ReplaceAll(cw.request.Header.Get("Accept-Encoding"), " ", "")), ",")
func (cw *compressWriter) setupCompression() {
if cw.compressionSet {
return
}
cw.compressionSet = true
// Find supported encoder by accepted list by precedence
for _, name := range cw.compressor.encodingPrecedence {
if slices.Contains(accepted, name) {
if pool, ok := cw.compressor.pooledEncoders[name]; ok {
encoder := pool.Get().(compressWriter)
cleanup := func() {
encoder.Reset(nil)
pool.Put(encoder)
}
encoder.Reset(cw.ResponseWriter)
return encoder, name, cleanup
}
shouldCompress := false
for _, t := range compressibleTypes {
if strings.HasPrefix(cw.contentType, t) {
shouldCompress = true
break
}
}
// No encoder found to match the accepted encoding
return nil, "", nil
}
func (cw *compressResponseWriter) doCleanup() {
if cw.encoder != nil {
cw.encoder = nil
cw.cleanup()
cw.cleanup = nil
}
}
func (cw *compressResponseWriter) WriteHeader(code int) {
defer cw.ResponseWriter.WriteHeader(code)
if cw.wroteHeader {
return
}
cw.wroteHeader = true
if cw.Header().Get("Content-Encoding") != "" {
// Data has already been compressed.
return
}
if !cw.isCompressable() {
// Data is not compressable.
return
}
var encoding string
cw.encoder, encoding, cw.cleanup = cw.selectEncoder()
if encoding != "" {
cw.Header().Set("Content-Encoding", encoding)
if shouldCompress {
if cw.supportsZstd {
zw := zstdWriterPool.Get().(*zstd.Encoder)
zw.Reset(cw.ResponseWriter)
cw.writer = zw
cw.Header().Set("Content-Encoding", "zstd")
} else if cw.supportsGzip {
gw := gzipWriterPool.Get().(*gzip.Writer)
gw.Reset(cw.ResponseWriter)
cw.writer = gw
cw.Header().Set("Content-Encoding", "gzip")
}
cw.Header().Add("Vary", "Accept-Encoding")
// The content-length after compression is unknown
cw.Header().Del("Content-Length")
} else {
cw.writer = cw.ResponseWriter
}
}
func (cw *compressResponseWriter) Write(p []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)
func (cw *compressWriter) Close() (err error) {
if cw.writer != nil {
if zw, ok := cw.writer.(*zstd.Encoder); ok {
err = zw.Close()
zw.Reset(io.Discard)
zstdWriterPool.Put(zw)
} else if gw, ok := cw.writer.(*gzip.Writer); ok {
err = gw.Close()
gw.Reset(io.Discard)
gzipWriterPool.Put(gw)
}
cw.writer = nil
}
return cw.writer().Write(p)
return err
}
func (cw *compressResponseWriter) Flush() {
if cw.encoder != nil {
cw.encoder.Flush()
// Flush implements the http.Flusher interface.
func (cw *compressWriter) Flush() {
if !cw.headerWritten {
cw.WriteHeader(cw.statusCode)
}
if f, ok := cw.ResponseWriter.(http.Flusher); ok {
if cw.writer != nil {
if gw, ok := cw.writer.(*gzip.Writer); ok {
gw.Flush()
}
// Note: zstd.Encoder doesn't have a Flush method
}
f.Flush()
}
}
func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := cw.writer().(http.Hijacker); ok {
// Hijack implements the http.Hijacker interface.
func (cw *compressWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := cw.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errors.New("http.Hijacker is unavailable on the writer")
return nil, nil, http.ErrNotSupported
}
func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
if ps, ok := cw.writer().(http.Pusher); ok {
return ps.Push(target, opts)
// Push implements the http.Pusher interface.
func (cw *compressWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := cw.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return errors.New("http.Pusher is unavailable on the writer")
}
func (cw *compressResponseWriter) Close() error {
if c, ok := cw.writer().(io.WriteCloser); ok {
return c.Close()
}
return errors.New("io.WriteCloser is unavailable on the writer")
}
func encoderGzip(w io.Writer) compressWriter {
gw, err := gzip.NewWriterLevel(w, 5)
if err != nil {
return nil
}
return gw
}
func encoderZstd(w io.Writer) compressWriter {
dw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
return nil
}
return dw
return http.ErrNotSupported
}

View File

@ -0,0 +1,118 @@
package httpcompress
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
)
func TestCompressMiddleware(t *testing.T) {
tests := []struct {
name string
acceptEncoding string
contentType string
body string
shouldEncode bool
}{
{
name: "No compression",
acceptEncoding: "",
contentType: "text/plain",
body: "Hello, World!",
shouldEncode: false,
},
{
name: "Gzip compression",
acceptEncoding: "gzip",
contentType: "text/html",
body: "<html><body>Hello, World!</body></html>",
shouldEncode: true,
},
{
name: "Zstd compression",
acceptEncoding: "zstd",
contentType: "application/json",
body: `{"message": "Hello, World!"}`,
shouldEncode: true,
},
{
name: "Non-compressible content type",
acceptEncoding: "gzip,zstd",
contentType: "image/jpeg",
body: "binary data",
shouldEncode: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(http.StatusOK)
w.Write([]byte(tt.body))
})
compressHandler := CompressMiddleware(handler)
req := httptest.NewRequest("GET", "http://example.com", nil)
req.Header.Set("Accept-Encoding", tt.acceptEncoding)
rec := httptest.NewRecorder()
compressHandler.ServeHTTP(rec, req)
resp := rec.Result()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, resp.StatusCode)
}
decompressedBody, err := decompressBody(body, resp.Header.Get("Content-Encoding"))
if err != nil {
t.Fatalf("Failed to decompress body: %v", err)
}
if string(decompressedBody) != tt.body {
t.Errorf("Response body doesn't match. Expected %q, got %q", tt.body, string(decompressedBody))
}
contentEncoding := resp.Header.Get("Content-Encoding")
if tt.shouldEncode {
if contentEncoding == "" {
t.Errorf("Expected Content-Encoding header to be set for compressible content")
}
if resp.Header.Get("Vary") != "Accept-Encoding" {
t.Errorf("Expected Vary header to be set")
}
} else if contentEncoding != "" {
t.Errorf("Content-Encoding header should not be set for non-compressible content")
}
})
}
}
func decompressBody(body []byte, encoding string) ([]byte, error) {
switch encoding {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
case "zstd":
reader, err := zstd.NewReader(bytes.NewReader(body))
if err != nil {
return nil, err
}
defer reader.Close()
return io.ReadAll(reader)
default:
return body, nil
}
}

View File

@ -13,6 +13,7 @@ import (
"time"
"github.com/samber/lo"
"go.goblog.app/app/pkgs/bodylimit"
"go.goblog.app/app/pkgs/bufferpool"
"go.goblog.app/app/pkgs/contenttype"
)
@ -69,7 +70,9 @@ func (a *goBlog) verifyMention(m *mention) error {
}
}
// Request source
sourceReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, m.Source, nil)
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer timeoutCancel()
sourceReq, err := http.NewRequestWithContext(timeoutCtx, http.MethodGet, m.Source, nil)
if err != nil {
return err
}
@ -137,7 +140,7 @@ func (a *goBlog) verifyReader(m *mention, body io.Reader) error {
defer bufferpool.Put(mfBuffer)
pr, pw := io.Pipe()
go func() {
_, err := io.Copy(io.MultiWriter(pw, mfBuffer), body)
_, err := io.Copy(io.MultiWriter(pw, mfBuffer), io.LimitReader(body, 10*bodylimit.MB))
_ = pw.CloseWithError(err)
}()
// Check if source mentions target
@ -154,7 +157,9 @@ func (a *goBlog) verifyReader(m *mention, body io.Reader) error {
return false
}
// Check if link is or redirects to target
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, m.Target, nil)
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer timeoutCancel()
req, err := http.NewRequestWithContext(timeoutCtx, http.MethodGet, m.Target, nil)
if err != nil {
return false
}