Add support for Zstandard to the Compression middleware
This commit is contained in:
parent
3f48e6f8ef
commit
b795f128d7
7 changed files with 576 additions and 213 deletions
|
@ -10,7 +10,7 @@ Compress Allows Compressing Responses before Sending them to the Client
|
|||
|
||||
![Compress](../../assets/img/middleware/compress.png)
|
||||
|
||||
The Compress middleware supports gzip and Brotli compression.
|
||||
The Compress middleware supports Gzip, Brotli and Zstandard compression.
|
||||
The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header.
|
||||
|
||||
## Configuration Examples
|
||||
|
@ -54,8 +54,8 @@ http:
|
|||
|
||||
Responses are compressed when the following criteria are all met:
|
||||
|
||||
* The `Accept-Encoding` request header contains `gzip`, `*`, and/or `br` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
|
||||
If the `Accept-Encoding` request header is absent, the response won't be encoded.
|
||||
* The `Accept-Encoding` request header contains `gzip`, and/or `*`, and/or `br`, and/or `zstd` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
|
||||
If the `Accept-Encoding` request header is absent and no [defaultEncoding](#defaultencoding) is configured, the response won't be encoded.
|
||||
If it is present, but its value is the empty string, then compression is disabled.
|
||||
* The response is not already compressed, i.e. the `Content-Encoding` response header is not already set.
|
||||
* The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes).
|
||||
|
|
|
@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding"
|
|||
const (
|
||||
brotliName = "br"
|
||||
gzipName = "gzip"
|
||||
zstdName = "zstd"
|
||||
identityName = "identity"
|
||||
wildcardName = "*"
|
||||
notAcceptable = "not_acceptable"
|
||||
|
@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string {
|
|||
return encoding.Type
|
||||
}
|
||||
|
||||
for _, dt := range []string{brotliName, gzipName} {
|
||||
for _, dt := range []string{zstdName, brotliName, gzipName} {
|
||||
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
|
||||
return dt
|
||||
}
|
||||
|
@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) {
|
|||
}
|
||||
|
||||
switch parsed[0] {
|
||||
case brotliName, gzipName, identityName, wildcardName:
|
||||
case zstdName, brotliName, gzipName, identityName, wildcardName:
|
||||
// supported encoding
|
||||
default:
|
||||
continue
|
||||
|
|
|
@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) {
|
|||
values: []string{"gzip, br"},
|
||||
expected: brotliName,
|
||||
},
|
||||
{
|
||||
desc: "zstd > br > gzip (no weight)",
|
||||
values: []string{"zstd, gzip, br"},
|
||||
expected: zstdName,
|
||||
},
|
||||
{
|
||||
desc: "known compression type (no weight)",
|
||||
values: []string{"compress, gzip"},
|
||||
|
@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) {
|
|||
values: []string{"compress;q=1.0, gzip;q=0.5"},
|
||||
expected: gzipName,
|
||||
},
|
||||
{
|
||||
desc: "fallback on non-zero compression type",
|
||||
values: []string{"compress;q=1.0, gzip, identity;q=0"},
|
||||
expected: gzipName,
|
||||
},
|
||||
{
|
||||
desc: "not acceptable (identity)",
|
||||
values: []string{"compress;q=1.0, identity;q=0"},
|
||||
|
@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
desc: "weight",
|
||||
values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"},
|
||||
values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"},
|
||||
expected: []Encoding{
|
||||
{Type: brotliName, Weight: ptr[float64](1)},
|
||||
{Type: zstdName, Weight: ptr(0.9)},
|
||||
{Type: gzipName, Weight: ptr(0.8)},
|
||||
{Type: wildcardName, Weight: ptr(0.1)},
|
||||
},
|
||||
|
@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "mixed",
|
||||
values: []string{"gzip, br;q=1.0, *;q=0"},
|
||||
values: []string{"zstd,gzip, br;q=1.0, *;q=0"},
|
||||
expected: []Encoding{
|
||||
{Type: brotliName, Weight: ptr[float64](1)},
|
||||
{Type: zstdName},
|
||||
{Type: gzipName},
|
||||
{Type: wildcardName, Weight: ptr[float64](0)},
|
||||
},
|
||||
|
@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "no weight",
|
||||
values: []string{"gzip, br, *"},
|
||||
values: []string{"zstd, gzip, br, *"},
|
||||
expected: []Encoding{
|
||||
{Type: zstdName},
|
||||
{Type: gzipName},
|
||||
{Type: brotliName},
|
||||
{Type: wildcardName},
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/klauspost/compress/gzhttp"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
|
@ -32,6 +31,7 @@ type compress struct {
|
|||
|
||||
brotliHandler http.Handler
|
||||
gzipHandler http.Handler
|
||||
zstdHandler http.Handler
|
||||
}
|
||||
|
||||
// New creates a new compress middleware.
|
||||
|
@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
|
|||
}
|
||||
|
||||
var err error
|
||||
c.brotliHandler, err = c.newBrotliHandler()
|
||||
|
||||
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
|
||||
switch typ {
|
||||
case zstdName:
|
||||
c.zstdHandler.ServeHTTP(rw, req)
|
||||
case brotliName:
|
||||
c.brotliHandler.ServeHTTP(rw, req)
|
||||
case gzipName:
|
||||
|
@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
|
|||
return wrapper(c.next), nil
|
||||
}
|
||||
|
||||
func (c *compress) newBrotliHandler() (http.Handler, error) {
|
||||
cfg := brotli.Config{MinSize: c.minSize}
|
||||
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
|
||||
if len(c.includes) > 0 {
|
||||
cfg.IncludedContentTypes = c.includes
|
||||
} else {
|
||||
cfg.ExcludedContentTypes = c.excludes
|
||||
}
|
||||
|
||||
wrapper, err := brotli.NewWrapper(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new brotli wrapper: %w", err)
|
||||
}
|
||||
|
||||
return wrapper(c.next), nil
|
||||
return NewCompressionHandler(cfg, c.next)
|
||||
}
|
||||
|
|
|
@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) {
|
|||
{
|
||||
desc: "accept any header",
|
||||
acceptEncHeader: "*",
|
||||
expEncoding: "br",
|
||||
expEncoding: brotliName,
|
||||
},
|
||||
{
|
||||
desc: "gzip accept header",
|
||||
acceptEncHeader: "gzip",
|
||||
expEncoding: "gzip",
|
||||
expEncoding: gzipName,
|
||||
},
|
||||
{
|
||||
desc: "br accept header",
|
||||
acceptEncHeader: "br",
|
||||
expEncoding: "br",
|
||||
expEncoding: brotliName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header, prefer br",
|
||||
acceptEncHeader: "br;q=0.8, gzip;q=0.6",
|
||||
expEncoding: "br",
|
||||
expEncoding: brotliName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header, prefer gzip",
|
||||
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
|
||||
expEncoding: "gzip",
|
||||
expEncoding: gzipName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header list, prefer br",
|
||||
acceptEncHeader: "gzip, br",
|
||||
expEncoding: "br",
|
||||
expEncoding: brotliName,
|
||||
},
|
||||
{
|
||||
desc: "zstd accept header",
|
||||
acceptEncHeader: "zstd",
|
||||
expEncoding: zstdName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header, prefer zstd",
|
||||
acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6",
|
||||
expEncoding: zstdName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header, prefer gzip",
|
||||
acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7",
|
||||
expEncoding: gzipName,
|
||||
},
|
||||
{
|
||||
desc: "multi accept header list, prefer zstd",
|
||||
acceptEncHeader: "gzip, br, zstd",
|
||||
expEncoding: zstdName,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package brotli
|
||||
package compress
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
@ -10,6 +10,9 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,10 +33,26 @@ type Config struct {
|
|||
IncludedContentTypes []string
|
||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||
MinSize int
|
||||
// Algorithm used for the compression (currently Brotli and Zstandard)
|
||||
Algorithm string
|
||||
// MiddlewareName use for logging purposes
|
||||
MiddlewareName string
|
||||
}
|
||||
|
||||
// NewWrapper returns a new Brotli compressing wrapper.
|
||||
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
||||
// CompressionHandler handles Brolti and Zstd compression.
|
||||
type CompressionHandler struct {
|
||||
cfg Config
|
||||
excludedContentTypes []parsedContentType
|
||||
includedContentTypes []parsedContentType
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewCompressionHandler returns a new compressing handler.
|
||||
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
|
||||
if cfg.Algorithm == "" {
|
||||
return nil, errors.New("compression algorithm undefined")
|
||||
}
|
||||
|
||||
if cfg.MinSize < 0 {
|
||||
return nil, errors.New("minimum size must be greater than or equal to zero")
|
||||
}
|
||||
|
@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
|||
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
|
||||
}
|
||||
|
||||
return func(h http.Handler) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(vary, acceptEncoding)
|
||||
|
||||
brw := &responseWriter{
|
||||
rw: rw,
|
||||
bw: brotli.NewWriter(rw),
|
||||
minSize: cfg.MinSize,
|
||||
statusCode: http.StatusOK,
|
||||
return &CompressionHandler{
|
||||
cfg: cfg,
|
||||
excludedContentTypes: excludedContentTypes,
|
||||
includedContentTypes: includedContentTypes,
|
||||
}
|
||||
defer brw.close()
|
||||
|
||||
h.ServeHTTP(brw, r)
|
||||
}
|
||||
next: next,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(vary, acceptEncoding)
|
||||
|
||||
compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw)
|
||||
if err != nil {
|
||||
logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName)
|
||||
logMessage := fmt.Sprintf("create compression handler: %v", err)
|
||||
logger.Debug().Msg(logMessage)
|
||||
observability.SetStatusErrorf(r.Context(), logMessage)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
responseWriter := &responseWriter{
|
||||
rw: rw,
|
||||
compressionWriter: compressionWriter,
|
||||
minSize: c.cfg.MinSize,
|
||||
statusCode: http.StatusOK,
|
||||
excludedContentTypes: c.excludedContentTypes,
|
||||
includedContentTypes: c.includedContentTypes,
|
||||
}
|
||||
defer responseWriter.close()
|
||||
|
||||
c.next.ServeHTTP(responseWriter, r)
|
||||
}
|
||||
|
||||
type compression interface {
|
||||
// Write data to the encoder.
|
||||
// Input data will be buffered and as the buffer fills up
|
||||
// content will be compressed and written to the output.
|
||||
// When done writing, use Close to flush the remaining output
|
||||
// and write CRC if requested.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Flush will send the currently written data to output
|
||||
// and block until everything has been written.
|
||||
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||
Flush() error
|
||||
// Close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type compressionWriter struct {
|
||||
compression
|
||||
alg string
|
||||
}
|
||||
|
||||
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
|
||||
switch algo {
|
||||
case brotliName:
|
||||
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
|
||||
|
||||
case zstdName:
|
||||
writer, err := zstd.NewWriter(in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating zstd writer: %w", err)
|
||||
}
|
||||
return &compressionWriter{compression: writer, alg: algo}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown compression algo: %s", algo)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compressionWriter) ContentEncoding() string {
|
||||
return c.alg
|
||||
}
|
||||
|
||||
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
||||
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
||||
type responseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
bw *brotli.Writer
|
||||
compressionWriter *compressionWriter
|
||||
|
||||
minSize int
|
||||
excludedContentTypes []parsedContentType
|
||||
|
@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
|||
// We are now in compression cruise mode until the end of times.
|
||||
if r.compressionStarted {
|
||||
// If compressionStarted we assume we have sent headers already
|
||||
return r.bw.Write(p)
|
||||
return r.compressionWriter.Write(p)
|
||||
}
|
||||
|
||||
// If we detect a contentEncoding, we know we are never going to compress.
|
||||
|
@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
|||
// Since we know we are going to compress we will never be able to know the actual length.
|
||||
r.rw.Header().Del(contentLength)
|
||||
|
||||
r.rw.Header().Set(contentEncoding, "br")
|
||||
r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding())
|
||||
r.rw.WriteHeader(r.statusCode)
|
||||
r.headersSent = true
|
||||
|
||||
// Start with sending what we have previously buffered, before actually writing
|
||||
// the bytes in argument.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
n, err := r.compressionWriter.Write(r.buf)
|
||||
if err != nil {
|
||||
r.buf = r.buf[n:]
|
||||
// Return zero because we haven't taken care of the bytes in argument yet.
|
||||
|
@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
|||
r.buf = r.buf[:0]
|
||||
|
||||
// Now that we emptied the buffer, we can actually write the given bytes.
|
||||
return r.bw.Write(p)
|
||||
return r.compressionWriter.Write(p)
|
||||
}
|
||||
|
||||
// Flush flushes data to the appropriate underlying writer(s), although it does
|
||||
|
@ -250,7 +328,7 @@ func (r *responseWriter) Flush() {
|
|||
// we have to do it ourselves.
|
||||
defer func() {
|
||||
// because we also ignore the error returned by Write anyway
|
||||
_ = r.bw.Flush()
|
||||
_ = r.compressionWriter.Flush()
|
||||
|
||||
if rw, ok := r.rw.(http.Flusher); ok {
|
||||
rw.Flush()
|
||||
|
@ -258,7 +336,7 @@ func (r *responseWriter) Flush() {
|
|||
}()
|
||||
|
||||
// We empty whatever is left of the buffer that Write never took care of.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
n, err := r.compressionWriter.Write(r.buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -313,7 +391,7 @@ func (r *responseWriter) close() error {
|
|||
|
||||
if len(r.buf) == 0 {
|
||||
// If we got here we know compression has started, so we can safely flush on bw.
|
||||
return r.bw.Close()
|
||||
return r.compressionWriter.Close()
|
||||
}
|
||||
|
||||
// There is still data in the buffer, because we never reached minSize (to
|
||||
|
@ -331,16 +409,16 @@ func (r *responseWriter) close() error {
|
|||
|
||||
// There is still data in the buffer, simply because Write did not take care of it all.
|
||||
// We flush it to the compressed writer.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
n, err := r.compressionWriter.Write(r.buf)
|
||||
if err != nil {
|
||||
r.bw.Close()
|
||||
r.compressionWriter.Close()
|
||||
return err
|
||||
}
|
||||
if n < len(r.buf) {
|
||||
r.bw.Close()
|
||||
r.compressionWriter.Close()
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return r.bw.Close()
|
||||
return r.compressionWriter.Close()
|
||||
}
|
||||
|
||||
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.
|
|
@ -1,4 +1,4 @@
|
|||
package brotli
|
||||
package compress
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -19,44 +20,107 @@ var (
|
|||
)
|
||||
|
||||
func Test_Vary(t *testing.T) {
|
||||
h := newTestHandler(t, smallTestBody)
|
||||
testCases := []struct {
|
||||
desc string
|
||||
h http.Handler
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
h: newTestBrotliHandler(t, smallTestBody),
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
h: newTestZstandardHandler(t, smallTestBody),
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
test.h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SmallBodyNoCompression(t *testing.T) {
|
||||
h := newTestHandler(t, smallTestBody)
|
||||
testCases := []struct {
|
||||
desc string
|
||||
h http.Handler
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
h: newTestBrotliHandler(t, smallTestBody),
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
h: newTestZstandardHandler(t, smallTestBody),
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
test.h.ServeHTTP(rw, req)
|
||||
|
||||
// With less than 1024 bytes the response should not be compressed.
|
||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||
assert.Empty(t, rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, smallTestBody, rw.Body.Bytes())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AlreadyCompressed(t *testing.T) {
|
||||
h := newTestHandler(t, bigTestBody)
|
||||
testCases := []struct {
|
||||
desc string
|
||||
h http.Handler
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
h: newTestBrotliHandler(t, bigTestBody),
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
h: newTestZstandardHandler(t, bigTestBody),
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
test.h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||
assert.Equal(t, bigTestBody, rw.Body.Bytes())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NoBody(t *testing.T) {
|
||||
|
@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(test.statusCode)
|
||||
|
||||
_, err := rw.Write(test.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
@ -116,11 +182,12 @@ func Test_NoBody(t *testing.T) {
|
|||
func Test_MinSize(t *testing.T) {
|
||||
cfg := Config{
|
||||
MinSize: 128,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
var bodySize int
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
|
||||
func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for range bodySize {
|
||||
// We make sure to Write at least once less than minSize so that both
|
||||
// cases below go through the same algo: i.e. they start buffering
|
||||
|
@ -128,11 +195,12 @@ func Test_MinSize(t *testing.T) {
|
|||
_, err := rw.Write([]byte{'x'})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
},
|
||||
))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
||||
req.Header.Add(acceptEncoding, "br")
|
||||
req.Header.Add(acceptEncoding, "zstd")
|
||||
|
||||
// Short response is not compressed
|
||||
bodySize = cfg.MinSize - 1
|
||||
|
@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) {
|
|||
rw = httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding))
|
||||
assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
|
||||
}
|
||||
|
||||
func Test_MultipleWriteHeader(t *testing.T) {
|
||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
// We ensure that the subsequent call to WriteHeader is a noop.
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
@ -166,19 +236,49 @@ func Test_MultipleWriteHeader(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_FlushBeforeWrite(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.(http.Flusher).Flush()
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
})))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
@ -186,15 +286,46 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
reader, err := test.readerBuilder(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FlushAfterWrite(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(bigTestBody[0:1])
|
||||
|
@ -205,13 +336,15 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
_, err := rw.Write([]byte{b})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
@ -219,28 +352,61 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
reader, err := test.readerBuilder(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FlushAfterWriteNil(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw.(http.Flusher).Flush()
|
||||
})))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
@ -250,25 +416,58 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Empty(t, res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
reader, err := test.readerBuilder(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FlushAfterAllWrites(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
acceptEncoding: "br",
|
||||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
acceptEncoding: "zstd",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for i := range bigTestBody {
|
||||
_, err := rw.Write(bigTestBody[i : i+1])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
rw.(http.Flusher).Flush()
|
||||
})))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
@ -276,11 +475,16 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
reader, err := test.readerBuilder(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ExcludedContentTypes(t *testing.T) {
|
||||
|
@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
reader, err := zstd.NewReader(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
||||
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
|
@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
reader, err := zstd.NewReader(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
||||
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
|
@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
|
@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
tb = tb[toWrite:]
|
||||
}
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||
rw := httptest.NewRecorder()
|
||||
|
@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
reader, err := zstd.NewReader(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
|
@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
|
@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
tb = tb[toWrite:]
|
||||
}
|
||||
}))
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||
rw := httptest.NewRecorder()
|
||||
|
@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
reader, err := zstd.NewReader(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := io.ReadAll(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
||||
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
|
@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
|
||||
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
w, err := NewWrapper(cfg)
|
||||
w, err := NewCompressionHandler(cfg, next)
|
||||
require.NoError(t, err)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func newTestHandler(t *testing.T, body []byte) http.Handler {
|
||||
func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
return mustNewWrapper(t, Config{MinSize: 1024})(
|
||||
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/compressed" {
|
||||
rw.Header().Set("Content-Encoding", "br")
|
||||
rw.Header().Set("Content-Encoding", brotliName)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
_, err := rw.Write(body)
|
||||
require.NoError(t, err)
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
|
||||
}
|
||||
|
||||
func TestParseContentType_equals(t *testing.T) {
|
||||
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/compressed" {
|
||||
rw.Header().Set("Content-Encoding", zstdName)
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
_, err := rw.Write(body)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
|
||||
}
|
||||
|
||||
func Test_ParseContentType_equals(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
pct parsedContentType
|
Loading…
Reference in a new issue