Add support for Zstandard to the Compression middleware

This commit is contained in:
Antoine Aflalo 2024-06-12 05:38:04 -04:00 committed by GitHub
parent 3f48e6f8ef
commit b795f128d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 576 additions and 213 deletions

View file

@ -10,7 +10,7 @@ Compress Allows Compressing Responses before Sending them to the Client
![Compress](../../assets/img/middleware/compress.png)
The Compress middleware supports gzip and Brotli compression.
The Compress middleware supports Gzip, Brotli and Zstandard compression.
The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header.
## Configuration Examples
@ -54,8 +54,8 @@ http:
Responses are compressed when the following criteria are all met:
* The `Accept-Encoding` request header contains `gzip`, `*`, and/or `br` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
If the `Accept-Encoding` request header is absent, the response won't be encoded.
* The `Accept-Encoding` request header contains `gzip`, and/or `*`, and/or `br`, and/or `zstd` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
If the `Accept-Encoding` request header is absent and no [defaultEncoding](#defaultencoding) is configured, the response won't be encoded.
If it is present, but its value is the empty string, then compression is disabled.
* The response is not already compressed, i.e. the `Content-Encoding` response header is not already set.
* The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes).

View file

@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding"
const (
brotliName = "br"
gzipName = "gzip"
zstdName = "zstd"
identityName = "identity"
wildcardName = "*"
notAcceptable = "not_acceptable"
@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string {
return encoding.Type
}
for _, dt := range []string{brotliName, gzipName} {
for _, dt := range []string{zstdName, brotliName, gzipName} {
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
return dt
}
@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) {
}
switch parsed[0] {
case brotliName, gzipName, identityName, wildcardName:
case zstdName, brotliName, gzipName, identityName, wildcardName:
// supported encoding
default:
continue

View file

@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"gzip, br"},
expected: brotliName,
},
{
desc: "zstd > br > gzip (no weight)",
values: []string{"zstd, gzip, br"},
expected: zstdName,
},
{
desc: "known compression type (no weight)",
values: []string{"compress, gzip"},
@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"compress;q=1.0, gzip;q=0.5"},
expected: gzipName,
},
{
desc: "fallback on non-zero compression type",
values: []string{"compress;q=1.0, gzip, identity;q=0"},
expected: gzipName,
},
{
desc: "not acceptable (identity)",
values: []string{"compress;q=1.0, identity;q=0"},
@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
}{
{
desc: "weight",
values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"},
values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"},
expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName, Weight: ptr(0.9)},
{Type: gzipName, Weight: ptr(0.8)},
{Type: wildcardName, Weight: ptr(0.1)},
},
@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
},
{
desc: "mixed",
values: []string{"gzip, br;q=1.0, *;q=0"},
values: []string{"zstd,gzip, br;q=1.0, *;q=0"},
expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName},
{Type: gzipName},
{Type: wildcardName, Weight: ptr[float64](0)},
},
@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) {
},
{
desc: "no weight",
values: []string{"gzip, br, *"},
values: []string{"zstd, gzip, br, *"},
expected: []Encoding{
{Type: zstdName},
{Type: gzipName},
{Type: brotliName},
{Type: wildcardName},

View file

@ -11,7 +11,6 @@ import (
"github.com/klauspost/compress/gzhttp"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli"
"go.opentelemetry.io/otel/trace"
)
@ -32,6 +31,7 @@ type compress struct {
brotliHandler http.Handler
gzipHandler http.Handler
zstdHandler http.Handler
}
// New creates a new compress middleware.
@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
}
var err error
c.brotliHandler, err = c.newBrotliHandler()
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
if err != nil {
return nil, err
}
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
if err != nil {
return nil, err
}
@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
switch typ {
case zstdName:
c.zstdHandler.ServeHTTP(rw, req)
case brotliName:
c.brotliHandler.ServeHTTP(rw, req)
case gzipName:
@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
return wrapper(c.next), nil
}
func (c *compress) newBrotliHandler() (http.Handler, error) {
cfg := brotli.Config{MinSize: c.minSize}
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
if len(c.includes) > 0 {
cfg.IncludedContentTypes = c.includes
} else {
cfg.ExcludedContentTypes = c.excludes
}
wrapper, err := brotli.NewWrapper(cfg)
if err != nil {
return nil, fmt.Errorf("new brotli wrapper: %w", err)
}
return wrapper(c.next), nil
return NewCompressionHandler(cfg, c.next)
}

View file

@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) {
{
desc: "accept any header",
acceptEncHeader: "*",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "gzip accept header",
acceptEncHeader: "gzip",
expEncoding: "gzip",
expEncoding: gzipName,
},
{
desc: "br accept header",
acceptEncHeader: "br",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "multi accept header, prefer br",
acceptEncHeader: "br;q=0.8, gzip;q=0.6",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
expEncoding: "gzip",
expEncoding: gzipName,
},
{
desc: "multi accept header list, prefer br",
acceptEncHeader: "gzip, br",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "zstd accept header",
acceptEncHeader: "zstd",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer zstd",
acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7",
expEncoding: gzipName,
},
{
desc: "multi accept header list, prefer zstd",
acceptEncHeader: "gzip, br, zstd",
expEncoding: zstdName,
},
}

