Fix recovered panic when websocket is mirrored
This commit is contained in:
parent
b113972bcf
commit
38508f9a9c
2 changed files with 64 additions and 3 deletions
|
@ -1,8 +1,10 @@
|
||||||
package mirror
|
package mirror
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -75,7 +77,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
// AddMirror adds an httpHandler to mirror to.
|
// AddMirror adds an httpHandler to mirror to.
|
||||||
func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
|
func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
|
||||||
if percent < 0 || percent >= 100 {
|
if percent < 0 || percent > 100 {
|
||||||
return errors.New("percent must be between 0 and 100")
|
return errors.New("percent must be between 0 and 100")
|
||||||
}
|
}
|
||||||
m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent})
|
m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent})
|
||||||
|
@ -84,6 +86,12 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
|
||||||
|
|
||||||
type blackholeResponseWriter struct{}
|
type blackholeResponseWriter struct{}
|
||||||
|
|
||||||
|
func (b blackholeResponseWriter) Flush() {}
|
||||||
|
|
||||||
|
func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, errors.New("you can hijack connection on blackholeResponseWriter")
|
||||||
|
}
|
||||||
|
|
||||||
func (b blackholeResponseWriter) Header() http.Header {
|
func (b blackholeResponseWriter) Header() http.Header {
|
||||||
return http.Header{}
|
return http.Header{}
|
||||||
}
|
}
|
||||||
|
@ -92,8 +100,7 @@ func (b blackholeResponseWriter) Write(bytes []byte) (int, error) {
|
||||||
return len(bytes), nil
|
return len(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b blackholeResponseWriter) WriteHeader(statusCode int) {
|
func (b blackholeResponseWriter) WriteHeader(statusCode int) {}
|
||||||
}
|
|
||||||
|
|
||||||
type contextStopPropagation struct {
|
type contextStopPropagation struct {
|
||||||
context.Context
|
context.Context
|
||||||
|
|
|
@ -76,4 +76,58 @@ func TestInvalidPercent(t *testing.T) {
|
||||||
|
|
||||||
err = mirror.AddMirror(nil, 101)
|
err = mirror.AddMirror(nil, 101)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = mirror.AddMirror(nil, 100)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = mirror.AddMirror(nil, 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHijack(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
pool := safe.NewPool(context.Background())
|
||||||
|
mirror := New(handler, pool)
|
||||||
|
|
||||||
|
var mirrorRequest bool
|
||||||
|
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
hijacker, ok := rw.(http.Hijacker)
|
||||||
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
|
_, _, err := hijacker.Hijack()
|
||||||
|
assert.Error(t, err)
|
||||||
|
mirrorRequest = true
|
||||||
|
}), 100)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
pool.Stop()
|
||||||
|
assert.Equal(t, true, mirrorRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlush(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
pool := safe.NewPool(context.Background())
|
||||||
|
mirror := New(handler, pool)
|
||||||
|
|
||||||
|
var mirrorRequest bool
|
||||||
|
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
hijacker, ok := rw.(http.Flusher)
|
||||||
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
|
hijacker.Flush()
|
||||||
|
|
||||||
|
mirrorRequest = true
|
||||||
|
}), 100)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
mirror.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
|
||||||
|
pool.Stop()
|
||||||
|
assert.Equal(t, true, mirrorRequest)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue