1
mirror of https://github.com/jlelse/GoBlog synced 2024-07-01 11:57:35 +00:00

Fix http compress middleware, add zstd

This commit is contained in:
Jan-Lukas Else 2024-05-14 17:41:22 +02:00
parent 48fa07a28c
commit e14104cd16
3 changed files with 28 additions and 31 deletions

View File

@ -14,7 +14,6 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/justinas/alice" "github.com/justinas/alice"
"github.com/klauspost/compress/flate"
"github.com/samber/lo" "github.com/samber/lo"
"go.goblog.app/app/pkgs/httpcompress" "go.goblog.app/app/pkgs/httpcompress"
"go.goblog.app/app/pkgs/maprouter" "go.goblog.app/app/pkgs/maprouter"
@ -41,7 +40,7 @@ func (a *goBlog) startServer() (err error) {
if a.cfg.Server.Logging { if a.cfg.Server.Logging {
h = h.Append(a.logMiddleware) h = h.Append(a.logMiddleware)
} }
h = h.Append(middleware.Recoverer, httpcompress.Compress(flate.BestCompression)) h = h.Append(middleware.Recoverer, httpcompress.Compress())
if a.cfg.Server.SecurityHeaders { if a.cfg.Server.SecurityHeaders {
h = h.Append(a.securityHeaders) h = h.Append(a.securityHeaders)
} }

View File

@ -3,8 +3,8 @@ package main
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"slices"
"github.com/samber/lo"
"github.com/tiptophelmet/cspolicy" "github.com/tiptophelmet/cspolicy"
"github.com/tiptophelmet/cspolicy/directives" "github.com/tiptophelmet/cspolicy/directives"
"github.com/tiptophelmet/cspolicy/directives/constraint" "github.com/tiptophelmet/cspolicy/directives/constraint"
@ -88,7 +88,7 @@ func keepSelectedQueryParams(paramsToKeep ...string) func(http.Handler) http.Han
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query() query := r.URL.Query()
for param := range query { for param := range query {
if !lo.Contains(paramsToKeep, param) { if !slices.Contains(paramsToKeep, param) {
query.Del(param) query.Del(param)
} }
} }

View File

@ -6,11 +6,13 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"slices"
"strings" "strings"
"sync" "sync"
"github.com/klauspost/compress/flate" "github.com/klauspost/compress/flate"
"github.com/klauspost/compress/gzip" "github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
"github.com/samber/lo" "github.com/samber/lo"
"go.goblog.app/app/pkgs/contenttype" "go.goblog.app/app/pkgs/contenttype"
@ -35,12 +37,9 @@ var defaultCompressibleContentTypes = []string{
// Compress is a middleware that compresses response // Compress is a middleware that compresses response
// body of a given content types to a data format based // body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given // on Accept-Encoding request header.
// compression level. func Compress(types ...string) func(next http.Handler) http.Handler {
// return NewCompressor(types...).Handler
// Passing a compression level of 5 is sensible value
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
return NewCompressor(level, types...).Handler
} }
// Compressor represents a set of encoding configurations. // Compressor represents a set of encoding configurations.
@ -51,15 +50,12 @@ type Compressor struct {
allowedTypes map[string]any allowedTypes map[string]any
// The list of encoders in order of decreasing precedence. // The list of encoders in order of decreasing precedence.
encodingPrecedence []string encodingPrecedence []string
// The compression level.
level int
} }
// NewCompressor creates a new Compressor that will handle encoding responses. // NewCompressor creates a new Compressor that will handle encoding responses.
// //
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed. // The types are the content types that are allowed to be compressed.
func NewCompressor(level int, types ...string) *Compressor { func NewCompressor(types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are // If types are provided, set those as the allowed types. If none are
// provided, use the default list. // provided, use the default list.
allowedTypes := lo.SliceToMap( allowedTypes := lo.SliceToMap(
@ -68,13 +64,13 @@ func NewCompressor(level int, types ...string) *Compressor {
) )
c := &Compressor{ c := &Compressor{
level: level,
pooledEncoders: map[string]*sync.Pool{}, pooledEncoders: map[string]*sync.Pool{},
allowedTypes: allowedTypes, allowedTypes: allowedTypes,
} }
c.SetEncoder("deflate", encoderDeflate) c.SetEncoder("deflate", encoderDeflate)
c.SetEncoder("gzip", encoderGzip) c.SetEncoder("gzip", encoderGzip)
c.SetEncoder("zstd", encoderZstd)
return c return c
} }
@ -94,20 +90,14 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
// Deleted already registered encoder // Deleted already registered encoder
delete(c.pooledEncoders, encoding) delete(c.pooledEncoders, encoding)
c.encodingPrecedence = slices.DeleteFunc(c.encodingPrecedence, func(e string) bool { return e == encoding })
// Register new encoder
c.pooledEncoders[encoding] = &sync.Pool{ c.pooledEncoders[encoding] = &sync.Pool{
New: func() any { New: func() any {
return fn(io.Discard, c.level) return fn(io.Discard)
}, },
} }
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
break
}
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...) c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
} }
@ -130,7 +120,7 @@ func (c *Compressor) Handler(next http.Handler) http.Handler {
// streaming compression algorithm and returns it. // streaming compression algorithm and returns it.
// //
// In case of failure, the function should return nil. // In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) compressWriter type EncoderFunc func(w io.Writer) compressWriter
// Interface for types that allow resetting io.Writers. // Interface for types that allow resetting io.Writers.
type compressWriter interface { type compressWriter interface {
@ -170,11 +160,11 @@ func (cw *compressResponseWriter) writer() io.Writer {
// selectEncoder returns the encoder, the name of the encoder, and a closer function. // selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) { func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) {
// Parse the names of all accepted algorithms from the header. // Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(cw.request.Header.Get("Accept-Encoding")), ",") accepted := strings.Split(strings.ToLower(strings.ReplaceAll(cw.request.Header.Get("Accept-Encoding"), " ", "")), ",")
// Find supported encoder by accepted list by precedence // Find supported encoder by accepted list by precedence
for _, name := range cw.compressor.encodingPrecedence { for _, name := range cw.compressor.encodingPrecedence {
if lo.Contains(accepted, name) { if slices.Contains(accepted, name) {
if pool, ok := cw.compressor.pooledEncoders[name]; ok { if pool, ok := cw.compressor.pooledEncoders[name]; ok {
encoder := pool.Get().(compressWriter) encoder := pool.Get().(compressWriter)
cleanup := func() { cleanup := func() {
@ -266,16 +256,24 @@ func (cw *compressResponseWriter) Close() error {
return errors.New("io.WriteCloser is unavailable on the writer") return errors.New("io.WriteCloser is unavailable on the writer")
} }
func encoderGzip(w io.Writer, level int) compressWriter { func encoderGzip(w io.Writer) compressWriter {
gw, err := gzip.NewWriterLevel(w, level) gw, err := gzip.NewWriterLevel(w, gzip.DefaultCompression)
if err != nil { if err != nil {
return nil return nil
} }
return gw return gw
} }
func encoderDeflate(w io.Writer, level int) compressWriter { func encoderDeflate(w io.Writer) compressWriter {
dw, err := flate.NewWriter(w, level) dw, err := flate.NewWriter(w, flate.DefaultCompression)
if err != nil {
return nil
}
return dw
}
func encoderZstd(w io.Writer) compressWriter {
dw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil { if err != nil {
return nil return nil
} }