Add support for Zstandard to the Compression middleware

This commit is contained in:
Antoine Aflalo 2024-06-12 05:38:04 -04:00 committed by GitHub
parent 3f48e6f8ef
commit b795f128d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 576 additions and 213 deletions

View file

@ -10,7 +10,7 @@ Compress Allows Compressing Responses before Sending them to the Client
![Compress](../../assets/img/middleware/compress.png) ![Compress](../../assets/img/middleware/compress.png)
The Compress middleware supports gzip and Brotli compression. The Compress middleware supports Gzip, Brotli and Zstandard compression.
The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header. The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header.
## Configuration Examples ## Configuration Examples
@ -54,8 +54,8 @@ http:
Responses are compressed when the following criteria are all met: Responses are compressed when the following criteria are all met:
* The `Accept-Encoding` request header contains `gzip`, `*`, and/or `br` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values). * The `Accept-Encoding` request header contains `gzip`, and/or `*`, and/or `br`, and/or `zstd` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
If the `Accept-Encoding` request header is absent, the response won't be encoded. If the `Accept-Encoding` request header is absent and no [defaultEncoding](#defaultencoding) is configured, the response won't be encoded.
If it is present, but its value is the empty string, then compression is disabled. If it is present, but its value is the empty string, then compression is disabled.
* The response is not already compressed, i.e. the `Content-Encoding` response header is not already set. * The response is not already compressed, i.e. the `Content-Encoding` response header is not already set.
* The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes). * The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes).

View file

@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding"
const ( const (
brotliName = "br" brotliName = "br"
gzipName = "gzip" gzipName = "gzip"
zstdName = "zstd"
identityName = "identity" identityName = "identity"
wildcardName = "*" wildcardName = "*"
notAcceptable = "not_acceptable" notAcceptable = "not_acceptable"
@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string {
return encoding.Type return encoding.Type
} }
for _, dt := range []string{brotliName, gzipName} { for _, dt := range []string{zstdName, brotliName, gzipName} {
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) { if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
return dt return dt
} }
@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) {
} }
switch parsed[0] { switch parsed[0] {
case brotliName, gzipName, identityName, wildcardName: case zstdName, brotliName, gzipName, identityName, wildcardName:
// supported encoding // supported encoding
default: default:
continue continue

View file

@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"gzip, br"}, values: []string{"gzip, br"},
expected: brotliName, expected: brotliName,
}, },
{
desc: "zstd > br > gzip (no weight)",
values: []string{"zstd, gzip, br"},
expected: zstdName,
},
{ {
desc: "known compression type (no weight)", desc: "known compression type (no weight)",
values: []string{"compress, gzip"}, values: []string{"compress, gzip"},
@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"compress;q=1.0, gzip;q=0.5"}, values: []string{"compress;q=1.0, gzip;q=0.5"},
expected: gzipName, expected: gzipName,
}, },
{
desc: "fallback on non-zero compression type",
values: []string{"compress;q=1.0, gzip, identity;q=0"},
expected: gzipName,
},
{ {
desc: "not acceptable (identity)", desc: "not acceptable (identity)",
values: []string{"compress;q=1.0, identity;q=0"}, values: []string{"compress;q=1.0, identity;q=0"},
@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
}{ }{
{ {
desc: "weight", desc: "weight",
values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"}, values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"},
expected: []Encoding{ expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)}, {Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName, Weight: ptr(0.9)},
{Type: gzipName, Weight: ptr(0.8)}, {Type: gzipName, Weight: ptr(0.8)},
{Type: wildcardName, Weight: ptr(0.1)}, {Type: wildcardName, Weight: ptr(0.1)},
}, },
@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
}, },
{ {
desc: "mixed", desc: "mixed",
values: []string{"gzip, br;q=1.0, *;q=0"}, values: []string{"zstd,gzip, br;q=1.0, *;q=0"},
expected: []Encoding{ expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)}, {Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName},
{Type: gzipName}, {Type: gzipName},
{Type: wildcardName, Weight: ptr[float64](0)}, {Type: wildcardName, Weight: ptr[float64](0)},
}, },
@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) {
}, },
{ {
desc: "no weight", desc: "no weight",
values: []string{"gzip, br, *"}, values: []string{"zstd, gzip, br, *"},
expected: []Encoding{ expected: []Encoding{
{Type: zstdName},
{Type: gzipName}, {Type: gzipName},
{Type: brotliName}, {Type: brotliName},
{Type: wildcardName}, {Type: wildcardName},

View file

@ -11,7 +11,6 @@ import (
"github.com/klauspost/compress/gzhttp" "github.com/klauspost/compress/gzhttp"
"github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares" "github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
) )
@ -32,6 +31,7 @@ type compress struct {
brotliHandler http.Handler brotliHandler http.Handler
gzipHandler http.Handler gzipHandler http.Handler
zstdHandler http.Handler
} }
// New creates a new compress middleware. // New creates a new compress middleware.
@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
} }
var err error var err error
c.brotliHandler, err = c.newBrotliHandler()
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
if err != nil {
return nil, err
}
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) { func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
switch typ { switch typ {
case zstdName:
c.zstdHandler.ServeHTTP(rw, req)
case brotliName: case brotliName:
c.brotliHandler.ServeHTTP(rw, req) c.brotliHandler.ServeHTTP(rw, req)
case gzipName: case gzipName:
@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
return wrapper(c.next), nil return wrapper(c.next), nil
} }
func (c *compress) newBrotliHandler() (http.Handler, error) { func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
cfg := brotli.Config{MinSize: c.minSize} cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
if len(c.includes) > 0 { if len(c.includes) > 0 {
cfg.IncludedContentTypes = c.includes cfg.IncludedContentTypes = c.includes
} else { } else {
cfg.ExcludedContentTypes = c.excludes cfg.ExcludedContentTypes = c.excludes
} }
wrapper, err := brotli.NewWrapper(cfg) return NewCompressionHandler(cfg, c.next)
if err != nil {
return nil, fmt.Errorf("new brotli wrapper: %w", err)
}
return wrapper(c.next), nil
} }

