Add support for Zstandard to the Compression middleware
This commit is contained in:
parent
3f48e6f8ef
commit
b795f128d7
7 changed files with 576 additions and 213 deletions
|
@ -10,7 +10,7 @@ Compress Allows Compressing Responses before Sending them to the Client
|
||||||
|
|
||||||
![Compress](../../assets/img/middleware/compress.png)
|
![Compress](../../assets/img/middleware/compress.png)
|
||||||
|
|
||||||
The Compress middleware supports gzip and Brotli compression.
|
The Compress middleware supports Gzip, Brotli and Zstandard compression.
|
||||||
The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header.
|
The activation of compression, and the compression method choice rely (among other things) on the request's `Accept-Encoding` header.
|
||||||
|
|
||||||
## Configuration Examples
|
## Configuration Examples
|
||||||
|
@ -54,8 +54,8 @@ http:
|
||||||
|
|
||||||
Responses are compressed when the following criteria are all met:
|
Responses are compressed when the following criteria are all met:
|
||||||
|
|
||||||
* The `Accept-Encoding` request header contains `gzip`, `*`, and/or `br` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
|
* The `Accept-Encoding` request header contains `gzip`, and/or `*`, and/or `br`, and/or `zstd` with or without [quality values](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values).
|
||||||
If the `Accept-Encoding` request header is absent, the response won't be encoded.
|
If the `Accept-Encoding` request header is absent and no [defaultEncoding](#defaultencoding) is configured, the response won't be encoded.
|
||||||
If it is present, but its value is the empty string, then compression is disabled.
|
If it is present, but its value is the empty string, then compression is disabled.
|
||||||
* The response is not already compressed, i.e. the `Content-Encoding` response header is not already set.
|
* The response is not already compressed, i.e. the `Content-Encoding` response header is not already set.
|
||||||
* The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes).
|
* The response`Content-Type` header is not one among the [excludedContentTypes options](#excludedcontenttypes), or is one among the [includedContentTypes options](#includedcontenttypes).
|
||||||
|
|
|
@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding"
|
||||||
const (
|
const (
|
||||||
brotliName = "br"
|
brotliName = "br"
|
||||||
gzipName = "gzip"
|
gzipName = "gzip"
|
||||||
|
zstdName = "zstd"
|
||||||
identityName = "identity"
|
identityName = "identity"
|
||||||
wildcardName = "*"
|
wildcardName = "*"
|
||||||
notAcceptable = "not_acceptable"
|
notAcceptable = "not_acceptable"
|
||||||
|
@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string {
|
||||||
return encoding.Type
|
return encoding.Type
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dt := range []string{brotliName, gzipName} {
|
for _, dt := range []string{zstdName, brotliName, gzipName} {
|
||||||
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
|
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
|
||||||
return dt
|
return dt
|
||||||
}
|
}
|
||||||
|
@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch parsed[0] {
|
switch parsed[0] {
|
||||||
case brotliName, gzipName, identityName, wildcardName:
|
case zstdName, brotliName, gzipName, identityName, wildcardName:
|
||||||
// supported encoding
|
// supported encoding
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) {
|
||||||
values: []string{"gzip, br"},
|
values: []string{"gzip, br"},
|
||||||
expected: brotliName,
|
expected: brotliName,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd > br > gzip (no weight)",
|
||||||
|
values: []string{"zstd, gzip, br"},
|
||||||
|
expected: zstdName,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "known compression type (no weight)",
|
desc: "known compression type (no weight)",
|
||||||
values: []string{"compress, gzip"},
|
values: []string{"compress, gzip"},
|
||||||
|
@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) {
|
||||||
values: []string{"compress;q=1.0, gzip;q=0.5"},
|
values: []string{"compress;q=1.0, gzip;q=0.5"},
|
||||||
expected: gzipName,
|
expected: gzipName,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "fallback on non-zero compression type",
|
||||||
|
values: []string{"compress;q=1.0, gzip, identity;q=0"},
|
||||||
|
expected: gzipName,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "not acceptable (identity)",
|
desc: "not acceptable (identity)",
|
||||||
values: []string{"compress;q=1.0, identity;q=0"},
|
values: []string{"compress;q=1.0, identity;q=0"},
|
||||||
|
@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "weight",
|
desc: "weight",
|
||||||
values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"},
|
values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"},
|
||||||
expected: []Encoding{
|
expected: []Encoding{
|
||||||
{Type: brotliName, Weight: ptr[float64](1)},
|
{Type: brotliName, Weight: ptr[float64](1)},
|
||||||
|
{Type: zstdName, Weight: ptr(0.9)},
|
||||||
{Type: gzipName, Weight: ptr(0.8)},
|
{Type: gzipName, Weight: ptr(0.8)},
|
||||||
{Type: wildcardName, Weight: ptr(0.1)},
|
{Type: wildcardName, Weight: ptr(0.1)},
|
||||||
},
|
},
|
||||||
|
@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "mixed",
|
desc: "mixed",
|
||||||
values: []string{"gzip, br;q=1.0, *;q=0"},
|
values: []string{"zstd,gzip, br;q=1.0, *;q=0"},
|
||||||
expected: []Encoding{
|
expected: []Encoding{
|
||||||
{Type: brotliName, Weight: ptr[float64](1)},
|
{Type: brotliName, Weight: ptr[float64](1)},
|
||||||
|
{Type: zstdName},
|
||||||
{Type: gzipName},
|
{Type: gzipName},
|
||||||
{Type: wildcardName, Weight: ptr[float64](0)},
|
{Type: wildcardName, Weight: ptr[float64](0)},
|
||||||
},
|
},
|
||||||
|
@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "no weight",
|
desc: "no weight",
|
||||||
values: []string{"gzip, br, *"},
|
values: []string{"zstd, gzip, br, *"},
|
||||||
expected: []Encoding{
|
expected: []Encoding{
|
||||||
|
{Type: zstdName},
|
||||||
{Type: gzipName},
|
{Type: gzipName},
|
||||||
{Type: brotliName},
|
{Type: brotliName},
|
||||||
{Type: wildcardName},
|
{Type: wildcardName},
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"github.com/klauspost/compress/gzhttp"
|
"github.com/klauspost/compress/gzhttp"
|
||||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli"
|
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,6 +31,7 @@ type compress struct {
|
||||||
|
|
||||||
brotliHandler http.Handler
|
brotliHandler http.Handler
|
||||||
gzipHandler http.Handler
|
gzipHandler http.Handler
|
||||||
|
zstdHandler http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new compress middleware.
|
// New creates a new compress middleware.
|
||||||
|
@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.brotliHandler, err = c.newBrotliHandler()
|
|
||||||
|
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
|
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
|
||||||
switch typ {
|
switch typ {
|
||||||
|
case zstdName:
|
||||||
|
c.zstdHandler.ServeHTTP(rw, req)
|
||||||
case brotliName:
|
case brotliName:
|
||||||
c.brotliHandler.ServeHTTP(rw, req)
|
c.brotliHandler.ServeHTTP(rw, req)
|
||||||
case gzipName:
|
case gzipName:
|
||||||
|
@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
|
||||||
return wrapper(c.next), nil
|
return wrapper(c.next), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *compress) newBrotliHandler() (http.Handler, error) {
|
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
|
||||||
cfg := brotli.Config{MinSize: c.minSize}
|
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
|
||||||
if len(c.includes) > 0 {
|
if len(c.includes) > 0 {
|
||||||
cfg.IncludedContentTypes = c.includes
|
cfg.IncludedContentTypes = c.includes
|
||||||
} else {
|
} else {
|
||||||
cfg.ExcludedContentTypes = c.excludes
|
cfg.ExcludedContentTypes = c.excludes
|
||||||
}
|
}
|
||||||
|
|
||||||
wrapper, err := brotli.NewWrapper(cfg)
|
return NewCompressionHandler(cfg, c.next)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("new brotli wrapper: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return wrapper(c.next), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) {
|
||||||
{
|
{
|
||||||
desc: "accept any header",
|
desc: "accept any header",
|
||||||
acceptEncHeader: "*",
|
acceptEncHeader: "*",
|
||||||
expEncoding: "br",
|
expEncoding: brotliName,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "gzip accept header",
|
desc: "gzip accept header",
|
||||||
acceptEncHeader: "gzip",
|
acceptEncHeader: "gzip",
|
||||||
expEncoding: "gzip",
|
expEncoding: gzipName,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "br accept header",
|
desc: "br accept header",
|
||||||
acceptEncHeader: "br",
|
acceptEncHeader: "br",
|
||||||
expEncoding: "br",
|
expEncoding: brotliName,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "multi accept header, prefer br",
|
desc: "multi accept header, prefer br",
|
||||||
acceptEncHeader: "br;q=0.8, gzip;q=0.6",
|
acceptEncHeader: "br;q=0.8, gzip;q=0.6",
|
||||||
expEncoding: "br",
|
expEncoding: brotliName,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "multi accept header, prefer gzip",
|
desc: "multi accept header, prefer gzip",
|
||||||
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
|
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
|
||||||
expEncoding: "gzip",
|
expEncoding: gzipName,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "multi accept header list, prefer br",
|
desc: "multi accept header list, prefer br",
|
||||||
acceptEncHeader: "gzip, br",
|
acceptEncHeader: "gzip, br",
|
||||||
expEncoding: "br",
|
expEncoding: brotliName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd accept header",
|
||||||
|
acceptEncHeader: "zstd",
|
||||||
|
expEncoding: zstdName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multi accept header, prefer zstd",
|
||||||
|
acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6",
|
||||||
|
expEncoding: zstdName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multi accept header, prefer gzip",
|
||||||
|
acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7",
|
||||||
|
expEncoding: gzipName,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multi accept header list, prefer zstd",
|
||||||
|
acceptEncHeader: "gzip, br, zstd",
|
||||||
|
expEncoding: zstdName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package brotli
|
package compress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -10,6 +10,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/andybalholm/brotli"
|
"github.com/andybalholm/brotli"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
|
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||||
|
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -30,10 +33,26 @@ type Config struct {
|
||||||
IncludedContentTypes []string
|
IncludedContentTypes []string
|
||||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||||
MinSize int
|
MinSize int
|
||||||
|
// Algorithm used for the compression (currently Brotli and Zstandard)
|
||||||
|
Algorithm string
|
||||||
|
// MiddlewareName use for logging purposes
|
||||||
|
MiddlewareName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompressionHandler handles Brolti and Zstd compression.
|
||||||
|
type CompressionHandler struct {
|
||||||
|
cfg Config
|
||||||
|
excludedContentTypes []parsedContentType
|
||||||
|
includedContentTypes []parsedContentType
|
||||||
|
next http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCompressionHandler returns a new compressing handler.
|
||||||
|
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
|
||||||
|
if cfg.Algorithm == "" {
|
||||||
|
return nil, errors.New("compression algorithm undefined")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWrapper returns a new Brotli compressing wrapper.
|
|
||||||
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
|
||||||
if cfg.MinSize < 0 {
|
if cfg.MinSize < 0 {
|
||||||
return nil, errors.New("minimum size must be greater than or equal to zero")
|
return nil, errors.New("minimum size must be greater than or equal to zero")
|
||||||
}
|
}
|
||||||
|
@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
||||||
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
|
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(h http.Handler) http.HandlerFunc {
|
return &CompressionHandler{
|
||||||
return func(rw http.ResponseWriter, r *http.Request) {
|
cfg: cfg,
|
||||||
rw.Header().Add(vary, acceptEncoding)
|
|
||||||
|
|
||||||
brw := &responseWriter{
|
|
||||||
rw: rw,
|
|
||||||
bw: brotli.NewWriter(rw),
|
|
||||||
minSize: cfg.MinSize,
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
excludedContentTypes: excludedContentTypes,
|
excludedContentTypes: excludedContentTypes,
|
||||||
includedContentTypes: includedContentTypes,
|
includedContentTypes: includedContentTypes,
|
||||||
}
|
next: next,
|
||||||
defer brw.close()
|
|
||||||
|
|
||||||
h.ServeHTTP(brw, r)
|
|
||||||
}
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.Header().Add(vary, acceptEncoding)
|
||||||
|
|
||||||
|
compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw)
|
||||||
|
if err != nil {
|
||||||
|
logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName)
|
||||||
|
logMessage := fmt.Sprintf("create compression handler: %v", err)
|
||||||
|
logger.Debug().Msg(logMessage)
|
||||||
|
observability.SetStatusErrorf(r.Context(), logMessage)
|
||||||
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseWriter := &responseWriter{
|
||||||
|
rw: rw,
|
||||||
|
compressionWriter: compressionWriter,
|
||||||
|
minSize: c.cfg.MinSize,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
excludedContentTypes: c.excludedContentTypes,
|
||||||
|
includedContentTypes: c.includedContentTypes,
|
||||||
|
}
|
||||||
|
defer responseWriter.close()
|
||||||
|
|
||||||
|
c.next.ServeHTTP(responseWriter, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type compression interface {
|
||||||
|
// Write data to the encoder.
|
||||||
|
// Input data will be buffered and as the buffer fills up
|
||||||
|
// content will be compressed and written to the output.
|
||||||
|
// When done writing, use Close to flush the remaining output
|
||||||
|
// and write CRC if requested.
|
||||||
|
Write(p []byte) (n int, err error)
|
||||||
|
// Flush will send the currently written data to output
|
||||||
|
// and block until everything has been written.
|
||||||
|
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||||
|
Flush() error
|
||||||
|
// Close closes the underlying writers if/when appropriate.
|
||||||
|
// Note that the compressed writer should not be closed if we never used it,
|
||||||
|
// as it would otherwise send some extra "end of compression" bytes.
|
||||||
|
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type compressionWriter struct {
|
||||||
|
compression
|
||||||
|
alg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
|
||||||
|
switch algo {
|
||||||
|
case brotliName:
|
||||||
|
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
|
||||||
|
|
||||||
|
case zstdName:
|
||||||
|
writer, err := zstd.NewWriter(in)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating zstd writer: %w", err)
|
||||||
|
}
|
||||||
|
return &compressionWriter{compression: writer, alg: algo}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown compression algo: %s", algo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *compressionWriter) ContentEncoding() string {
|
||||||
|
return c.alg
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
||||||
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
||||||
type responseWriter struct {
|
type responseWriter struct {
|
||||||
rw http.ResponseWriter
|
rw http.ResponseWriter
|
||||||
bw *brotli.Writer
|
compressionWriter *compressionWriter
|
||||||
|
|
||||||
minSize int
|
minSize int
|
||||||
excludedContentTypes []parsedContentType
|
excludedContentTypes []parsedContentType
|
||||||
|
@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
||||||
// We are now in compression cruise mode until the end of times.
|
// We are now in compression cruise mode until the end of times.
|
||||||
if r.compressionStarted {
|
if r.compressionStarted {
|
||||||
// If compressionStarted we assume we have sent headers already
|
// If compressionStarted we assume we have sent headers already
|
||||||
return r.bw.Write(p)
|
return r.compressionWriter.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we detect a contentEncoding, we know we are never going to compress.
|
// If we detect a contentEncoding, we know we are never going to compress.
|
||||||
|
@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
||||||
// Since we know we are going to compress we will never be able to know the actual length.
|
// Since we know we are going to compress we will never be able to know the actual length.
|
||||||
r.rw.Header().Del(contentLength)
|
r.rw.Header().Del(contentLength)
|
||||||
|
|
||||||
r.rw.Header().Set(contentEncoding, "br")
|
r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding())
|
||||||
r.rw.WriteHeader(r.statusCode)
|
r.rw.WriteHeader(r.statusCode)
|
||||||
r.headersSent = true
|
r.headersSent = true
|
||||||
|
|
||||||
// Start with sending what we have previously buffered, before actually writing
|
// Start with sending what we have previously buffered, before actually writing
|
||||||
// the bytes in argument.
|
// the bytes in argument.
|
||||||
n, err := r.bw.Write(r.buf)
|
n, err := r.compressionWriter.Write(r.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.buf = r.buf[n:]
|
r.buf = r.buf[n:]
|
||||||
// Return zero because we haven't taken care of the bytes in argument yet.
|
// Return zero because we haven't taken care of the bytes in argument yet.
|
||||||
|
@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
||||||
r.buf = r.buf[:0]
|
r.buf = r.buf[:0]
|
||||||
|
|
||||||
// Now that we emptied the buffer, we can actually write the given bytes.
|
// Now that we emptied the buffer, we can actually write the given bytes.
|
||||||
return r.bw.Write(p)
|
return r.compressionWriter.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush flushes data to the appropriate underlying writer(s), although it does
|
// Flush flushes data to the appropriate underlying writer(s), although it does
|
||||||
|
@ -250,7 +328,7 @@ func (r *responseWriter) Flush() {
|
||||||
// we have to do it ourselves.
|
// we have to do it ourselves.
|
||||||
defer func() {
|
defer func() {
|
||||||
// because we also ignore the error returned by Write anyway
|
// because we also ignore the error returned by Write anyway
|
||||||
_ = r.bw.Flush()
|
_ = r.compressionWriter.Flush()
|
||||||
|
|
||||||
if rw, ok := r.rw.(http.Flusher); ok {
|
if rw, ok := r.rw.(http.Flusher); ok {
|
||||||
rw.Flush()
|
rw.Flush()
|
||||||
|
@ -258,7 +336,7 @@ func (r *responseWriter) Flush() {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// We empty whatever is left of the buffer that Write never took care of.
|
// We empty whatever is left of the buffer that Write never took care of.
|
||||||
n, err := r.bw.Write(r.buf)
|
n, err := r.compressionWriter.Write(r.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -313,7 +391,7 @@ func (r *responseWriter) close() error {
|
||||||
|
|
||||||
if len(r.buf) == 0 {
|
if len(r.buf) == 0 {
|
||||||
// If we got here we know compression has started, so we can safely flush on bw.
|
// If we got here we know compression has started, so we can safely flush on bw.
|
||||||
return r.bw.Close()
|
return r.compressionWriter.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// There is still data in the buffer, because we never reached minSize (to
|
// There is still data in the buffer, because we never reached minSize (to
|
||||||
|
@ -331,16 +409,16 @@ func (r *responseWriter) close() error {
|
||||||
|
|
||||||
// There is still data in the buffer, simply because Write did not take care of it all.
|
// There is still data in the buffer, simply because Write did not take care of it all.
|
||||||
// We flush it to the compressed writer.
|
// We flush it to the compressed writer.
|
||||||
n, err := r.bw.Write(r.buf)
|
n, err := r.compressionWriter.Write(r.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.bw.Close()
|
r.compressionWriter.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if n < len(r.buf) {
|
if n < len(r.buf) {
|
||||||
r.bw.Close()
|
r.compressionWriter.Close()
|
||||||
return io.ErrShortWrite
|
return io.ErrShortWrite
|
||||||
}
|
}
|
||||||
return r.bw.Close()
|
return r.compressionWriter.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.
|
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.
|
|
@ -1,4 +1,4 @@
|
||||||
package brotli
|
package compress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -9,6 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/andybalholm/brotli"
|
"github.com/andybalholm/brotli"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -19,44 +20,107 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Vary(t *testing.T) {
|
func Test_Vary(t *testing.T) {
|
||||||
h := newTestHandler(t, smallTestBody)
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
h http.Handler
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
h: newTestBrotliHandler(t, smallTestBody),
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
h: newTestZstandardHandler(t, smallTestBody),
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
test.h.ServeHTTP(rw, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||||
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
|
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SmallBodyNoCompression(t *testing.T) {
|
func Test_SmallBodyNoCompression(t *testing.T) {
|
||||||
h := newTestHandler(t, smallTestBody)
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
h http.Handler
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
h: newTestBrotliHandler(t, smallTestBody),
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
h: newTestZstandardHandler(t, smallTestBody),
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
test.h.ServeHTTP(rw, req)
|
||||||
|
|
||||||
// With less than 1024 bytes the response should not be compressed.
|
// With less than 1024 bytes the response should not be compressed.
|
||||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||||
assert.Empty(t, rw.Header().Get(contentEncoding))
|
assert.Empty(t, rw.Header().Get(contentEncoding))
|
||||||
assert.Equal(t, smallTestBody, rw.Body.Bytes())
|
assert.Equal(t, smallTestBody, rw.Body.Bytes())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_AlreadyCompressed(t *testing.T) {
|
func Test_AlreadyCompressed(t *testing.T) {
|
||||||
h := newTestHandler(t, bigTestBody)
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
h http.Handler
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
h: newTestBrotliHandler(t, bigTestBody),
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
h: newTestZstandardHandler(t, bigTestBody),
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
test.h.ServeHTTP(rw, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||||
assert.Equal(t, bigTestBody, rw.Body.Bytes())
|
assert.Equal(t, bigTestBody, rw.Body.Bytes())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NoBody(t *testing.T) {
|
func Test_NoBody(t *testing.T) {
|
||||||
|
@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) {
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.WriteHeader(test.statusCode)
|
rw.WriteHeader(test.statusCode)
|
||||||
|
|
||||||
_, err := rw.Write(test.body)
|
_, err := rw.Write(test.body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, "zstd")
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
h.ServeHTTP(rw, req)
|
||||||
|
@ -116,11 +182,12 @@ func Test_NoBody(t *testing.T) {
|
||||||
func Test_MinSize(t *testing.T) {
|
func Test_MinSize(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
MinSize: 128,
|
MinSize: 128,
|
||||||
|
Algorithm: zstdName,
|
||||||
}
|
}
|
||||||
|
|
||||||
var bodySize int
|
var bodySize int
|
||||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
|
|
||||||
func(rw http.ResponseWriter, req *http.Request) {
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
for range bodySize {
|
for range bodySize {
|
||||||
// We make sure to Write at least once less than minSize so that both
|
// We make sure to Write at least once less than minSize so that both
|
||||||
// cases below go through the same algo: i.e. they start buffering
|
// cases below go through the same algo: i.e. they start buffering
|
||||||
|
@ -128,11 +195,12 @@ func Test_MinSize(t *testing.T) {
|
||||||
_, err := rw.Write([]byte{'x'})
|
_, err := rw.Write([]byte{'x'})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
},
|
})
|
||||||
))
|
|
||||||
|
h := mustNewCompressionHandler(t, cfg, next)
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
||||||
req.Header.Add(acceptEncoding, "br")
|
req.Header.Add(acceptEncoding, "zstd")
|
||||||
|
|
||||||
// Short response is not compressed
|
// Short response is not compressed
|
||||||
bodySize = cfg.MinSize - 1
|
bodySize = cfg.MinSize - 1
|
||||||
|
@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) {
|
||||||
rw = httptest.NewRecorder()
|
rw = httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
h.ServeHTTP(rw, req)
|
||||||
|
|
||||||
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding))
|
assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_MultipleWriteHeader(t *testing.T) {
|
func Test_MultipleWriteHeader(t *testing.T) {
|
||||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
// We ensure that the subsequent call to WriteHeader is a noop.
|
// We ensure that the subsequent call to WriteHeader is a noop.
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, "zstd")
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
h.ServeHTTP(rw, req)
|
||||||
|
@ -166,19 +236,49 @@ func Test_MultipleWriteHeader(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_FlushBeforeWrite(t *testing.T) {
|
func Test_FlushBeforeWrite(t *testing.T) {
|
||||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg Config
|
||||||
|
readerBuilder func(io.Reader) (io.Reader, error)
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return brotli.NewReader(reader), nil
|
||||||
|
},
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return zstd.NewReader(reader)
|
||||||
|
},
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
rw.(http.Flusher).Flush()
|
rw.(http.Flusher).Flush()
|
||||||
|
|
||||||
_, err := rw.Write(bigTestBody)
|
_, err := rw.Write(bigTestBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})))
|
})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -186,15 +286,46 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
reader, err := test.readerBuilder(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_FlushAfterWrite(t *testing.T) {
|
func Test_FlushAfterWrite(t *testing.T) {
|
||||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg Config
|
||||||
|
readerBuilder func(io.Reader) (io.Reader, error)
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return brotli.NewReader(reader), nil
|
||||||
|
},
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return zstd.NewReader(reader)
|
||||||
|
},
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
_, err := rw.Write(bigTestBody[0:1])
|
_, err := rw.Write(bigTestBody[0:1])
|
||||||
|
@ -205,13 +336,15 @@ func Test_FlushAfterWrite(t *testing.T) {
|
||||||
_, err := rw.Write([]byte{b})
|
_, err := rw.Write([]byte{b})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
})))
|
})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -219,28 +352,61 @@ func Test_FlushAfterWrite(t *testing.T) {
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
reader, err := test.readerBuilder(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_FlushAfterWriteNil(t *testing.T) {
|
func Test_FlushAfterWriteNil(t *testing.T) {
|
||||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg Config
|
||||||
|
readerBuilder func(io.Reader) (io.Reader, error)
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return brotli.NewReader(reader), nil
|
||||||
|
},
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return zstd.NewReader(reader)
|
||||||
|
},
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
_, err := rw.Write(nil)
|
_, err := rw.Write(nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rw.(http.Flusher).Flush()
|
rw.(http.Flusher).Flush()
|
||||||
})))
|
})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -250,25 +416,58 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
assert.Empty(t, res.Header.Get(contentEncoding))
|
assert.Empty(t, res.Header.Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
reader, err := test.readerBuilder(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Empty(t, got)
|
assert.Empty(t, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_FlushAfterAllWrites(t *testing.T) {
|
func Test_FlushAfterAllWrites(t *testing.T) {
|
||||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
cfg Config
|
||||||
|
readerBuilder func(io.Reader) (io.Reader, error)
|
||||||
|
acceptEncoding string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "brotli",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return brotli.NewReader(reader), nil
|
||||||
|
},
|
||||||
|
acceptEncoding: "br",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zstd",
|
||||||
|
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||||
|
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||||
|
return zstd.NewReader(reader)
|
||||||
|
},
|
||||||
|
acceptEncoding: "zstd",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
for i := range bigTestBody {
|
for i := range bigTestBody {
|
||||||
_, err := rw.Write(bigTestBody[i : i+1])
|
_, err := rw.Write(bigTestBody[i : i+1])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
rw.(http.Flusher).Flush()
|
rw.(http.Flusher).Flush()
|
||||||
})))
|
})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, test.acceptEncoding)
|
||||||
|
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -276,11 +475,16 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
reader, err := test.readerBuilder(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ExcludedContentTypes(t *testing.T) {
|
func Test_ExcludedContentTypes(t *testing.T) {
|
||||||
|
@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
MinSize: 1024,
|
MinSize: 1024,
|
||||||
ExcludedContentTypes: test.excludedContentTypes,
|
ExcludedContentTypes: test.excludedContentTypes,
|
||||||
|
Algorithm: zstdName,
|
||||||
}
|
}
|
||||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set(contentType, test.contentType)
|
rw.Header().Set(contentType, test.contentType)
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
_, err := rw.Write(bigTestBody)
|
_, err := rw.Write(bigTestBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, cfg, next)
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, zstdName)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
h.ServeHTTP(rw, req)
|
||||||
|
@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||||
|
|
||||||
if test.expCompression {
|
if test.expCompression {
|
||||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
reader, err := zstd.NewReader(rw.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
} else {
|
} else {
|
||||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
|
||||||
|
|
||||||
got, err := io.ReadAll(rw.Body)
|
got, err := io.ReadAll(rw.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
MinSize: 1024,
|
MinSize: 1024,
|
||||||
IncludedContentTypes: test.includedContentTypes,
|
IncludedContentTypes: test.includedContentTypes,
|
||||||
|
Algorithm: zstdName,
|
||||||
}
|
}
|
||||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set(contentType, test.contentType)
|
rw.Header().Set(contentType, test.contentType)
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
_, err := rw.Write(bigTestBody)
|
_, err := rw.Write(bigTestBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, cfg, next)
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, zstdName)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
h.ServeHTTP(rw, req)
|
h.ServeHTTP(rw, req)
|
||||||
|
@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusAccepted, rw.Code)
|
assert.Equal(t, http.StatusAccepted, rw.Code)
|
||||||
|
|
||||||
if test.expCompression {
|
if test.expCompression {
|
||||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
reader, err := zstd.NewReader(rw.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
} else {
|
} else {
|
||||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
|
||||||
|
|
||||||
got, err := io.ReadAll(rw.Body)
|
got, err := io.ReadAll(rw.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
MinSize: 1024,
|
MinSize: 1024,
|
||||||
ExcludedContentTypes: test.excludedContentTypes,
|
ExcludedContentTypes: test.excludedContentTypes,
|
||||||
|
Algorithm: zstdName,
|
||||||
}
|
}
|
||||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set(contentType, test.contentType)
|
rw.Header().Set(contentType, test.contentType)
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
||||||
rw.(http.Flusher).Flush()
|
rw.(http.Flusher).Flush()
|
||||||
tb = tb[toWrite:]
|
tb = tb[toWrite:]
|
||||||
}
|
}
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, cfg, next)
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, zstdName)
|
||||||
|
|
||||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusOK, rw.Code)
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||||||
|
|
||||||
if test.expCompression {
|
if test.expCompression {
|
||||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
reader, err := zstd.NewReader(rw.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
} else {
|
} else {
|
||||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(rw.Body)
|
got, err := io.ReadAll(rw.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
MinSize: 1024,
|
MinSize: 1024,
|
||||||
IncludedContentTypes: test.includedContentTypes,
|
IncludedContentTypes: test.includedContentTypes,
|
||||||
|
Algorithm: zstdName,
|
||||||
}
|
}
|
||||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.Header().Set(contentType, test.contentType)
|
rw.Header().Set(contentType, test.contentType)
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
||||||
rw.(http.Flusher).Flush()
|
rw.(http.Flusher).Flush()
|
||||||
tb = tb[toWrite:]
|
tb = tb[toWrite:]
|
||||||
}
|
}
|
||||||
}))
|
})
|
||||||
|
|
||||||
|
h := mustNewCompressionHandler(t, cfg, next)
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||||
req.Header.Set(acceptEncoding, "br")
|
req.Header.Set(acceptEncoding, zstdName)
|
||||||
|
|
||||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusOK, rw.Code)
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||||||
|
|
||||||
if test.expCompression {
|
if test.expCompression {
|
||||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
reader, err := zstd.NewReader(rw.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := io.ReadAll(reader)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, bigTestBody, got)
|
assert.Equal(t, bigTestBody, got)
|
||||||
} else {
|
} else {
|
||||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
|
||||||
|
|
||||||
got, err := io.ReadAll(rw.Body)
|
got, err := io.ReadAll(rw.Body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
|
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
w, err := NewWrapper(cfg)
|
w, err := NewCompressionHandler(cfg, next)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return w
|
return w
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestHandler(t *testing.T, body []byte) http.Handler {
|
func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
return mustNewWrapper(t, Config{MinSize: 1024})(
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
if req.URL.Path == "/compressed" {
|
if req.URL.Path == "/compressed" {
|
||||||
rw.Header().Set("Content-Encoding", "br")
|
rw.Header().Set("Content-Encoding", brotliName)
|
||||||
}
|
}
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
_, err := rw.Write(body)
|
_, err := rw.Write(body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}),
|
})
|
||||||
)
|
|
||||||
|
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseContentType_equals(t *testing.T) {
|
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Path == "/compressed" {
|
||||||
|
rw.Header().Set("Content-Encoding", zstdName)
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
|
_, err := rw.Write(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_ParseContentType_equals(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
pct parsedContentType
|
pct parsedContentType
|
Loading…
Reference in a new issue