traefik/pkg/middlewares/compress/compression_handler_test.go
Kevin Pollet ef168b801c
Refactor compress handler to make it generic
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
2024-10-10 16:04:04 +02:00

1106 lines
29 KiB
Go

package compress
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
smallTestBody = []byte("aaabbc" + strings.Repeat("aaabbbccc", 9) + "aaabbbc")
bigTestBody = []byte(strings.Repeat(strings.Repeat("aaabbbccc", 66)+" ", 6) + strings.Repeat("aaabbbccc", 66))
)
func Test_Vary(t *testing.T) {
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
})
}
}
func Test_SmallBodyNoCompression(t *testing.T) {
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
// With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes())
})
}
}
func Test_AlreadyCompressed(t *testing.T) {
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, bigTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, bigTestBody),
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes())
})
}
}
func Test_NoBody(t *testing.T) {
testCases := []struct {
desc string
statusCode int
body []byte
}{
{
desc: "status no content",
statusCode: http.StatusNoContent,
body: nil,
},
{
desc: "status not modified",
statusCode: http.StatusNotModified,
body: nil,
},
{
desc: "status OK with empty body",
statusCode: http.StatusOK,
body: []byte{},
},
{
desc: "status OK with nil body",
statusCode: http.StatusOK,
body: nil,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(test.statusCode)
_, err := rw.Write(test.body)
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
body, err := io.ReadAll(rw.Body)
require.NoError(t, err)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Empty(t, body)
})
}
}
func Test_MinSize(t *testing.T) {
cfg := Config{
MinSize: 128,
}
var bodySize int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for range bodySize {
// We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering
// because they haven't reached minSize.
_, err := rw.Write([]byte{'x'})
require.NoError(t, err)
}
})
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
req.Header.Add(acceptEncoding, "zstd")
// Short response is not compressed
bodySize = cfg.MinSize - 1
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Empty(t, rw.Result().Header.Get(contentEncoding))
// Long response is compressed
bodySize = cfg.MinSize
rw = httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
}
func Test_MultipleWriteHeader(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// We ensure that the subsequent call to WriteHeader is a noop.
rw.WriteHeader(http.StatusInternalServerError)
rw.WriteHeader(http.StatusNotFound)
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusInternalServerError, rw.Code)
}
func Test_FlushBeforeWrite(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWrite(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
_, err := rw.Write(bigTestBody[0:1])
require.NoError(t, err)
rw.(http.Flusher).Flush()
for _, b := range bigTestBody[1:] {
_, err := rw.Write([]byte{b})
require.NoError(t, err)
}
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWriteNil(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
_, err := rw.Write(nil)
require.NoError(t, err)
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Empty(t, got)
})
}
}
func Test_FlushAfterAllWrites(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1])
require.NoError(t, err)
}
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_ExcludedContentTypes(t *testing.T) {
testCases := []struct {
desc string
contentType string
excludedContentTypes []string
expCompression bool
}{
{
desc: "Always compress when content types are empty",
contentType: "",
expCompression: true,
},
{
desc: "MIME match",
contentType: "application/json",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME no match",
contentType: "text/xml",
excludedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match with no other directive ignores non-MIME directives",
contentType: "application/json; charset=utf-8",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, different charset",
contentType: "application/json; charset=ascii",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, same charset",
contentType: "application/json; charset=utf-8",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, missing charset",
contentType: "application/json",
excludedContentTypes: []string{"application/json; charset=ascii"},
expCompression: true,
},
{
desc: "MIME match case insensitive",
contentType: "Application/Json",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match ignore whitespace",
contentType: "application/json;charset=utf-8",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})
}
}
func Test_IncludedContentTypes(t *testing.T) {
testCases := []struct {
desc string
contentType string
includedContentTypes []string
expCompression bool
}{
{
desc: "Always compress when content types are empty",
contentType: "",
expCompression: true,
},
{
desc: "MIME match",
contentType: "application/json",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME no match",
contentType: "text/xml",
includedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match with no other directive ignores non-MIME directives",
contentType: "application/json; charset=utf-8",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, different charset",
contentType: "application/json; charset=ascii",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, same charset",
contentType: "application/json; charset=utf-8",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, missing charset",
contentType: "application/json",
includedContentTypes: []string{"application/json; charset=ascii"},
expCompression: false,
},
{
desc: "MIME match case insensitive",
contentType: "Application/Json",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match ignore whitespace",
contentType: "application/json;charset=utf-8",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})
}
}
func Test_FlushExcludedContentTypes(t *testing.T) {
testCases := []struct {
desc string
contentType string
excludedContentTypes []string
expCompression bool
}{
{
desc: "Always compress when content types are empty",
contentType: "",
expCompression: true,
},
{
desc: "MIME match",
contentType: "application/json",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME no match",
contentType: "text/xml",
excludedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match with no other directive ignores non-MIME directives",
contentType: "application/json; charset=utf-8",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, different charset",
contentType: "application/json; charset=ascii",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, same charset",
contentType: "application/json; charset=utf-8",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, missing charset",
contentType: "application/json",
excludedContentTypes: []string{"application/json; charset=ascii"},
expCompression: true,
},
{
desc: "MIME match case insensitive",
contentType: "Application/Json",
excludedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match ignore whitespace",
contentType: "application/json;charset=utf-8",
excludedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
tb := bigTestBody
for len(tb) > 0 {
// Write 100 bytes per run
// Detection should not be affected (we send 100 bytes)
toWrite := 100
if toWrite > len(tb) {
toWrite = len(tb)
}
_, err := rw.Write(tb[:toWrite])
require.NoError(t, err)
// Flush between each write
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
})
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})
}
}
func Test_FlushIncludedContentTypes(t *testing.T) {
testCases := []struct {
desc string
contentType string
includedContentTypes []string
expCompression bool
}{
{
desc: "Always compress when content types are empty",
contentType: "",
expCompression: true,
},
{
desc: "MIME match",
contentType: "application/json",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME no match",
contentType: "text/xml",
includedContentTypes: []string{"application/json"},
expCompression: false,
},
{
desc: "MIME match with no other directive ignores non-MIME directives",
contentType: "application/json; charset=utf-8",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, different charset",
contentType: "application/json; charset=ascii",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: false,
},
{
desc: "MIME match with other directives requires all directives be equal, same charset",
contentType: "application/json; charset=utf-8",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
{
desc: "MIME match with other directives requires all directives be equal, missing charset",
contentType: "application/json",
includedContentTypes: []string{"application/json; charset=ascii"},
expCompression: false,
},
{
desc: "MIME match case insensitive",
contentType: "Application/Json",
includedContentTypes: []string{"application/json"},
expCompression: true,
},
{
desc: "MIME match ignore whitespace",
contentType: "application/json;charset=utf-8",
includedContentTypes: []string{"application/json; charset=utf-8"},
expCompression: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
tb := bigTestBody
for len(tb) > 0 {
// Write 100 bytes per run
// Detection should not be affected (we send 100 bytes)
toWrite := 100
if toWrite > len(tb) {
toWrite = len(tb)
}
_, err := rw.Write(tb[:toWrite])
require.NoError(t, err)
// Flush between each write
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
})
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})
}
}
func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler {
t.Helper()
var writer NewCompressionWriter
switch algo {
case zstdName:
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
writer, err := zstd.NewWriter(rw)
require.NoError(t, err)
return writer, zstdName, nil
}
case brotliName:
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
return brotli.NewWriter(rw), brotliName, nil
}
default:
assert.Failf(t, "unknown compression algorithm: %s", algo)
}
w, err := NewCompressionHandler(cfg, writer, next)
require.NoError(t, err)
return w
}
func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", brotliName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next)
}
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", zstdName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next)
}
func Test_ParseContentType_equals(t *testing.T) {
testCases := []struct {
desc string
pct parsedContentType
mediaType string
params map[string]string
expect assert.BoolAssertionFunc
}{
{
desc: "empty parsed content type",
expect: assert.True,
},
{
desc: "simple content type",
pct: parsedContentType{
mediaType: "plain/text",
},
mediaType: "plain/text",
expect: assert.True,
},
{
desc: "content type with params",
pct: parsedContentType{
mediaType: "plain/text",
params: map[string]string{
"charset": "utf8",
},
},
mediaType: "plain/text",
params: map[string]string{
"charset": "utf8",
},
expect: assert.True,
},
{
desc: "different content type",
pct: parsedContentType{
mediaType: "plain/text",
},
mediaType: "application/json",
expect: assert.False,
},
{
desc: "content type with params",
pct: parsedContentType{
mediaType: "plain/text",
params: map[string]string{
"charset": "utf8",
},
},
mediaType: "plain/text",
params: map[string]string{
"charset": "latin-1",
},
expect: assert.False,
},
{
desc: "different number of parameters",
pct: parsedContentType{
mediaType: "plain/text",
params: map[string]string{
"charset": "utf8",
},
},
mediaType: "plain/text",
params: map[string]string{
"charset": "utf8",
"q": "0.8",
},
expect: assert.False,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
test.expect(t, test.pct.equals(test.mediaType, test.params))
})
}
}