View file

@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) {
{ {
desc: "accept any header", desc: "accept any header",
acceptEncHeader: "*", acceptEncHeader: "*",
expEncoding: "br", expEncoding: brotliName,
}, },
{ {
desc: "gzip accept header", desc: "gzip accept header",
acceptEncHeader: "gzip", acceptEncHeader: "gzip",
expEncoding: "gzip", expEncoding: gzipName,
}, },
{ {
desc: "br accept header", desc: "br accept header",
acceptEncHeader: "br", acceptEncHeader: "br",
expEncoding: "br", expEncoding: brotliName,
}, },
{ {
desc: "multi accept header, prefer br", desc: "multi accept header, prefer br",
acceptEncHeader: "br;q=0.8, gzip;q=0.6", acceptEncHeader: "br;q=0.8, gzip;q=0.6",
expEncoding: "br", expEncoding: brotliName,
}, },
{ {
desc: "multi accept header, prefer gzip", desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8", acceptEncHeader: "gzip;q=1.0, br;q=0.8",
expEncoding: "gzip", expEncoding: gzipName,
}, },
{ {
desc: "multi accept header list, prefer br", desc: "multi accept header list, prefer br",
acceptEncHeader: "gzip, br", acceptEncHeader: "gzip, br",
expEncoding: "br", expEncoding: brotliName,
},
{
desc: "zstd accept header",
acceptEncHeader: "zstd",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer zstd",
acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7",
expEncoding: gzipName,
},
{
desc: "multi accept header list, prefer zstd",
acceptEncHeader: "gzip, br, zstd",
expEncoding: zstdName,
}, },
} }

View file

