diff --git a/pkg/proxy/fast/proxy.go b/pkg/proxy/fast/proxy.go index a400ce646..717b1ff06 100644 --- a/pkg/proxy/fast/proxy.go +++ b/pkg/proxy/fast/proxy.go @@ -171,7 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if reqUpType != "" { outReq.Header.Set("Connection", "Upgrade") outReq.Header.Set("Upgrade", reqUpType) - if reqUpType == "websocket" { + if strings.EqualFold(reqUpType, "websocket") { cleanWebSocketHeaders(&outReq.Header) } } @@ -353,6 +353,7 @@ type fasthttpHeader interface { SetBytesV(key string, value []byte) DelBytes(key []byte) Del(key string) + ConnectionUpgrade() bool } // removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. diff --git a/pkg/proxy/fast/proxy_websocket_test.go b/pkg/proxy/fast/proxy_websocket_test.go index b057f8b58..ef22895cc 100644 --- a/pkg/proxy/fast/proxy_websocket_test.go +++ b/pkg/proxy/fast/proxy_websocket_test.go @@ -2,7 +2,9 @@ package fast import ( "bufio" + "crypto/sha1" "crypto/tls" + "encoding/base64" "errors" "fmt" "net" @@ -19,6 +21,34 @@ import ( "golang.org/x/net/websocket" ) +func TestWebSocketUpgradeCase(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + + hijacker, ok := w.(http.Hijacker) + require.True(t, ok) + + c, _, err := hijacker.Hijack() + require.NoError(t, err) + + // Force answer with "Connection: upgrade" in lowercase. + _, err = c.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Accept: " + computeAcceptKey(challengeKey) + "\r\n\n")) + require.NoError(t, err) + })) + defer srv.Close() + + proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil)) + + proxyAddr := proxy.Listener.Addr().String() + _, conn, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + ).open() + require.NoError(t, err) + + conn.Close() +} + func TestWebSocketTCPClose(t *testing.T) { errChan := make(chan error, 1) upgrader := gorillawebsocket.Upgrader{} @@ -691,3 +721,10 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes return srv } + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() // #nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 + h.Write([]byte(challengeKey)) + h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/pkg/proxy/fast/upgrade.go b/pkg/proxy/fast/upgrade.go index 7bec09e49..2570d4c72 100644 --- a/pkg/proxy/fast/upgrade.go +++ b/pkg/proxy/fast/upgrade.go @@ -1,7 +1,6 @@ package fast import ( - "bytes" "context" "fmt" "io" @@ -100,7 +99,7 @@ func upgradeType(h http.Header) string { } func upgradeTypeFastHTTP(h fasthttpHeader) string { - if !bytes.Contains(h.Peek("Connection"), []byte("Upgrade")) { + if !h.ConnectionUpgrade() { return "" }