1
mirror of https://github.com/jlelse/GoBlog synced 2024-06-29 09:07:35 +00:00

Fix http compress middleware, add zstd

This commit is contained in:
Jan-Lukas Else 2024-05-14 17:41:22 +02:00
parent 48fa07a28c
commit e14104cd16
3 changed files with 28 additions and 31 deletions

View File

@ -14,7 +14,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/justinas/alice"
"github.com/klauspost/compress/flate"
"github.com/samber/lo"
"go.goblog.app/app/pkgs/httpcompress"
"go.goblog.app/app/pkgs/maprouter"
@ -41,7 +40,7 @@ func (a *goBlog) startServer() (err error) {
if a.cfg.Server.Logging {
h = h.Append(a.logMiddleware)
}
h = h.Append(middleware.Recoverer, httpcompress.Compress(flate.BestCompression))
h = h.Append(middleware.Recoverer, httpcompress.Compress())
if a.cfg.Server.SecurityHeaders {
h = h.Append(a.securityHeaders)
}

View File

@ -3,8 +3,8 @@ package main
import (
"net/http"
"net/url"
"slices"
"github.com/samber/lo"
"github.com/tiptophelmet/cspolicy"
"github.com/tiptophelmet/cspolicy/directives"
"github.com/tiptophelmet/cspolicy/directives/constraint"
@ -88,7 +88,7 @@ func keepSelectedQueryParams(paramsToKeep ...string) func(http.Handler) http.Han
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
for param := range query {
if !lo.Contains(paramsToKeep, param) {
if !slices.Contains(paramsToKeep, param) {
query.Del(param)
}
}

View File

@ -6,11 +6,13 @@ import (
"io"
"net"
"net/http"
"slices"
"strings"
"sync"
"github.com/klauspost/compress/flate"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
"github.com/samber/lo"
"go.goblog.app/app/pkgs/contenttype"
@ -35,12 +37,9 @@ var defaultCompressibleContentTypes = []string{
// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// Passing a compression level of 5 is sensible value
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
return NewCompressor(level, types...).Handler
// on Accept-Encoding request header.
func Compress(types ...string) func(next http.Handler) http.Handler {
return NewCompressor(types...).Handler
}
// Compressor represents a set of encoding configurations.
@ -51,15 +50,12 @@ type Compressor struct {
allowedTypes map[string]any
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
// The compression level.
level int
}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed.
func NewCompressor(level int, types ...string) *Compressor {
func NewCompressor(types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are
// provided, use the default list.
allowedTypes := lo.SliceToMap(
@ -68,13 +64,13 @@ func NewCompressor(level int, types ...string) *Compressor {
)
c := &Compressor{
level: level,
pooledEncoders: map[string]*sync.Pool{},
allowedTypes: allowedTypes,
}
c.SetEncoder("deflate", encoderDeflate)
c.SetEncoder("gzip", encoderGzip)
c.SetEncoder("zstd", encoderZstd)
return c
}
@ -94,20 +90,14 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
// Deleted already registered encoder
delete(c.pooledEncoders, encoding)
c.encodingPrecedence = slices.DeleteFunc(c.encodingPrecedence, func(e string) bool { return e == encoding })
// Register new encoder
c.pooledEncoders[encoding] = &sync.Pool{
New: func() any {
return fn(io.Discard, c.level)
return fn(io.Discard)
},
}
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
break
}
}
c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}
@ -130,7 +120,7 @@ func (c *Compressor) Handler(next http.Handler) http.Handler {
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) compressWriter
type EncoderFunc func(w io.Writer) compressWriter
// Interface for types that allow resetting io.Writers.
type compressWriter interface {
@ -170,11 +160,11 @@ func (cw *compressResponseWriter) writer() io.Writer {
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (cw *compressResponseWriter) selectEncoder() (compressWriter, string, func()) {
// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(cw.request.Header.Get("Accept-Encoding")), ",")
accepted := strings.Split(strings.ToLower(strings.ReplaceAll(cw.request.Header.Get("Accept-Encoding"), " ", "")), ",")
// Find supported encoder by accepted list by precedence
for _, name := range cw.compressor.encodingPrecedence {
if lo.Contains(accepted, name) {
if slices.Contains(accepted, name) {
if pool, ok := cw.compressor.pooledEncoders[name]; ok {
encoder := pool.Get().(compressWriter)
cleanup := func() {
@ -266,16 +256,24 @@ func (cw *compressResponseWriter) Close() error {
return errors.New("io.WriteCloser is unavailable on the writer")
}
func encoderGzip(w io.Writer, level int) compressWriter {
gw, err := gzip.NewWriterLevel(w, level)
func encoderGzip(w io.Writer) compressWriter {
gw, err := gzip.NewWriterLevel(w, gzip.DefaultCompression)
if err != nil {
return nil
}
return gw
}
func encoderDeflate(w io.Writer, level int) compressWriter {
dw, err := flate.NewWriter(w, level)
func encoderDeflate(w io.Writer) compressWriter {
dw, err := flate.NewWriter(w, flate.DefaultCompression)
if err != nil {
return nil
}
return dw
}
func encoderZstd(w io.Writer) compressWriter {
dw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
return nil
}