diff --git a/docs/content/middlewares/http/compress.md b/docs/content/middlewares/http/compress.md index d3f2e1085..f26fe622c 100644 --- a/docs/content/middlewares/http/compress.md +++ b/docs/content/middlewares/http/compress.md @@ -10,7 +10,7 @@ Compress Allows Compressing Responses before Sending them to the Client ![Compress](../../assets/img/middleware/compress.png) -The Compress middleware supports gzip and Brotli compression. +The Compress middleware supports Gzip, Brotli and Zstandard compression. The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header. ## Configuration Examples @@ -54,8 +54,8 @@ http: Responses are compressed when the following criteria are all met: - * The `Accept-Encoding` request header contains `gzip`, `*`, and/or `br` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values). - If the `Accept-Encoding` request header is absent, the response won't be encoded. + * The `Accept-Encoding` request header contains `gzip`, and/or `*`, and/or `br`, and/or `zstd` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values). + If the `Accept-Encoding` request header is absent and no [defaultEncoding](#defaultencoding) is configured, the response won't be encoded. If it is present, but its value is the empty string, then compression is disabled. * The response is not already compressed, i.e. the `Content-Encoding` response header is not already set. * The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes). diff --git a/pkg/middlewares/compress/acceptencoding.go b/pkg/middlewares/compress/acceptencoding.go index 084a8f263..3f9fc4f4a 100644 --- a/pkg/middlewares/compress/acceptencoding.go +++ b/pkg/middlewares/compress/acceptencoding.go @@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding" const ( brotliName = "br" gzipName = "gzip" + zstdName = "zstd" identityName = "identity" wildcardName = "*" notAcceptable = "not_acceptable" @@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string { return encoding.Type } - for _, dt := range []string{brotliName, gzipName} { + for _, dt := range []string{zstdName, brotliName, gzipName} { if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) { return dt } @@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) { } switch parsed[0] { - case brotliName, gzipName, identityName, wildcardName: + case zstdName, brotliName, gzipName, identityName, wildcardName: // supported encoding default: continue diff --git a/pkg/middlewares/compress/acceptencoding_test.go b/pkg/middlewares/compress/acceptencoding_test.go index 858e6795e..818c3e06e 100644 --- a/pkg/middlewares/compress/acceptencoding_test.go +++ b/pkg/middlewares/compress/acceptencoding_test.go @@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) { values: []string{"gzip, br"}, expected: brotliName, }, + { + desc: "zstd > br > gzip (no weight)", + values: []string{"zstd, gzip, br"}, + expected: zstdName, + }, { desc: "known compression type (no weight)", values: []string{"compress, gzip"}, @@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) { values: []string{"compress;q=1.0, gzip;q=0.5"}, expected: gzipName, }, + { + desc: "fallback on non-zero compression type", + values: []string{"compress;q=1.0, gzip, identity;q=0"}, + expected: gzipName, + }, { desc: "not acceptable (identity)", values: []string{"compress;q=1.0, identity;q=0"}, @@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) { }{ { desc: "weight", - values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"}, + values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"}, expected: []Encoding{ {Type: brotliName, Weight: ptr[float64](1)}, + {Type: zstdName, Weight: ptr(0.9)}, {Type: gzipName, Weight: ptr(0.8)}, {Type: wildcardName, Weight: ptr(0.1)}, }, @@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) { }, { desc: "mixed", - values: []string{"gzip, br;q=1.0, *;q=0"}, + values: []string{"zstd,gzip, br;q=1.0, *;q=0"}, expected: []Encoding{ {Type: brotliName, Weight: ptr[float64](1)}, + {Type: zstdName}, {Type: gzipName}, {Type: wildcardName, Weight: ptr[float64](0)}, }, @@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) { }, { desc: "no weight", - values: []string{"gzip, br, *"}, + values: []string{"zstd, gzip, br, *"}, expected: []Encoding{ + {Type: zstdName}, {Type: gzipName}, {Type: brotliName}, {Type: wildcardName}, diff --git a/pkg/middlewares/compress/compress.go b/pkg/middlewares/compress/compress.go index 514fddb2e..3ccac9f2d 100644 --- a/pkg/middlewares/compress/compress.go +++ b/pkg/middlewares/compress/compress.go @@ -11,7 +11,6 @@ import ( "github.com/klauspost/compress/gzhttp" "github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/middlewares" - "github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli" "go.opentelemetry.io/otel/trace" ) @@ -32,6 +31,7 @@ type compress struct { brotliHandler http.Handler gzipHandler http.Handler + zstdHandler http.Handler } // New creates a new compress middleware. @@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str } var err error - c.brotliHandler, err = c.newBrotliHandler() + + c.zstdHandler, err = c.newCompressionHandler(zstdName, name) + if err != nil { + return nil, err + } + + c.brotliHandler, err = c.newCompressionHandler(brotliName, name) if err != nil { return nil, err } @@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) { switch typ { + case zstdName: + c.zstdHandler.ServeHTTP(rw, req) case brotliName: c.brotliHandler.ServeHTTP(rw, req) case gzipName: @@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) { return wrapper(c.next), nil } -func (c *compress) newBrotliHandler() (http.Handler, error) { - cfg := brotli.Config{MinSize: c.minSize} +func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) { + cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName} if len(c.includes) > 0 { cfg.IncludedContentTypes = c.includes } else { cfg.ExcludedContentTypes = c.excludes } - wrapper, err := brotli.NewWrapper(cfg) - if err != nil { - return nil, fmt.Errorf("new brotli wrapper: %w", err) - } - - return wrapper(c.next), nil + return NewCompressionHandler(cfg, c.next) } diff --git a/pkg/middlewares/compress/compress_test.go b/pkg/middlewares/compress/compress_test.go index 556d9f5fa..af127df0c 100644 --- a/pkg/middlewares/compress/compress_test.go +++ b/pkg/middlewares/compress/compress_test.go @@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) { { desc: "accept any header", acceptEncHeader: "*", - expEncoding: "br", + expEncoding: brotliName, }, { desc: "gzip accept header", acceptEncHeader: "gzip", - expEncoding: "gzip", + expEncoding: gzipName, }, { desc: "br accept header", acceptEncHeader: "br", - expEncoding: "br", + expEncoding: brotliName, }, { desc: "multi accept header, prefer br", acceptEncHeader: "br;q=0.8, gzip;q=0.6", - expEncoding: "br", + expEncoding: brotliName, }, { desc: "multi accept header, prefer gzip", acceptEncHeader: "gzip;q=1.0, br;q=0.8", - expEncoding: "gzip", + expEncoding: gzipName, }, { desc: "multi accept header list, prefer br", acceptEncHeader: "gzip, br", - expEncoding: "br", + expEncoding: brotliName, + }, + { + desc: "zstd accept header", + acceptEncHeader: "zstd", + expEncoding: zstdName, + }, + { + desc: "multi accept header, prefer zstd", + acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6", + expEncoding: zstdName, + }, + { + desc: "multi accept header, prefer gzip", + acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7", + expEncoding: gzipName, + }, + { + desc: "multi accept header list, prefer zstd", + acceptEncHeader: "gzip, br, zstd", + expEncoding: zstdName, }, } diff --git a/pkg/middlewares/compress/brotli/brotli.go b/pkg/middlewares/compress/compression_handler.go similarity index 72% rename from pkg/middlewares/compress/brotli/brotli.go rename to pkg/middlewares/compress/compression_handler.go index 27802b0f6..78214acc5 100644 --- a/pkg/middlewares/compress/brotli/brotli.go +++ b/pkg/middlewares/compress/compression_handler.go @@ -1,4 +1,4 @@ -package brotli +package compress import ( "bufio" @@ -10,6 +10,9 @@ import ( "net/http" "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" + "github.com/traefik/traefik/v3/pkg/middlewares" + "github.com/traefik/traefik/v3/pkg/middlewares/observability" ) const ( @@ -30,10 +33,26 @@ type Config struct { IncludedContentTypes []string // MinSize is the minimum size (in bytes) required to enable compression. MinSize int + // Algorithm used for the compression (currently Brotli and Zstandard) + Algorithm string + // MiddlewareName use for logging purposes + MiddlewareName string } -// NewWrapper returns a new Brotli compressing wrapper. -func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) { +// CompressionHandler handles Brolti and Zstd compression. +type CompressionHandler struct { + cfg Config + excludedContentTypes []parsedContentType + includedContentTypes []parsedContentType + next http.Handler +} + +// NewCompressionHandler returns a new compressing handler. +func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) { + if cfg.Algorithm == "" { + return nil, errors.New("compression algorithm undefined") + } + if cfg.MinSize < 0 { return nil, errors.New("minimum size must be greater than or equal to zero") } @@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) { includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params}) } - return func(h http.Handler) http.HandlerFunc { - return func(rw http.ResponseWriter, r *http.Request) { - rw.Header().Add(vary, acceptEncoding) - - brw := &responseWriter{ - rw: rw, - bw: brotli.NewWriter(rw), - minSize: cfg.MinSize, - statusCode: http.StatusOK, - excludedContentTypes: excludedContentTypes, - includedContentTypes: includedContentTypes, - } - defer brw.close() - - h.ServeHTTP(brw, r) - } + return &CompressionHandler{ + cfg: cfg, + excludedContentTypes: excludedContentTypes, + includedContentTypes: includedContentTypes, + next: next, }, nil } +func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + rw.Header().Add(vary, acceptEncoding) + + compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw) + if err != nil { + logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName) + logMessage := fmt.Sprintf("create compression handler: %v", err) + logger.Debug().Msg(logMessage) + observability.SetStatusErrorf(r.Context(), logMessage) + rw.WriteHeader(http.StatusInternalServerError) + return + } + + responseWriter := &responseWriter{ + rw: rw, + compressionWriter: compressionWriter, + minSize: c.cfg.MinSize, + statusCode: http.StatusOK, + excludedContentTypes: c.excludedContentTypes, + includedContentTypes: c.includedContentTypes, + } + defer responseWriter.close() + + c.next.ServeHTTP(responseWriter, r) +} + +type compression interface { + // Write data to the encoder. + // Input data will be buffered and as the buffer fills up + // content will be compressed and written to the output. + // When done writing, use Close to flush the remaining output + // and write CRC if requested. + Write(p []byte) (n int, err error) + // Flush will send the currently written data to output + // and block until everything has been written. + // This should only be used on rare occasions where pushing the currently queued data is critical. + Flush() error + // Close closes the underlying writers if/when appropriate. + // Note that the compressed writer should not be closed if we never used it, + // as it would otherwise send some extra "end of compression" bytes. + // Close also makes sure to flush whatever was left to write from the buffer. + Close() error +} + +type compressionWriter struct { + compression + alg string +} + +func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) { + switch algo { + case brotliName: + return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil + + case zstdName: + writer, err := zstd.NewWriter(in) + if err != nil { + return nil, fmt.Errorf("creating zstd writer: %w", err) + } + return &compressionWriter{compression: writer, alg: algo}, nil + + default: + return nil, fmt.Errorf("unknown compression algo: %s", algo) + } +} + +func (c *compressionWriter) ContentEncoding() string { + return c.alg +} + // TODO: check whether we want to implement content-type sniffing (as gzip does) // TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges) type responseWriter struct { - rw http.ResponseWriter - bw *brotli.Writer + rw http.ResponseWriter + compressionWriter *compressionWriter minSize int excludedContentTypes []parsedContentType @@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) { // We are now in compression cruise mode until the end of times. if r.compressionStarted { // If compressionStarted we assume we have sent headers already - return r.bw.Write(p) + return r.compressionWriter.Write(p) } // If we detect a contentEncoding, we know we are never going to compress. @@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) { // Since we know we are going to compress we will never be able to know the actual length. r.rw.Header().Del(contentLength) - r.rw.Header().Set(contentEncoding, "br") + r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding()) r.rw.WriteHeader(r.statusCode) r.headersSent = true // Start with sending what we have previously buffered, before actually writing // the bytes in argument. - n, err := r.bw.Write(r.buf) + n, err := r.compressionWriter.Write(r.buf) if err != nil { r.buf = r.buf[n:] // Return zero because we haven't taken care of the bytes in argument yet. @@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) { r.buf = r.buf[:0] // Now that we emptied the buffer, we can actually write the given bytes. - return r.bw.Write(p) + return r.compressionWriter.Write(p) } // Flush flushes data to the appropriate underlying writer(s), although it does @@ -250,7 +328,7 @@ func (r *responseWriter) Flush() { // we have to do it ourselves. defer func() { // because we also ignore the error returned by Write anyway - _ = r.bw.Flush() + _ = r.compressionWriter.Flush() if rw, ok := r.rw.(http.Flusher); ok { rw.Flush() @@ -258,7 +336,7 @@ func (r *responseWriter) Flush() { }() // We empty whatever is left of the buffer that Write never took care of. - n, err := r.bw.Write(r.buf) + n, err := r.compressionWriter.Write(r.buf) if err != nil { return } @@ -313,7 +391,7 @@ func (r *responseWriter) close() error { if len(r.buf) == 0 { // If we got here we know compression has started, so we can safely flush on bw. - return r.bw.Close() + return r.compressionWriter.Close() } // There is still data in the buffer, because we never reached minSize (to @@ -331,16 +409,16 @@ func (r *responseWriter) close() error { // There is still data in the buffer, simply because Write did not take care of it all. // We flush it to the compressed writer. - n, err := r.bw.Write(r.buf) + n, err := r.compressionWriter.Write(r.buf) if err != nil { - r.bw.Close() + r.compressionWriter.Close() return err } if n < len(r.buf) { - r.bw.Close() + r.compressionWriter.Close() return io.ErrShortWrite } - return r.bw.Close() + return r.compressionWriter.Close() } // parsedContentType is the parsed representation of one of the inputs to ContentTypes. diff --git a/pkg/middlewares/compress/brotli/brotli_test.go b/pkg/middlewares/compress/compression_handler_test.go similarity index 57% rename from pkg/middlewares/compress/brotli/brotli_test.go rename to pkg/middlewares/compress/compression_handler_test.go index 67c794c46..1df9e9588 100644 --- a/pkg/middlewares/compress/brotli/brotli_test.go +++ b/pkg/middlewares/compress/compression_handler_test.go @@ -1,4 +1,4 @@ -package brotli +package compress import ( "bytes" @@ -9,6 +9,7 @@ import ( "testing" "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,44 +20,107 @@ var ( ) func Test_Vary(t *testing.T) { - h := newTestHandler(t, smallTestBody) + testCases := []struct { + desc string + h http.Handler + acceptEncoding string + }{ + { + desc: "brotli", + h: newTestBrotliHandler(t, smallTestBody), + acceptEncoding: "br", + }, + { + desc: "zstd", + h: newTestZstandardHandler(t, smallTestBody), + acceptEncoding: "zstd", + }, + } - req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) + req.Header.Set(acceptEncoding, test.acceptEncoding) - assert.Equal(t, http.StatusAccepted, rw.Code) - assert.Equal(t, acceptEncoding, rw.Header().Get(vary)) + rw := httptest.NewRecorder() + test.h.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusAccepted, rw.Code) + assert.Equal(t, acceptEncoding, rw.Header().Get(vary)) + }) + } } func Test_SmallBodyNoCompression(t *testing.T) { - h := newTestHandler(t, smallTestBody) + testCases := []struct { + desc string + h http.Handler + acceptEncoding string + }{ + { + desc: "brotli", + h: newTestBrotliHandler(t, smallTestBody), + acceptEncoding: "br", + }, + { + desc: "zstd", + h: newTestZstandardHandler(t, smallTestBody), + acceptEncoding: "zstd", + }, + } - req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) + req.Header.Set(acceptEncoding, test.acceptEncoding) - // With less than 1024 bytes the response should not be compressed. - assert.Equal(t, http.StatusAccepted, rw.Code) - assert.Empty(t, rw.Header().Get(contentEncoding)) - assert.Equal(t, smallTestBody, rw.Body.Bytes()) + rw := httptest.NewRecorder() + test.h.ServeHTTP(rw, req) + + // With less than 1024 bytes the response should not be compressed. + assert.Equal(t, http.StatusAccepted, rw.Code) + assert.Empty(t, rw.Header().Get(contentEncoding)) + assert.Equal(t, smallTestBody, rw.Body.Bytes()) + }) + } } func Test_AlreadyCompressed(t *testing.T) { - h := newTestHandler(t, bigTestBody) + testCases := []struct { + desc string + h http.Handler + acceptEncoding string + }{ + { + desc: "brotli", + h: newTestBrotliHandler(t, bigTestBody), + acceptEncoding: "br", + }, + { + desc: "zstd", + h: newTestZstandardHandler(t, bigTestBody), + acceptEncoding: "zstd", + }, + } - req, _ := http.NewRequest(http.MethodGet, "/compressed", nil) - req.Header.Set(acceptEncoding, "br") + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + req, _ := http.NewRequest(http.MethodGet, "/compressed", nil) + req.Header.Set(acceptEncoding, test.acceptEncoding) - assert.Equal(t, http.StatusAccepted, rw.Code) - assert.Equal(t, bigTestBody, rw.Body.Bytes()) + rw := httptest.NewRecorder() + test.h.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusAccepted, rw.Code) + assert.Equal(t, bigTestBody, rw.Body.Bytes()) + }) + } } func Test_NoBody(t *testing.T) { @@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(test.statusCode) _, err := rw.Write(test.body) require.NoError(t, err) - })) + }) + + h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, "zstd") rw := httptest.NewRecorder() h.ServeHTTP(rw, req) @@ -115,24 +181,26 @@ func Test_NoBody(t *testing.T) { func Test_MinSize(t *testing.T) { cfg := Config{ - MinSize: 128, + MinSize: 128, + Algorithm: zstdName, } var bodySize int - h := mustNewWrapper(t, cfg)(http.HandlerFunc( - func(rw http.ResponseWriter, req *http.Request) { - for range bodySize { - // We make sure to Write at least once less than minSize so that both - // cases below go through the same algo: i.e. they start buffering - // because they haven't reached minSize. - _, err := rw.Write([]byte{'x'}) - require.NoError(t, err) - } - }, - )) + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for range bodySize { + // We make sure to Write at least once less than minSize so that both + // cases below go through the same algo: i.e. they start buffering + // because they haven't reached minSize. + _, err := rw.Write([]byte{'x'}) + require.NoError(t, err) + } + }) + + h := mustNewCompressionHandler(t, cfg, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{}) - req.Header.Add(acceptEncoding, "br") + req.Header.Add(acceptEncoding, "zstd") // Short response is not compressed bodySize = cfg.MinSize - 1 @@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) { rw = httptest.NewRecorder() h.ServeHTTP(rw, req) - assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding)) + assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding)) } func Test_MultipleWriteHeader(t *testing.T) { - h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { // We ensure that the subsequent call to WriteHeader is a noop. rw.WriteHeader(http.StatusInternalServerError) rw.WriteHeader(http.StatusNotFound) - })) + }) + + h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, "zstd") rw := httptest.NewRecorder() h.ServeHTTP(rw, req) @@ -166,121 +236,255 @@ func Test_MultipleWriteHeader(t *testing.T) { } func Test_FlushBeforeWrite(t *testing.T) { - srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusOK) - rw.(http.Flusher).Flush() + testCases := []struct { + desc string + cfg Config + readerBuilder func(io.Reader) (io.Reader, error) + acceptEncoding string + }{ + { + desc: "brotli", + cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return brotli.NewReader(reader), nil + }, + acceptEncoding: "br", + }, + { + desc: "zstd", + cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return zstd.NewReader(reader) + }, + acceptEncoding: "zstd", + }, + } - _, err := rw.Write(bigTestBody) - require.NoError(t, err) - }))) - defer srv.Close() + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + t.Parallel() - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + rw.(http.Flusher).Flush() - req.Header.Set(acceptEncoding, "br") + _, err := rw.Write(bigTestBody) + require.NoError(t, err) + }) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + defer srv.Close() - defer res.Body.Close() + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Equal(t, "br", res.Header.Get(contentEncoding)) + req.Header.Set(acceptEncoding, test.acceptEncoding) - got, err := io.ReadAll(brotli.NewReader(res.Body)) - require.NoError(t, err) - assert.Equal(t, bigTestBody, got) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding)) + + reader, err := test.readerBuilder(res.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, bigTestBody, got) + }) + } } func Test_FlushAfterWrite(t *testing.T) { - srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusOK) + testCases := []struct { + desc string + cfg Config + readerBuilder func(io.Reader) (io.Reader, error) + acceptEncoding string + }{ + { + desc: "brotli", + cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return brotli.NewReader(reader), nil + }, + acceptEncoding: "br", + }, + { + desc: "zstd", + cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return zstd.NewReader(reader) + }, + acceptEncoding: "zstd", + }, + } - _, err := rw.Write(bigTestBody[0:1]) - require.NoError(t, err) + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) - rw.(http.Flusher).Flush() - for _, b := range bigTestBody[1:] { - _, err := rw.Write([]byte{b}) + _, err := rw.Write(bigTestBody[0:1]) + require.NoError(t, err) + + rw.(http.Flusher).Flush() + for _, b := range bigTestBody[1:] { + _, err := rw.Write([]byte{b}) + require.NoError(t, err) + } + }) + + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + defer srv.Close() + + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) require.NoError(t, err) - } - }))) - defer srv.Close() - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) + req.Header.Set(acceptEncoding, test.acceptEncoding) - req.Header.Set(acceptEncoding, "br") + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) + defer res.Body.Close() - defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding)) - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Equal(t, "br", res.Header.Get(contentEncoding)) + reader, err := test.readerBuilder(res.Body) + require.NoError(t, err) - got, err := io.ReadAll(brotli.NewReader(res.Body)) - require.NoError(t, err) - assert.Equal(t, bigTestBody, got) + got, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, bigTestBody, got) + }) + } } func Test_FlushAfterWriteNil(t *testing.T) { - srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusOK) + testCases := []struct { + desc string + cfg Config + readerBuilder func(io.Reader) (io.Reader, error) + acceptEncoding string + }{ + { + desc: "brotli", + cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return brotli.NewReader(reader), nil + }, + acceptEncoding: "br", + }, + { + desc: "zstd", + cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return zstd.NewReader(reader) + }, + acceptEncoding: "zstd", + }, + } - _, err := rw.Write(nil) - require.NoError(t, err) + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) - rw.(http.Flusher).Flush() - }))) - defer srv.Close() + _, err := rw.Write(nil) + require.NoError(t, err) - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) + rw.(http.Flusher).Flush() + }) - req.Header.Set(acceptEncoding, "br") + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + defer srv.Close() - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) + require.NoError(t, err) - defer res.Body.Close() + req.Header.Set(acceptEncoding, test.acceptEncoding) - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Empty(t, res.Header.Get(contentEncoding)) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) - got, err := io.ReadAll(brotli.NewReader(res.Body)) - require.NoError(t, err) - assert.Empty(t, got) + defer res.Body.Close() + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Empty(t, res.Header.Get(contentEncoding)) + + reader, err := test.readerBuilder(res.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Empty(t, got) + }) + } } func Test_FlushAfterAllWrites(t *testing.T) { - srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - for i := range bigTestBody { - _, err := rw.Write(bigTestBody[i : i+1]) + testCases := []struct { + desc string + cfg Config + readerBuilder func(io.Reader) (io.Reader, error) + acceptEncoding string + }{ + { + desc: "brotli", + cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return brotli.NewReader(reader), nil + }, + acceptEncoding: "br", + }, + { + desc: "zstd", + cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + readerBuilder: func(reader io.Reader) (io.Reader, error) { + return zstd.NewReader(reader) + }, + acceptEncoding: "zstd", + }, + } + + for _, test := range testCases { + t.Run(test.desc, func(t *testing.T) { + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + for i := range bigTestBody { + _, err := rw.Write(bigTestBody[i : i+1]) + require.NoError(t, err) + } + rw.(http.Flusher).Flush() + }) + + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + defer srv.Close() + + req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) require.NoError(t, err) - } - rw.(http.Flusher).Flush() - }))) - defer srv.Close() - req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) - require.NoError(t, err) + req.Header.Set(acceptEncoding, test.acceptEncoding) - req.Header.Set(acceptEncoding, "br") + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) + defer res.Body.Close() - defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding)) - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Equal(t, "br", res.Header.Get(contentEncoding)) + reader, err := test.readerBuilder(res.Body) + require.NoError(t, err) - got, err := io.ReadAll(brotli.NewReader(res.Body)) - require.NoError(t, err) - assert.Equal(t, bigTestBody, got) + got, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, bigTestBody, got) + }) + } } func Test_ExcludedContentTypes(t *testing.T) { @@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, ExcludedContentTypes: test.excludedContentTypes, + Algorithm: zstdName, } - h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set(contentType, test.contentType) rw.WriteHeader(http.StatusAccepted) _, err := rw.Write(bigTestBody) require.NoError(t, err) - })) + }) + + h := mustNewCompressionHandler(t, cfg, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, zstdName) rw := httptest.NewRecorder() h.ServeHTTP(rw, req) @@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) { assert.Equal(t, http.StatusAccepted, rw.Code) if test.expCompression { - assert.Equal(t, "br", rw.Header().Get(contentEncoding)) + assert.Equal(t, zstdName, rw.Header().Get(contentEncoding)) - got, err := io.ReadAll(brotli.NewReader(rw.Body)) + reader, err := zstd.NewReader(rw.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, bigTestBody, got) } else { - assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding")) + assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding")) got, err := io.ReadAll(rw.Body) assert.NoError(t, err) @@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, IncludedContentTypes: test.includedContentTypes, + Algorithm: zstdName, } - h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set(contentType, test.contentType) rw.WriteHeader(http.StatusAccepted) _, err := rw.Write(bigTestBody) require.NoError(t, err) - })) + }) + + h := mustNewCompressionHandler(t, cfg, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, zstdName) rw := httptest.NewRecorder() h.ServeHTTP(rw, req) @@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) { assert.Equal(t, http.StatusAccepted, rw.Code) if test.expCompression { - assert.Equal(t, "br", rw.Header().Get(contentEncoding)) + assert.Equal(t, zstdName, rw.Header().Get(contentEncoding)) - got, err := io.ReadAll(brotli.NewReader(rw.Body)) + reader, err := zstd.NewReader(rw.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, bigTestBody, got) } else { - assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding")) + assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding")) got, err := io.ReadAll(rw.Body) assert.NoError(t, err) @@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, ExcludedContentTypes: test.excludedContentTypes, + Algorithm: zstdName, } - h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set(contentType, test.contentType) rw.WriteHeader(http.StatusOK) @@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) { rw.(http.Flusher).Flush() tb = tb[toWrite:] } - })) + }) + + h := mustNewCompressionHandler(t, cfg, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, zstdName) // This doesn't allow checking flushes, but we validate if content is correct. rw := httptest.NewRecorder() @@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) { assert.Equal(t, http.StatusOK, rw.Code) if test.expCompression { - assert.Equal(t, "br", rw.Header().Get(contentEncoding)) + assert.Equal(t, zstdName, rw.Header().Get(contentEncoding)) - got, err := io.ReadAll(brotli.NewReader(rw.Body)) + reader, err := zstd.NewReader(rw.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, bigTestBody, got) } else { - assert.NotEqual(t, "br", rw.Header().Get(contentEncoding)) + assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding)) got, err := io.ReadAll(rw.Body) assert.NoError(t, err) @@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, IncludedContentTypes: test.includedContentTypes, + Algorithm: zstdName, } - h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set(contentType, test.contentType) rw.WriteHeader(http.StatusOK) @@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) { rw.(http.Flusher).Flush() tb = tb[toWrite:] } - })) + }) + + h := mustNewCompressionHandler(t, cfg, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) - req.Header.Set(acceptEncoding, "br") + req.Header.Set(acceptEncoding, zstdName) // This doesn't allow checking flushes, but we validate if content is correct. rw := httptest.NewRecorder() @@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) { assert.Equal(t, http.StatusOK, rw.Code) if test.expCompression { - assert.Equal(t, "br", rw.Header().Get(contentEncoding)) + assert.Equal(t, zstdName, rw.Header().Get(contentEncoding)) - got, err := io.ReadAll(brotli.NewReader(rw.Body)) + reader, err := zstd.NewReader(rw.Body) + require.NoError(t, err) + + got, err := io.ReadAll(reader) assert.NoError(t, err) assert.Equal(t, bigTestBody, got) } else { - assert.NotEqual(t, "br", rw.Header().Get(contentEncoding)) + assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding)) got, err := io.ReadAll(rw.Body) assert.NoError(t, err) @@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) { } } -func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc { +func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler { t.Helper() - w, err := NewWrapper(cfg) + w, err := NewCompressionHandler(cfg, next) require.NoError(t, err) return w } -func newTestHandler(t *testing.T, body []byte) http.Handler { +func newTestBrotliHandler(t *testing.T, body []byte) http.Handler { t.Helper() - return mustNewWrapper(t, Config{MinSize: 1024})( - http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.URL.Path == "/compressed" { - rw.Header().Set("Content-Encoding", "br") - } + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/compressed" { + rw.Header().Set("Content-Encoding", brotliName) + } - rw.WriteHeader(http.StatusAccepted) - _, err := rw.Write(body) - require.NoError(t, err) - }), - ) + rw.WriteHeader(http.StatusAccepted) + _, err := rw.Write(body) + require.NoError(t, err) + }) + + return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next) } -func TestParseContentType_equals(t *testing.T) { +func newTestZstandardHandler(t *testing.T, body []byte) http.Handler { + t.Helper() + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/compressed" { + rw.Header().Set("Content-Encoding", zstdName) + } + + rw.WriteHeader(http.StatusAccepted) + _, err := rw.Write(body) + require.NoError(t, err) + }) + + return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next) +} + +func Test_ParseContentType_equals(t *testing.T) { testCases := []struct { desc string pct parsedContentType