From 38508f9a9cc072dbbac2298f5ede9a3f8aa7d9b7 Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Thu, 29 Aug 2019 10:28:05 +0200 Subject: [PATCH] Fix recovered panic when websocket is mirrored --- .../service/loadbalancer/mirror/mirror.go | 13 +++-- .../loadbalancer/mirror/mirror_test.go | 54 +++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/pkg/server/service/loadbalancer/mirror/mirror.go b/pkg/server/service/loadbalancer/mirror/mirror.go index 593f20858..feba0d343 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror.go +++ b/pkg/server/service/loadbalancer/mirror/mirror.go @@ -1,8 +1,10 @@ package mirror import ( + "bufio" "context" "errors" + "net" "net/http" "sync" @@ -75,7 +77,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // AddMirror adds an httpHandler to mirror to. func (m *Mirroring) AddMirror(handler http.Handler, percent int) error { - if percent < 0 || percent >= 100 { + if percent < 0 || percent > 100 { return errors.New("percent must be between 0 and 100") } m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent}) @@ -84,6 +86,12 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error { type blackholeResponseWriter struct{} +func (b blackholeResponseWriter) Flush() {} + +func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("you can hijack connection on blackholeResponseWriter") +} + func (b blackholeResponseWriter) Header() http.Header { return http.Header{} } @@ -92,8 +100,7 @@ func (b blackholeResponseWriter) Write(bytes []byte) (int, error) { return len(bytes), nil } -func (b blackholeResponseWriter) WriteHeader(statusCode int) { -} +func (b blackholeResponseWriter) WriteHeader(statusCode int) {} type contextStopPropagation struct { context.Context diff --git a/pkg/server/service/loadbalancer/mirror/mirror_test.go b/pkg/server/service/loadbalancer/mirror/mirror_test.go index fb2c516c5..f223d72fd 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror_test.go +++ b/pkg/server/service/loadbalancer/mirror/mirror_test.go @@ -76,4 +76,58 @@ func TestInvalidPercent(t *testing.T) { err = mirror.AddMirror(nil, 101) assert.Error(t, err) + + err = mirror.AddMirror(nil, 100) + assert.NoError(t, err) + + err = mirror.AddMirror(nil, 0) + assert.NoError(t, err) +} + +func TestHijack(t *testing.T) { + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + pool := safe.NewPool(context.Background()) + mirror := New(handler, pool) + + var mirrorRequest bool + err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + hijacker, ok := rw.(http.Hijacker) + assert.Equal(t, true, ok) + + _, _, err := hijacker.Hijack() + assert.Error(t, err) + mirrorRequest = true + }), 100) + assert.NoError(t, err) + + mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + + pool.Stop() + assert.Equal(t, true, mirrorRequest) +} + +func TestFlush(t *testing.T) { + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + pool := safe.NewPool(context.Background()) + mirror := New(handler, pool) + + var mirrorRequest bool + err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + hijacker, ok := rw.(http.Flusher) + assert.Equal(t, true, ok) + + hijacker.Flush() + + mirrorRequest = true + }), 100) + assert.NoError(t, err) + + mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil)) + + pool.Stop() + assert.Equal(t, true, mirrorRequest) }