diff --git a/go.mod b/go.mod index d4e045bb3..27de095eb 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/influxdata/influxdb-client-go/v2 v2.7.0 github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d github.com/instana/go-sensor v1.38.3 - github.com/klauspost/compress v1.15.0 + github.com/klauspost/compress v1.16.6 github.com/kvtools/consul v1.0.2 github.com/kvtools/etcdv3 v1.0.2 github.com/kvtools/redis v1.0.2 diff --git a/go.sum b/go.sum index 1de390a94..c444150a2 100644 --- a/go.sum +++ b/go.sum @@ -1158,8 +1158,8 @@ github.com/kisielk/sqlstruct v0.0.0-20150923205031-648daed35d49/go.mod h1:yyMNCy github.com/kisom/goutils v1.1.0/go.mod h1:+UBTfd78habUYWFbNWTJNG+jNG/i/lGURakr4A/yNRw= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U= -github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk= +github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kolo/xmlrpc v0.0.0-20200310150728-e0350524596b h1:DzHy0GlWeF0KAglaTMY7Q+khIFoG8toHP+wLFBVBQJc= github.com/kolo/xmlrpc v0.0.0-20200310150728-e0350524596b/go.mod h1:o03bZfuBwAXHetKXuInt4S7omeXUu62/A845kiycsSQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/pkg/middlewares/compress/compress_test.go b/pkg/middlewares/compress/compress_test.go index 8f7e85ae0..4a1fd2196 100644 --- a/pkg/middlewares/compress/compress_test.go +++ b/pkg/middlewares/compress/compress_test.go @@ -5,6 +5,8 @@ import ( "io" "net/http" "net/http/httptest" + "net/http/httptrace" + "net/textproto" "testing" "github.com/klauspost/compress/gzhttp" @@ -356,6 +358,86 @@ func TestMinResponseBodyBytes(t *testing.T) { } } +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +func Test1xxResponses(t *testing.T) { + fakeBody := generateBytes(100000) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + if _, err := w.Write(fakeBody); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }) + + compress, err := New(context.Background(), next, dynamic.Compress{MinResponseBodyBytes: 1024}, "testing") + require.NoError(t, err) + + server := httptest.NewServer(compress) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + req.Header.Add(acceptEncodingHeader, gzipValue) + + res, err := frontendClient.Do(req) + assert.Nil(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + assert.Equal(t, gzipValue, res.Header.Get(contentEncodingHeader)) + body, _ := io.ReadAll(res.Body) + assert.NotEqualValues(t, body, fakeBody) +} + func BenchmarkCompress(b *testing.B) { testCases := []struct { name string diff --git a/pkg/middlewares/customerrors/custom_errors.go b/pkg/middlewares/customerrors/custom_errors.go index c764b2cf8..39d411622 100644 --- a/pkg/middlewares/customerrors/custom_errors.go +++ b/pkg/middlewares/customerrors/custom_errors.go @@ -193,11 +193,25 @@ func (cc *codeCatcher) Write(buf []byte) (int, error) { return cc.responseWriter.Write(buf) } +// WriteHeader is, in the specific case of 1xx status codes, a direct call to the wrapped ResponseWriter, without marking headers as sent, +// allowing so further calls. func (cc *codeCatcher) WriteHeader(code int) { if cc.headersSent || cc.caughtFilteredCode { return } + // Handling informational headers. + if code >= 100 && code <= 199 { + // Multiple informational status codes can be used, + // so here the copy is not appending the values to not repeat them. + for k, v := range cc.Header() { + cc.responseWriter.Header()[k] = v + } + + cc.responseWriter.WriteHeader(code) + return + } + cc.code = code for _, block := range cc.httpCodeRanges { if cc.code >= block[0] && cc.code <= block[1] { @@ -208,7 +222,11 @@ func (cc *codeCatcher) WriteHeader(code int) { } } - utils.CopyHeaders(cc.responseWriter.Header(), cc.Header()) + // The copy is not appending the values, + // to not repeat them in case any informational status code has been written. + for k, v := range cc.Header() { + cc.responseWriter.Header()[k] = v + } cc.responseWriter.WriteHeader(cc.code) cc.headersSent = true } @@ -303,12 +321,28 @@ func (r *codeModifierWithoutCloseNotify) Write(buf []byte) (int, error) { // WriteHeader sends the headers, with the enforced code (the code in argument is always ignored), // if it hasn't already been done. -func (r *codeModifierWithoutCloseNotify) WriteHeader(_ int) { +// WriteHeader is, in the specific case of 1xx status codes, a direct call to the wrapped ResponseWriter, without marking headers as sent, +// allowing so further calls. +func (r *codeModifierWithoutCloseNotify) WriteHeader(code int) { if r.headerSent { return } - utils.CopyHeaders(r.responseWriter.Header(), r.Header()) + // Handling informational headers. + if code >= 100 && code <= 199 { + // Multiple informational status codes can be used, + // so here the copy is not appending the values to not repeat them. + for k, v := range r.headerMap { + r.responseWriter.Header()[k] = v + } + + r.responseWriter.WriteHeader(code) + return + } + + for k, v := range r.headerMap { + r.responseWriter.Header()[k] = v + } r.responseWriter.WriteHeader(r.code) r.headerSent = true } diff --git a/pkg/middlewares/customerrors/custom_errors_test.go b/pkg/middlewares/customerrors/custom_errors_test.go index c291b15d5..84e3579b7 100644 --- a/pkg/middlewares/customerrors/custom_errors_test.go +++ b/pkg/middlewares/customerrors/custom_errors_test.go @@ -3,8 +3,11 @@ package customerrors import ( "context" "fmt" + "io" "net/http" "net/http/httptest" + "net/http/httptrace" + "net/textproto" "testing" "github.com/stretchr/testify/assert" @@ -181,6 +184,88 @@ func TestHandler(t *testing.T) { } } +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +func Test1xxResponses(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + h.Add("User-Agent", "foobar") + _, _ = w.Write([]byte("Hello")) + w.WriteHeader(http.StatusBadGateway) + }) + + serviceBuilderMock := &mockServiceBuilder{handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "My error page.") + })} + + config := dynamic.ErrorPage{Service: "error", Query: "/", Status: []string{"200"}} + + errorPageHandler, err := New(context.Background(), next, config, serviceBuilderMock, "test") + require.NoError(t, err) + + server := httptest.NewServer(errorPageHandler) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + + res, err := frontendClient.Do(req) + assert.Nil(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + assert.Equal(t, "My error page.\n", string(body)) +} + type mockServiceBuilder struct { handler http.Handler } diff --git a/pkg/middlewares/headers/headers_test.go b/pkg/middlewares/headers/headers_test.go index efbcdb0b2..c345efd60 100644 --- a/pkg/middlewares/headers/headers_test.go +++ b/pkg/middlewares/headers/headers_test.go @@ -4,8 +4,11 @@ package headers import ( "context" + "io" "net/http" "net/http/httptest" + "net/http/httptrace" + "net/textproto" "testing" "github.com/stretchr/testify/assert" @@ -111,3 +114,87 @@ func Test_headers_getTracingInformation(t *testing.T) { assert.Equal(t, "testing", name) assert.Equal(t, tracing.SpanKindNoneEnum, trace) } + +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +func Test1xxResponses(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + _, _ = w.Write([]byte("Hello")) + }) + + cfg := dynamic.Headers{ + CustomResponseHeaders: map[string]string{ + "X-Custom-Response-Header": "test_response", + }, + } + + mid, err := New(context.Background(), next, cfg, "testing") + require.NoError(t, err) + + server := httptest.NewServer(mid) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + + res, err := frontendClient.Do(req) + assert.Nil(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } + + assert.Equal(t, "test_response", res.Header.Get("X-Custom-Response-Header")) +} diff --git a/pkg/middlewares/headers/responsewriter.go b/pkg/middlewares/headers/responsewriter.go index 15201b1fd..c36674a05 100644 --- a/pkg/middlewares/headers/responsewriter.go +++ b/pkg/middlewares/headers/responsewriter.go @@ -36,10 +36,19 @@ func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(* return rm } +// WriteHeader is, in the specific case of 1xx status codes, a direct call to the wrapped ResponseWriter, without marking headers as sent, +// allowing so further calls. func (r *responseModifier) WriteHeader(code int) { if r.headersSent { return } + + // Handling informational headers. + if code >= 100 && code <= 199 { + r.rw.WriteHeader(code) + return + } + defer func() { r.code = code r.headersSent = true diff --git a/pkg/middlewares/pipelining/pipelining_test.go b/pkg/middlewares/pipelining/pipelining_test.go index b26e5fc6b..181fcb328 100644 --- a/pkg/middlewares/pipelining/pipelining_test.go +++ b/pkg/middlewares/pipelining/pipelining_test.go @@ -2,8 +2,11 @@ package pipelining import ( "context" + "io" "net/http" "net/http/httptest" + "net/http/httptrace" + "net/textproto" "testing" "github.com/stretchr/testify/assert" @@ -68,3 +71,80 @@ func TestNew(t *testing.T) { }) } } + +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +// This test is only here to guarantee that there would not be any regression in the future, +// because the pipelining middleware is already supporting informational headers. +func Test1xxResponses(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + _, _ = w.Write([]byte("Hello")) + }) + + pipe := New(context.Background(), next, "pipe") + + server := httptest.NewServer(pipe) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + + res, err := frontendClient.Do(req) + assert.Nil(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} diff --git a/pkg/middlewares/retry/retry.go b/pkg/middlewares/retry/retry.go index 37dd4abef..75367ebc7 100644 --- a/pkg/middlewares/retry/retry.go +++ b/pkg/middlewares/retry/retry.go @@ -209,7 +209,7 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) { r.DisableRetries() } - if r.ShouldRetry() { + if r.ShouldRetry() || r.written { return } @@ -223,6 +223,13 @@ func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) { } r.responseWriter.WriteHeader(code) + + // Handling informational headers. + // This allows to keep writing to r.headers map until a final status code is written. + if code >= 100 && code <= 199 { + return + } + r.written = true } diff --git a/pkg/middlewares/retry/retry_test.go b/pkg/middlewares/retry/retry_test.go index 4622093af..75ef5831b 100644 --- a/pkg/middlewares/retry/retry_test.go +++ b/pkg/middlewares/retry/retry_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/http/httptrace" + "net/textproto" "strings" "testing" "time" @@ -309,3 +310,82 @@ func TestRetryWebsocket(t *testing.T) { }) } } + +// This test is an adapted version of net/http/httputil.Test1xxResponses test. +func Test1xxResponses(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "; rel=preload; as=style") + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + _, _ = w.Write([]byte("Hello")) + }) + + retryListener := &countingRetryListener{} + retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 1}, retryListener, "traefikTest") + require.NoError(t, err) + + server := httptest.NewServer(retry) + t.Cleanup(server.Close) + frontendClient := server.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil) + + res, err := frontendClient.Do(req) + assert.Nil(t, err) + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"; rel=preload; as=style", "; rel=preload; as=script", "; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } + + assert.Equal(t, 0, retryListener.timesCalled) +}