From ef168b801c43080b373190967c27b48631a009ed Mon Sep 17 00:00:00 2001 From: Kevin Pollet Date: Thu, 10 Oct 2024 16:04:04 +0200 Subject: [PATCH] Refactor compress handler to make it generic Co-authored-by: Romain --- pkg/middlewares/compress/compress.go | 33 +++++- .../compress/compression_handler.go | 103 ++++++++---------- .../compress/compression_handler_test.go | 81 +++++++++----- 3 files changed, 125 insertions(+), 92 deletions(-) diff --git a/pkg/middlewares/compress/compress.go b/pkg/middlewares/compress/compress.go index 3ccac9f2d..fc8fe8d96 100644 --- a/pkg/middlewares/compress/compress.go +++ b/pkg/middlewares/compress/compress.go @@ -8,7 +8,9 @@ import ( "net/http" "slices" + "github.com/andybalholm/brotli" "github.com/klauspost/compress/gzhttp" + "github.com/klauspost/compress/zstd" "github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/middlewares" "go.opentelemetry.io/otel/trace" @@ -78,12 +80,12 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str var err error - c.zstdHandler, err = c.newCompressionHandler(zstdName, name) + c.zstdHandler, err = c.newZstdHandler(name) if err != nil { return nil, err } - c.brotliHandler, err = c.newCompressionHandler(brotliName, name) + c.brotliHandler, err = c.newBrotliHandler(name) if err != nil { return nil, err } @@ -174,13 +176,34 @@ func (c *compress) newGzipHandler() (http.Handler, error) { return wrapper(c.next), nil } -func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) { - cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName} +func (c *compress) newBrotliHandler(middlewareName string) (http.Handler, error) { + cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName} if len(c.includes) > 0 { cfg.IncludedContentTypes = c.includes } else { cfg.ExcludedContentTypes = c.excludes } - return NewCompressionHandler(cfg, c.next) + newBrotliWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) { + return brotli.NewWriter(rw), brotliName, nil + } + return NewCompressionHandler(cfg, newBrotliWriter, c.next) +} + +func (c *compress) newZstdHandler(middlewareName string) (http.Handler, error) { + cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName} + if len(c.includes) > 0 { + cfg.IncludedContentTypes = c.includes + } else { + cfg.ExcludedContentTypes = c.excludes + } + + newZstdWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) { + writer, err := zstd.NewWriter(rw) + if err != nil { + return nil, "", fmt.Errorf("creating zstd writer: %w", err) + } + return writer, zstdName, nil + } + return NewCompressionHandler(cfg, newZstdWriter, c.next) } diff --git a/pkg/middlewares/compress/compression_handler.go b/pkg/middlewares/compress/compression_handler.go index a8f08ec83..1c5065ca7 100644 --- a/pkg/middlewares/compress/compression_handler.go +++ b/pkg/middlewares/compress/compression_handler.go @@ -10,8 +10,6 @@ import ( "net/http" "sync" - "github.com/andybalholm/brotli" - "github.com/klauspost/compress/zstd" "github.com/traefik/traefik/v3/pkg/middlewares" "github.com/traefik/traefik/v3/pkg/middlewares/observability" ) @@ -24,6 +22,30 @@ const ( contentType = "Content-Type" ) +// CompressionWriter compresses the written bytes. +type CompressionWriter interface { + // Write data to the encoder. + // Input data will be buffered and as the buffer fills up + // content will be compressed and written to the output. + // When done writing, use Close to flush the remaining output + // and write CRC if requested. + Write(p []byte) (n int, err error) + // Flush will send the currently written data to output + // and block until everything has been written. + // This should only be used on rare occasions where pushing the currently queued data is critical. + Flush() error + // Close closes the underlying writers if/when appropriate. + // Note that the compressed writer should not be closed if we never used it, + // as it would otherwise send some extra "end of compression" bytes. + // Close also makes sure to flush whatever was left to write from the buffer. + Close() error + // Reset reinitializes the state of the encoder, allowing it to be reused. + Reset(w io.Writer) +} + +// NewCompressionWriter returns a new CompressionWriter with its corresponding algorithm. +type NewCompressionWriter func(rw http.ResponseWriter) (CompressionWriter, string, error) + // Config is the Brotli handler configuration. type Config struct { // ExcludedContentTypes is the list of content types for which we should not compress. @@ -34,8 +56,6 @@ type Config struct { IncludedContentTypes []string // MinSize is the minimum size (in bytes) required to enable compression. MinSize int - // Algorithm used for the compression (currently Brotli and Zstandard) - Algorithm string // MiddlewareName use for logging purposes MiddlewareName string } @@ -46,15 +66,13 @@ type CompressionHandler struct { excludedContentTypes []parsedContentType includedContentTypes []parsedContentType next http.Handler - writerPool sync.Pool + + writerPool sync.Pool + newWriter NewCompressionWriter } // NewCompressionHandler returns a new compressing handler. -func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) { - if cfg.Algorithm == "" { - return nil, errors.New("compression algorithm undefined") - } - +func NewCompressionHandler(cfg Config, newWriter NewCompressionWriter, next http.Handler) (http.Handler, error) { if cfg.MinSize < 0 { return nil, errors.New("minimum size must be greater than or equal to zero") } @@ -88,6 +106,7 @@ func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) excludedContentTypes: excludedContentTypes, includedContentTypes: includedContentTypes, next: next, + newWriter: newWriter, }, nil } @@ -117,70 +136,38 @@ func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) c.next.ServeHTTP(responseWriter, r) } -type compression interface { - // Write data to the encoder. - // Input data will be buffered and as the buffer fills up - // content will be compressed and written to the output. - // When done writing, use Close to flush the remaining output - // and write CRC if requested. - Write(p []byte) (n int, err error) - // Flush will send the currently written data to output - // and block until everything has been written. - // This should only be used on rare occasions where pushing the currently queued data is critical. - Flush() error - // Close closes the underlying writers if/when appropriate. - // Note that the compressed writer should not be closed if we never used it, - // as it would otherwise send some extra "end of compression" bytes. - // Close also makes sure to flush whatever was left to write from the buffer. - Close() error - // Reset reinitializes the state of the encoder, allowing it to be reused. - Reset(w io.Writer) -} - -type compressionWriter struct { - compression - alg string -} - -func (c *CompressionHandler) getCompressionWriter(rw io.Writer) (*compressionWriter, error) { - if writer, ok := c.writerPool.Get().(*compressionWriter); ok { - writer.compression.Reset(rw) +func (c *CompressionHandler) getCompressionWriter(rw http.ResponseWriter) (*compressionWriterWrapper, error) { + if writer, ok := c.writerPool.Get().(*compressionWriterWrapper); ok { + writer.Reset(rw) return writer, nil } - return newCompressionWriter(c.cfg.Algorithm, rw) + + writer, algo, err := c.newWriter(rw) + if err != nil { + return nil, fmt.Errorf("creating compression writer: %w", err) + } + return &compressionWriterWrapper{CompressionWriter: writer, algo: algo}, nil } -func (c *CompressionHandler) putCompressionWriter(writer *compressionWriter) { +func (c *CompressionHandler) putCompressionWriter(writer *compressionWriterWrapper) { writer.Reset(nil) c.writerPool.Put(writer) } -func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) { - switch algo { - case brotliName: - return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil - - case zstdName: - writer, err := zstd.NewWriter(in) - if err != nil { - return nil, fmt.Errorf("creating zstd writer: %w", err) - } - return &compressionWriter{compression: writer, alg: algo}, nil - - default: - return nil, fmt.Errorf("unknown compression algo: %s", algo) - } +type compressionWriterWrapper struct { + CompressionWriter + algo string } -func (c *compressionWriter) ContentEncoding() string { - return c.alg +func (c *compressionWriterWrapper) ContentEncoding() string { + return c.algo } // TODO: check whether we want to implement content-type sniffing (as gzip does) // TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges) type responseWriter struct { rw http.ResponseWriter - compressionWriter *compressionWriter + compressionWriter *compressionWriterWrapper minSize int excludedContentTypes []parsedContentType diff --git a/pkg/middlewares/compress/compression_handler_test.go b/pkg/middlewares/compress/compression_handler_test.go index 1df9e9588..c702500d7 100644 --- a/pkg/middlewares/compress/compression_handler_test.go +++ b/pkg/middlewares/compress/compression_handler_test.go @@ -162,7 +162,7 @@ func Test_NoBody(t *testing.T) { require.NoError(t, err) }) - h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next) + h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(acceptEncoding, "zstd") @@ -181,8 +181,7 @@ func Test_NoBody(t *testing.T) { func Test_MinSize(t *testing.T) { cfg := Config{ - MinSize: 128, - Algorithm: zstdName, + MinSize: 128, } var bodySize int @@ -197,7 +196,7 @@ func Test_MinSize(t *testing.T) { } }) - h := mustNewCompressionHandler(t, cfg, next) + h := mustNewCompressionHandler(t, cfg, zstdName, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{}) req.Header.Add(acceptEncoding, "zstd") @@ -224,7 +223,7 @@ func Test_MultipleWriteHeader(t *testing.T) { rw.WriteHeader(http.StatusNotFound) }) - h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next) + h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(acceptEncoding, "zstd") @@ -239,12 +238,14 @@ func Test_FlushBeforeWrite(t *testing.T) { testCases := []struct { desc string cfg Config + algo string readerBuilder func(io.Reader) (io.Reader, error) acceptEncoding string }{ { desc: "brotli", - cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: brotliName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return brotli.NewReader(reader), nil }, @@ -252,7 +253,8 @@ func Test_FlushBeforeWrite(t *testing.T) { }, { desc: "zstd", - cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: zstdName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return zstd.NewReader(reader) }, @@ -272,7 +274,7 @@ func Test_FlushBeforeWrite(t *testing.T) { require.NoError(t, err) }) - srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next)) defer srv.Close() req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) @@ -302,12 +304,14 @@ func Test_FlushAfterWrite(t *testing.T) { testCases := []struct { desc string cfg Config + algo string readerBuilder func(io.Reader) (io.Reader, error) acceptEncoding string }{ { desc: "brotli", - cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: brotliName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return brotli.NewReader(reader), nil }, @@ -315,7 +319,8 @@ func Test_FlushAfterWrite(t *testing.T) { }, { desc: "zstd", - cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: zstdName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return zstd.NewReader(reader) }, @@ -338,7 +343,7 @@ func Test_FlushAfterWrite(t *testing.T) { } }) - srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next)) defer srv.Close() req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) @@ -368,12 +373,14 @@ func Test_FlushAfterWriteNil(t *testing.T) { testCases := []struct { desc string cfg Config + algo string readerBuilder func(io.Reader) (io.Reader, error) acceptEncoding string }{ { desc: "brotli", - cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: brotliName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return brotli.NewReader(reader), nil }, @@ -381,7 +388,8 @@ func Test_FlushAfterWriteNil(t *testing.T) { }, { desc: "zstd", - cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: zstdName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return zstd.NewReader(reader) }, @@ -400,7 +408,7 @@ func Test_FlushAfterWriteNil(t *testing.T) { rw.(http.Flusher).Flush() }) - srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next)) defer srv.Close() req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) @@ -430,12 +438,14 @@ func Test_FlushAfterAllWrites(t *testing.T) { testCases := []struct { desc string cfg Config + algo string readerBuilder func(io.Reader) (io.Reader, error) acceptEncoding string }{ { desc: "brotli", - cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: brotliName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return brotli.NewReader(reader), nil }, @@ -443,7 +453,8 @@ func Test_FlushAfterAllWrites(t *testing.T) { }, { desc: "zstd", - cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"}, + cfg: Config{MinSize: 1024, MiddlewareName: "Test"}, + algo: zstdName, readerBuilder: func(reader io.Reader) (io.Reader, error) { return zstd.NewReader(reader) }, @@ -461,7 +472,7 @@ func Test_FlushAfterAllWrites(t *testing.T) { rw.(http.Flusher).Flush() }) - srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next)) + srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next)) defer srv.Close() req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody) @@ -556,7 +567,6 @@ func Test_ExcludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, ExcludedContentTypes: test.excludedContentTypes, - Algorithm: zstdName, } next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -568,7 +578,7 @@ func Test_ExcludedContentTypes(t *testing.T) { require.NoError(t, err) }) - h := mustNewCompressionHandler(t, cfg, next) + h := mustNewCompressionHandler(t, cfg, zstdName, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req.Header.Set(acceptEncoding, zstdName) @@ -667,7 +677,6 @@ func Test_IncludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, IncludedContentTypes: test.includedContentTypes, - Algorithm: zstdName, } next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -679,7 +688,7 @@ func Test_IncludedContentTypes(t *testing.T) { require.NoError(t, err) }) - h := mustNewCompressionHandler(t, cfg, next) + h := mustNewCompressionHandler(t, cfg, zstdName, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req.Header.Set(acceptEncoding, zstdName) @@ -778,7 +787,6 @@ func Test_FlushExcludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, ExcludedContentTypes: test.excludedContentTypes, - Algorithm: zstdName, } next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -803,7 +811,7 @@ func Test_FlushExcludedContentTypes(t *testing.T) { } }) - h := mustNewCompressionHandler(t, cfg, next) + h := mustNewCompressionHandler(t, cfg, zstdName, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req.Header.Set(acceptEncoding, zstdName) @@ -903,7 +911,6 @@ func Test_FlushIncludedContentTypes(t *testing.T) { cfg := Config{ MinSize: 1024, IncludedContentTypes: test.includedContentTypes, - Algorithm: zstdName, } next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -928,7 +935,7 @@ func Test_FlushIncludedContentTypes(t *testing.T) { } }) - h := mustNewCompressionHandler(t, cfg, next) + h := mustNewCompressionHandler(t, cfg, zstdName, next) req, _ := http.NewRequest(http.MethodGet, "/whatever", nil) req.Header.Set(acceptEncoding, zstdName) @@ -959,10 +966,26 @@ func Test_FlushIncludedContentTypes(t *testing.T) { } } -func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler { +func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler { t.Helper() - w, err := NewCompressionHandler(cfg, next) + var writer NewCompressionWriter + switch algo { + case zstdName: + writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) { + writer, err := zstd.NewWriter(rw) + require.NoError(t, err) + return writer, zstdName, nil + } + case brotliName: + writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) { + return brotli.NewWriter(rw), brotliName, nil + } + default: + assert.Failf(t, "unknown compression algorithm: %s", algo) + } + + w, err := NewCompressionHandler(cfg, writer, next) require.NoError(t, err) return w @@ -981,7 +1004,7 @@ func newTestBrotliHandler(t *testing.T, body []byte) http.Handler { require.NoError(t, err) }) - return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next) + return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next) } func newTestZstandardHandler(t *testing.T, body []byte) http.Handler { @@ -997,7 +1020,7 @@ func newTestZstandardHandler(t *testing.T, body []byte) http.Handler { require.NoError(t, err) }) - return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next) + return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next) } func Test_ParseContentType_equals(t *testing.T) {