Fix recovered panic when websocket is mirrored

This commit is contained in:
Julien Salleyron 2019-08-29 10:28:05 +02:00 committed by Traefiker Bot
parent b113972bcf
commit 38508f9a9c
2 changed files with 64 additions and 3 deletions

View file

@ -1,8 +1,10 @@
package mirror package mirror
import ( import (
"bufio"
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"sync" "sync"
@ -75,7 +77,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// AddMirror adds an httpHandler to mirror to. // AddMirror adds an httpHandler to mirror to.
func (m *Mirroring) AddMirror(handler http.Handler, percent int) error { func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
if percent < 0 || percent >= 100 { if percent < 0 || percent > 100 {
return errors.New("percent must be between 0 and 100") return errors.New("percent must be between 0 and 100")
} }
m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent}) m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent})
@ -84,6 +86,12 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
type blackholeResponseWriter struct{} type blackholeResponseWriter struct{}
func (b blackholeResponseWriter) Flush() {}
func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("you can hijack connection on blackholeResponseWriter")
}
func (b blackholeResponseWriter) Header() http.Header { func (b blackholeResponseWriter) Header() http.Header {
return http.Header{} return http.Header{}
} }
@ -92,8 +100,7 @@ func (b blackholeResponseWriter) Write(bytes []byte) (int, error) {
return len(bytes), nil return len(bytes), nil
} }
func (b blackholeResponseWriter) WriteHeader(statusCode int) { func (b blackholeResponseWriter) WriteHeader(statusCode int) {}
}
type contextStopPropagation struct { type contextStopPropagation struct {
context.Context context.Context

View file

@ -76,4 +76,58 @@ func TestInvalidPercent(t *testing.T) {
err = mirror.AddMirror(nil, 101) err = mirror.AddMirror(nil, 101)
assert.Error(t, err) assert.Error(t, err)
err = mirror.AddMirror(nil, 100)
assert.NoError(t, err)
err = mirror.AddMirror(nil, 0)
assert.NoError(t, err)
}
func TestHijack(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
hijacker, ok := rw.(http.Hijacker)
assert.Equal(t, true, ok)
_, _, err := hijacker.Hijack()
assert.Error(t, err)
mirrorRequest = true
}), 100)
assert.NoError(t, err)
mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
pool.Stop()
assert.Equal(t, true, mirrorRequest)
}
func TestFlush(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
hijacker, ok := rw.(http.Flusher)
assert.Equal(t, true, ok)
hijacker.Flush()
mirrorRequest = true
}), 100)
assert.NoError(t, err)
mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
pool.Stop()
assert.Equal(t, true, mirrorRequest)
} }