Error pages and header merge

This commit is contained in:
Ludovic Fernandez 2018-05-28 15:00:04 +02:00 committed by Traefiker Bot
parent 3f5772c62a
commit fb5aa4c9c1
2 changed files with 31 additions and 37 deletions

View file

@ -101,12 +101,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorder.Header())
for key := range recorderErrorPage.Header() {
w.Header().Del(key)
}
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(recorder.GetCode())
w.Write(recorderErrorPage.GetBody().Bytes())
return
@ -174,64 +169,65 @@ type responseRecorderWithCloseNotify struct {
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (rw *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return rw.responseWriter.(http.CloseNotifier).CloseNotify()
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}
// Header returns the response headers.
func (rw *responseRecorderWithoutCloseNotify) Header() http.Header {
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
if r.HeaderMap == nil {
r.HeaderMap = make(http.Header)
}
return rw.HeaderMap
return r.HeaderMap
}
func (rw *responseRecorderWithoutCloseNotify) GetCode() int {
return rw.Code
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
return r.Code
}
func (rw *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return rw.Body
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return r.Body
}
func (rw *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return rw.streamingResponseStarted
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return r.streamingResponseStarted
}
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if rw.err != nil {
return 0, rw.err
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
return rw.Body.Write(buf)
return r.Body.Write(buf)
}
// WriteHeader sets rw.Code.
func (rw *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
rw.Code = code
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
r.Code = code
}
// Hijack hijacks the connection
func (rw *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rw.responseWriter.(http.Hijacker).Hijack()
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.responseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (rw *responseRecorderWithoutCloseNotify) Flush() {
if !rw.streamingResponseStarted {
utils.CopyHeaders(rw.responseWriter.Header(), rw.Header())
rw.responseWriter.WriteHeader(rw.Code)
rw.streamingResponseStarted = true
func (r *responseRecorderWithoutCloseNotify) Flush() {
if !r.streamingResponseStarted {
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
r.responseWriter.WriteHeader(r.Code)
r.streamingResponseStarted = true
}
_, err := rw.responseWriter.Write(rw.Body.Bytes())
_, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil {
log.Errorf("Error writing response in responseRecorder: %s", err)
rw.err = err
log.Errorf("Error writing response in responseRecorder: %v", err)
r.err = err
}
rw.Body.Reset()
r.Body.Reset()
if flusher, ok := rw.responseWriter.(http.Flusher); ok {
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View file

@ -318,7 +318,6 @@ func TestHandlerOldWayIntegration(t *testing.T) {
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Foo", "bar")
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
@ -331,7 +330,6 @@ func TestHandlerOldWayIntegration(t *testing.T) {
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
assert.Equal(t, "bar", recorder.Header().Get("X-Foo"), "missing header")
})
}
}