@ -1,4 +1,4 @@
package brotli package compress
import ( import (
"bufio" "bufio"
@ -10,6 +10,9 @@ import (
"net/http" "net/http"
"github.com/andybalholm/brotli" "github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
) )
const ( const (
@ -30,10 +33,26 @@ type Config struct {
IncludedContentTypes []string IncludedContentTypes []string
// MinSize is the minimum size (in bytes) required to enable compression. // MinSize is the minimum size (in bytes) required to enable compression.
MinSize int MinSize int
// Algorithm used for the compression (currently Brotli and Zstandard)
Algorithm string
// MiddlewareName use for logging purposes
MiddlewareName string
} }
// NewWrapper returns a new Brotli compressing wrapper. // CompressionHandler handles Brolti and Zstd compression.
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) { type CompressionHandler struct {
cfg Config
excludedContentTypes []parsedContentType
includedContentTypes []parsedContentType
next http.Handler
}
// NewCompressionHandler returns a new compressing handler.
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
if cfg.Algorithm == "" {
return nil, errors.New("compression algorithm undefined")
}
if cfg.MinSize < 0 { if cfg.MinSize < 0 {
return nil, errors.New("minimum size must be greater than or equal to zero") return nil, errors.New("minimum size must be greater than or equal to zero")
} }
@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params}) includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
} }
return func(h http.Handler) http.HandlerFunc { return &CompressionHandler{
return func(rw http.ResponseWriter, r *http.Request) { cfg: cfg,
rw.Header().Add(vary, acceptEncoding)
brw := &responseWriter{
rw: rw,
bw: brotli.NewWriter(rw),
minSize: cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: excludedContentTypes, excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes, includedContentTypes: includedContentTypes,
} next: next,
defer brw.close()
h.ServeHTTP(brw, r)
}
}, nil }, nil
} }
func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(vary, acceptEncoding)
compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw)
if err != nil {
logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName)
logMessage := fmt.Sprintf("create compression handler: %v", err)
logger.Debug().Msg(logMessage)
observability.SetStatusErrorf(r.Context(), logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
responseWriter := &responseWriter{
rw: rw,
compressionWriter: compressionWriter,
minSize: c.cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: c.excludedContentTypes,
includedContentTypes: c.includedContentTypes,
}
defer responseWriter.close()
c.next.ServeHTTP(responseWriter, r)
}
type compression interface {
// Write data to the encoder.
// Input data will be buffered and as the buffer fills up
// content will be compressed and written to the output.
// When done writing, use Close to flush the remaining output
// and write CRC if requested.
Write(p []byte) (n int, err error)
// Flush will send the currently written data to output
// and block until everything has been written.
// This should only be used on rare occasions where pushing the currently queued data is critical.
Flush() error
// Close closes the underlying writers if/when appropriate.
// Note that the compressed writer should not be closed if we never used it,
// as it would otherwise send some extra "end of compression" bytes.
// Close also makes sure to flush whatever was left to write from the buffer.
Close() error
}
type compressionWriter struct {
compression
alg string
}
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
switch algo {
case brotliName:
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
case zstdName:
writer, err := zstd.NewWriter(in)
if err != nil {
return nil, fmt.Errorf("creating zstd writer: %w", err)
}
return &compressionWriter{compression: writer, alg: algo}, nil
default:
return nil, fmt.Errorf("unknown compression algo: %s", algo)
}
}
func (c *compressionWriter) ContentEncoding() string {
return c.alg
}
// TODO: check whether we want to implement content-type sniffing (as gzip does) // TODO: check whether we want to implement content-type sniffing (as gzip does)
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges) // TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
type responseWriter struct { type responseWriter struct {
rw http.ResponseWriter rw http.ResponseWriter
bw *brotli.Writer compressionWriter *compressionWriter
minSize int minSize int
excludedContentTypes []parsedContentType excludedContentTypes []parsedContentType
@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// We are now in compression cruise mode until the end of times. // We are now in compression cruise mode until the end of times.
if r.compressionStarted { if r.compressionStarted {
// If compressionStarted we assume we have sent headers already // If compressionStarted we assume we have sent headers already
return r.bw.Write(p) return r.compressionWriter.Write(p)
} }
// If we detect a contentEncoding, we know we are never going to compress. // If we detect a contentEncoding, we know we are never going to compress.
@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// Since we know we are going to compress we will never be able to know the actual length. // Since we know we are going to compress we will never be able to know the actual length.
r.rw.Header().Del(contentLength) r.rw.Header().Del(contentLength)
r.rw.Header().Set(contentEncoding, "br") r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding())
r.rw.WriteHeader(r.statusCode) r.rw.WriteHeader(r.statusCode)
r.headersSent = true r.headersSent = true
// Start with sending what we have previously buffered, before actually writing // Start with sending what we have previously buffered, before actually writing
// the bytes in argument. // the bytes in argument.
n, err := r.bw.Write(r.buf) n, err := r.compressionWriter.Write(r.buf)
if err != nil { if err != nil {
r.buf = r.buf[n:] r.buf = r.buf[n:]
// Return zero because we haven't taken care of the bytes in argument yet. // Return zero because we haven't taken care of the bytes in argument yet.
@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
r.buf = r.buf[:0] r.buf = r.buf[:0]
// Now that we emptied the buffer, we can actually write the given bytes. // Now that we emptied the buffer, we can actually write the given bytes.
return r.bw.Write(p) return r.compressionWriter.Write(p)
} }
// Flush flushes data to the appropriate underlying writer(s), although it does // Flush flushes data to the appropriate underlying writer(s), although it does
@ -250,7 +328,7 @@ func (r *responseWriter) Flush() {
// we have to do it ourselves. // we have to do it ourselves.
defer func() { defer func() {
// because we also ignore the error returned by Write anyway // because we also ignore the error returned by Write anyway
_ = r.bw.Flush() _ = r.compressionWriter.Flush()
if rw, ok := r.rw.(http.Flusher); ok { if rw, ok := r.rw.(http.Flusher); ok {
rw.Flush() rw.Flush()
@ -258,7 +336,7 @@ func (r *responseWriter) Flush() {
}() }()
// We empty whatever is left of the buffer that Write never took care of. // We empty whatever is left of the buffer that Write never took care of.
n, err := r.bw.Write(r.buf) n, err := r.compressionWriter.Write(r.buf)
if err != nil { if err != nil {
return return
} }
@ -313,7 +391,7 @@ func (r *responseWriter) close() error {
if len(r.buf) == 0 { if len(r.buf) == 0 {
// If we got here we know compression has started, so we can safely flush on bw. // If we got here we know compression has started, so we can safely flush on bw.
return r.bw.Close() return r.compressionWriter.Close()
} }
// There is still data in the buffer, because we never reached minSize (to // There is still data in the buffer, because we never reached minSize (to
@ -331,16 +409,16 @@ func (r *responseWriter) close() error {
// There is still data in the buffer, simply because Write did not take care of it all. // There is still data in the buffer, simply because Write did not take care of it all.
// We flush it to the compressed writer. // We flush it to the compressed writer.
n, err := r.bw.Write(r.buf) n, err := r.compressionWriter.Write(r.buf)
if err != nil { if err != nil {
r.bw.Close() r.compressionWriter.Close()
return err return err
} }
if n < len(r.buf) { if n < len(r.buf) {
r.bw.Close() r.compressionWriter.Close()
return io.ErrShortWrite return io.ErrShortWrite
} }
return r.bw.Close() return r.compressionWriter.Close()
} }
// parsedContentType is the parsed representation of one of the inputs to ContentTypes. // parsedContentType is the parsed representation of one of the inputs to ContentTypes.

View file

@ -1,4 +1,4 @@
package brotli package compress
import ( import (
"bytes" "bytes"
@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/andybalholm/brotli" "github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -19,44 +20,107 @@ var (
) )
func Test_Vary(t *testing.T) { func Test_Vary(t *testing.T) {
h := newTestHandler(t, smallTestBody) testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code) assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary)) assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
})
}
} }
func Test_SmallBodyNoCompression(t *testing.T) { func Test_SmallBodyNoCompression(t *testing.T) {
h := newTestHandler(t, smallTestBody) testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) test.h.ServeHTTP(rw, req)
// With less than 1024 bytes the response should not be compressed. // With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code) assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding)) assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes()) assert.Equal(t, smallTestBody, rw.Body.Bytes())
})
}
} }
func Test_AlreadyCompressed(t *testing.T) { func Test_AlreadyCompressed(t *testing.T) {
h := newTestHandler(t, bigTestBody) testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, bigTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, bigTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil) req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code) assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes()) assert.Equal(t, bigTestBody, rw.Body.Bytes())
})
}
} }
func Test_NoBody(t *testing.T) { func Test_NoBody(t *testing.T) {
@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) {
t.Run(test.desc, func(t *testing.T) { t.Run(test.desc, func(t *testing.T) {
t.Parallel() t.Parallel()
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(test.statusCode) rw.WriteHeader(test.statusCode)
_, err := rw.Write(test.body) _, err := rw.Write(test.body)
require.NoError(t, err) require.NoError(t, err)
})) })
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) h.ServeHTTP(rw, req)
@ -116,11 +182,12 @@ func Test_NoBody(t *testing.T) {
func Test_MinSize(t *testing.T) { func Test_MinSize(t *testing.T) {
cfg := Config{ cfg := Config{
MinSize: 128, MinSize: 128,
Algorithm: zstdName,
} }
var bodySize int var bodySize int
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) { next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for range bodySize { for range bodySize {
// We make sure to Write at least once less than minSize so that both // We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering // cases below go through the same algo: i.e. they start buffering
@ -128,11 +195,12 @@ func Test_MinSize(t *testing.T) {
_, err := rw.Write([]byte{'x'}) _, err := rw.Write([]byte{'x'})
require.NoError(t, err) require.NoError(t, err)
} }
}, })
))
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{}) req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
req.Header.Add(acceptEncoding, "br") req.Header.Add(acceptEncoding, "zstd")
// Short response is not compressed // Short response is not compressed
bodySize = cfg.MinSize - 1 bodySize = cfg.MinSize - 1
@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) {
rw = httptest.NewRecorder() rw = httptest.NewRecorder()
h.ServeHTTP(rw, req) h.ServeHTTP(rw, req)
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding)) assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
} }
func Test_MultipleWriteHeader(t *testing.T) { func Test_MultipleWriteHeader(t *testing.T) {
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// We ensure that the subsequent call to WriteHeader is a noop. // We ensure that the subsequent call to WriteHeader is a noop.
rw.WriteHeader(http.StatusInternalServerError) rw.WriteHeader(http.StatusInternalServerError)
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
})) })
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) h.ServeHTTP(rw, req)
@ -166,19 +236,49 @@ func Test_MultipleWriteHeader(t *testing.T) {
} }
func Test_FlushBeforeWrite(t *testing.T) { func Test_FlushBeforeWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush() rw.(http.Flusher).Flush()
_, err := rw.Write(bigTestBody) _, err := rw.Write(bigTestBody)
require.NoError(t, err) require.NoError(t, err)
}))) })
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close() defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
require.NoError(t, err) require.NoError(t, err)
@ -186,15 +286,46 @@ func Test_FlushBeforeWrite(t *testing.T) {
defer res.Body.Close() defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding)) assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(res.Body)) reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
})
}
} }
func Test_FlushAfterWrite(t *testing.T) { func Test_FlushAfterWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_, err := rw.Write(bigTestBody[0:1]) _, err := rw.Write(bigTestBody[0:1])
@ -205,13 +336,15 @@ func Test_FlushAfterWrite(t *testing.T) {
_, err := rw.Write([]byte{b}) _, err := rw.Write([]byte{b})
require.NoError(t, err) require.NoError(t, err)
} }
}))) })
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close() defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
require.NoError(t, err) require.NoError(t, err)
@ -219,28 +352,61 @@ func Test_FlushAfterWrite(t *testing.T) {
defer res.Body.Close() defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding)) assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(res.Body)) reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
})
}
} }
func Test_FlushAfterWriteNil(t *testing.T) { func Test_FlushAfterWriteNil(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_, err := rw.Write(nil) _, err := rw.Write(nil)
require.NoError(t, err) require.NoError(t, err)
rw.(http.Flusher).Flush() rw.(http.Flusher).Flush()
}))) })
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close() defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
require.NoError(t, err) require.NoError(t, err)
@ -250,25 +416,58 @@ func Test_FlushAfterWriteNil(t *testing.T) {
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding)) assert.Empty(t, res.Header.Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(res.Body)) reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err) require.NoError(t, err)
assert.Empty(t, got) assert.Empty(t, got)
})
}
} }
func Test_FlushAfterAllWrites(t *testing.T) { func Test_FlushAfterAllWrites(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody { for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1]) _, err := rw.Write(bigTestBody[i : i+1])
require.NoError(t, err) require.NoError(t, err)
} }
rw.(http.Flusher).Flush() rw.(http.Flusher).Flush()
}))) })
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close() defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
require.NoError(t, err) require.NoError(t, err)
@ -276,11 +475,16 @@ func Test_FlushAfterAllWrites(t *testing.T) {
defer res.Body.Close() defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding)) assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(res.Body)) reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
})
}
} }
func Test_ExcludedContentTypes(t *testing.T) { func Test_ExcludedContentTypes(t *testing.T) {
@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) {
cfg := Config{ cfg := Config{
MinSize: 1024, MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes, ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
} }
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType) rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody) _, err := rw.Write(bigTestBody)
require.NoError(t, err) require.NoError(t, err)
})) })
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) h.ServeHTTP(rw, req)
@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code) assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression { if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding)) assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body)) reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
} else { } else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding")) assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body) got, err := io.ReadAll(rw.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) {
cfg := Config{ cfg := Config{
MinSize: 1024, MinSize: 1024,
IncludedContentTypes: test.includedContentTypes, IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
} }
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType) rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody) _, err := rw.Write(bigTestBody)
require.NoError(t, err) require.NoError(t, err)
})) })
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
h.ServeHTTP(rw, req) h.ServeHTTP(rw, req)
@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code) assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression { if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding)) assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body)) reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
} else { } else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding")) assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body) got, err := io.ReadAll(rw.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
cfg := Config{ cfg := Config{
MinSize: 1024, MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes, ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
} }
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType) rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush() rw.(http.Flusher).Flush()
tb = tb[toWrite:] tb = tb[toWrite:]
} }
})) })
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct. // This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code) assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression { if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding)) assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body)) reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
} else { } else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding)) assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body) got, err := io.ReadAll(rw.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
cfg := Config{ cfg := Config{
MinSize: 1024, MinSize: 1024,
IncludedContentTypes: test.includedContentTypes, IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
} }
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType) rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush() rw.(http.Flusher).Flush()
tb = tb[toWrite:] tb = tb[toWrite:]
} }
})) })
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br") req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct. // This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code) assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression { if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding)) assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body)) reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, bigTestBody, got) assert.Equal(t, bigTestBody, got)
} else { } else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding)) assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body) got, err := io.ReadAll(rw.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
} }
} }
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc { func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
t.Helper() t.Helper()
w, err := NewWrapper(cfg) w, err := NewCompressionHandler(cfg, next)
require.NoError(t, err) require.NoError(t, err)
return w return w
} }
func newTestHandler(t *testing.T, body []byte) http.Handler { func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
t.Helper() t.Helper()
return mustNewWrapper(t, Config{MinSize: 1024})( next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" { if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", "br") rw.Header().Set("Content-Encoding", brotliName)
} }
rw.WriteHeader(http.StatusAccepted) rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body) _, err := rw.Write(body)
require.NoError(t, err) require.NoError(t, err)
}), })
)
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
} }
func TestParseContentType_equals(t *testing.T) { func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", zstdName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
}
func Test_ParseContentType_equals(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
pct parsedContentType pct parsedContentType