View file

@ -1,4 +1,4 @@
package brotli
package compress
import (
"bufio"
@ -10,6 +10,9 @@ import (
"net/http"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
)
const (
@ -30,10 +33,26 @@ type Config struct {
IncludedContentTypes []string
// MinSize is the minimum size (in bytes) required to enable compression.
MinSize int
// Algorithm used for the compression (currently Brotli and Zstandard)
Algorithm string
// MiddlewareName use for logging purposes
MiddlewareName string
}
// NewWrapper returns a new Brotli compressing wrapper.
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
// CompressionHandler handles Brolti and Zstd compression.
type CompressionHandler struct {
cfg Config
excludedContentTypes []parsedContentType
includedContentTypes []parsedContentType
next http.Handler
}
// NewCompressionHandler returns a new compressing handler.
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
if cfg.Algorithm == "" {
return nil, errors.New("compression algorithm undefined")
}
if cfg.MinSize < 0 {
return nil, errors.New("minimum size must be greater than or equal to zero")
}
@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
}
return func(h http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(vary, acceptEncoding)
brw := &responseWriter{
rw: rw,
bw: brotli.NewWriter(rw),
minSize: cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes,
}
defer brw.close()
h.ServeHTTP(brw, r)
}
return &CompressionHandler{
cfg: cfg,
excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes,
next: next,
}, nil
}
func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(vary, acceptEncoding)
compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw)
if err != nil {
logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName)
logMessage := fmt.Sprintf("create compression handler: %v", err)
logger.Debug().Msg(logMessage)
observability.SetStatusErrorf(r.Context(), logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
responseWriter := &responseWriter{
rw: rw,
compressionWriter: compressionWriter,
minSize: c.cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: c.excludedContentTypes,
includedContentTypes: c.includedContentTypes,
}
defer responseWriter.close()
c.next.ServeHTTP(responseWriter, r)
}
type compression interface {
// Write data to the encoder.
// Input data will be buffered and as the buffer fills up
// content will be compressed and written to the output.
// When done writing, use Close to flush the remaining output
// and write CRC if requested.
Write(p []byte) (n int, err error)
// Flush will send the currently written data to output
// and block until everything has been written.
// This should only be used on rare occasions where pushing the currently queued data is critical.
Flush() error
// Close closes the underlying writers if/when appropriate.
// Note that the compressed writer should not be closed if we never used it,
// as it would otherwise send some extra "end of compression" bytes.
// Close also makes sure to flush whatever was left to write from the buffer.
Close() error
}
type compressionWriter struct {
compression
alg string
}
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
switch algo {
case brotliName:
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
case zstdName:
writer, err := zstd.NewWriter(in)
if err != nil {
return nil, fmt.Errorf("creating zstd writer: %w", err)
}
return &compressionWriter{compression: writer, alg: algo}, nil
default:
return nil, fmt.Errorf("unknown compression algo: %s", algo)
}
}
func (c *compressionWriter) ContentEncoding() string {
return c.alg
}
// TODO: check whether we want to implement content-type sniffing (as gzip does)
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
type responseWriter struct {
rw http.ResponseWriter
bw *brotli.Writer
rw http.ResponseWriter
compressionWriter *compressionWriter
minSize int
excludedContentTypes []parsedContentType
@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// We are now in compression cruise mode until the end of times.
if r.compressionStarted {
// If compressionStarted we assume we have sent headers already
return r.bw.Write(p)
return r.compressionWriter.Write(p)
}
// If we detect a contentEncoding, we know we are never going to compress.
@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// Since we know we are going to compress we will never be able to know the actual length.
r.rw.Header().Del(contentLength)
r.rw.Header().Set(contentEncoding, "br")
r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding())
r.rw.WriteHeader(r.statusCode)
r.headersSent = true
// Start with sending what we have previously buffered, before actually writing
// the bytes in argument.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
r.buf = r.buf[n:]
// Return zero because we haven't taken care of the bytes in argument yet.
@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
r.buf = r.buf[:0]
// Now that we emptied the buffer, we can actually write the given bytes.
return r.bw.Write(p)
return r.compressionWriter.Write(p)
}
// Flush flushes data to the appropriate underlying writer(s), although it does
@ -250,7 +328,7 @@ func (r *responseWriter) Flush() {
// we have to do it ourselves.
defer func() {
// because we also ignore the error returned by Write anyway
_ = r.bw.Flush()
_ = r.compressionWriter.Flush()
if rw, ok := r.rw.(http.Flusher); ok {
rw.Flush()
@ -258,7 +336,7 @@ func (r *responseWriter) Flush() {
}()
// We empty whatever is left of the buffer that Write never took care of.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
return
}
@ -313,7 +391,7 @@ func (r *responseWriter) close() error {
if len(r.buf) == 0 {
// If we got here we know compression has started, so we can safely flush on bw.
return r.bw.Close()
return r.compressionWriter.Close()
}
// There is still data in the buffer, because we never reached minSize (to
@ -331,16 +409,16 @@ func (r *responseWriter) close() error {
// There is still data in the buffer, simply because Write did not take care of it all.
// We flush it to the compressed writer.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
r.bw.Close()
r.compressionWriter.Close()
return err
}
if n < len(r.buf) {
r.bw.Close()
r.compressionWriter.Close()
return io.ErrShortWrite
}
return r.bw.Close()
return r.compressionWriter.Close()
}
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.

View file

@ -1,4 +1,4 @@
package brotli
package compress
import (
"bytes"
@ -9,6 +9,7 @@ import (
"testing"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -19,44 +20,107 @@ var (
)
func Test_Vary(t *testing.T) {
h := newTestHandler(t, smallTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
})
}
}
func Test_SmallBodyNoCompression(t *testing.T) {
h := newTestHandler(t, smallTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
// With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes())
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
// With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes())
})
}
}
func Test_AlreadyCompressed(t *testing.T) {
h := newTestHandler(t, bigTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, bigTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, bigTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes())
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes())
})
}
}
func Test_NoBody(t *testing.T) {
@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(test.statusCode)
_, err := rw.Write(test.body)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -115,24 +181,26 @@ func Test_NoBody(t *testing.T) {
func Test_MinSize(t *testing.T) {
cfg := Config{
MinSize: 128,
MinSize: 128,
Algorithm: zstdName,
}
var bodySize int
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
for range bodySize {
// We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering
// because they haven't reached minSize.
_, err := rw.Write([]byte{'x'})
require.NoError(t, err)
}
},
))
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for range bodySize {
// We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering
// because they haven't reached minSize.
_, err := rw.Write([]byte{'x'})
require.NoError(t, err)
}
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
req.Header.Add(acceptEncoding, "br")
req.Header.Add(acceptEncoding, "zstd")
// Short response is not compressed
bodySize = cfg.MinSize - 1
@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) {
rw = httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding))
assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
}
func Test_MultipleWriteHeader(t *testing.T) {
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// We ensure that the subsequent call to WriteHeader is a noop.
rw.WriteHeader(http.StatusInternalServerError)
rw.WriteHeader(http.StatusNotFound)
}))
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -166,121 +236,255 @@ func Test_MultipleWriteHeader(t *testing.T) {
}
func Test_FlushBeforeWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})))
defer srv.Close()
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
req.Header.Set(acceptEncoding, "br")
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
defer res.Body.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
req.Header.Set(acceptEncoding, test.acceptEncoding)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(bigTestBody[0:1])
require.NoError(t, err)
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
for _, b := range bigTestBody[1:] {
_, err := rw.Write([]byte{b})
_, err := rw.Write(bigTestBody[0:1])
require.NoError(t, err)
rw.(http.Flusher).Flush()
for _, b := range bigTestBody[1:] {
_, err := rw.Write([]byte{b})
require.NoError(t, err)
}
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
}
})))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
req.Header.Set(acceptEncoding, "br")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWriteNil(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(nil)
require.NoError(t, err)
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
})))
defer srv.Close()
_, err := rw.Write(nil)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
rw.(http.Flusher).Flush()
})
req.Header.Set(acceptEncoding, "br")
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
defer res.Body.Close()
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Empty(t, got)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Empty(t, got)
})
}
}
func Test_FlushAfterAllWrites(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1])
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1])
require.NoError(t, err)
}
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
}
rw.(http.Flusher).Flush()
})))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
req.Header.Set(acceptEncoding, "br")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_ExcludedContentTypes(t *testing.T) {
@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
}
}
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
t.Helper()
w, err := NewWrapper(cfg)
w, err := NewCompressionHandler(cfg, next)
require.NoError(t, err)
return w
}
func newTestHandler(t *testing.T, body []byte) http.Handler {
func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
return mustNewWrapper(t, Config{MinSize: 1024})(
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", "br")
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", brotliName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
}),
)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
}
func TestParseContentType_equals(t *testing.T) {
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", zstdName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
}
func Test_ParseContentType_equals(t *testing.T) {
testCases := []struct {
desc string
pct parsedContentType