Error pages and header merge

This commit is contained in:
Ludovic Fernandez 2018-05-28 15:00:04 +02:00 committed by Traefiker Bot
parent 3f5772c62a
commit fb5aa4c9c1
2 changed files with 31 additions and 37 deletions

View file

@ -101,12 +101,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorder.Header())
for key := range recorderErrorPage.Header() {
w.Header().Del(key)
}
utils.CopyHeaders(w.Header(), recorderErrorPage.Header()) utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(recorder.GetCode()) w.WriteHeader(recorder.GetCode())
w.Write(recorderErrorPage.GetBody().Bytes()) w.Write(recorderErrorPage.GetBody().Bytes())
return return
@ -174,64 +169,65 @@ type responseRecorderWithCloseNotify struct {
// CloseNotify returns a channel that receives at most a // CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away. // single value (true) when the client connection has gone away.
func (rw *responseRecorderWithCloseNotify) CloseNotify() <-chan bool { func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return rw.responseWriter.(http.CloseNotifier).CloseNotify() return r.responseWriter.(http.CloseNotifier).CloseNotify()
} }
// Header returns the response headers. // Header returns the response headers.
func (rw *responseRecorderWithoutCloseNotify) Header() http.Header { func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
if rw.HeaderMap == nil { if r.HeaderMap == nil {
rw.HeaderMap = make(http.Header) r.HeaderMap = make(http.Header)
}
return rw.HeaderMap
} }
func (rw *responseRecorderWithoutCloseNotify) GetCode() int { return r.HeaderMap
return rw.Code
} }
func (rw *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { func (r *responseRecorderWithoutCloseNotify) GetCode() int {
return rw.Body return r.Code
} }
func (rw *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return rw.streamingResponseStarted return r.Body
}
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return r.streamingResponseStarted
} }
// Write always succeeds and writes to rw.Body, if not nil. // Write always succeeds and writes to rw.Body, if not nil.
func (rw *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if rw.err != nil { if r.err != nil {
return 0, rw.err return 0, r.err
} }
return rw.Body.Write(buf) return r.Body.Write(buf)
} }
// WriteHeader sets rw.Code. // WriteHeader sets rw.Code.
func (rw *responseRecorderWithoutCloseNotify) WriteHeader(code int) { func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
rw.Code = code r.Code = code
} }
// Hijack hijacks the connection // Hijack hijacks the connection
func (rw *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rw.responseWriter.(http.Hijacker).Hijack() return r.responseWriter.(http.Hijacker).Hijack()
} }
// Flush sends any buffered data to the client. // Flush sends any buffered data to the client.
func (rw *responseRecorderWithoutCloseNotify) Flush() { func (r *responseRecorderWithoutCloseNotify) Flush() {
if !rw.streamingResponseStarted { if !r.streamingResponseStarted {
utils.CopyHeaders(rw.responseWriter.Header(), rw.Header()) utils.CopyHeaders(r.responseWriter.Header(), r.Header())
rw.responseWriter.WriteHeader(rw.Code) r.responseWriter.WriteHeader(r.Code)
rw.streamingResponseStarted = true r.streamingResponseStarted = true
} }
_, err := rw.responseWriter.Write(rw.Body.Bytes()) _, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil { if err != nil {
log.Errorf("Error writing response in responseRecorder: %s", err) log.Errorf("Error writing response in responseRecorder: %v", err)
rw.err = err r.err = err
} }
rw.Body.Reset() r.Body.Reset()
if flusher, ok := rw.responseWriter.(http.Flusher); ok { if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
} }

View file

@ -318,7 +318,6 @@ func TestHandlerOldWayIntegration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Foo", "bar")
w.WriteHeader(test.backendCode) w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode)) fmt.Fprintln(w, http.StatusText(test.backendCode))
}) })
@ -331,7 +330,6 @@ func TestHandlerOldWayIntegration(t *testing.T) {
n.ServeHTTP(recorder, req) n.ServeHTTP(recorder, req)
test.validate(t, recorder) test.validate(t, recorder)
assert.Equal(t, "bar", recorder.Header().Get("X-Foo"), "missing header")
}) })
} }
} }