traefik/pkg/proxy/fast/upgrade.go

109 lines
2.7 KiB
Go
Raw Normal View History

package fast
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
"github.com/valyala/fasthttp"
"golang.org/x/net/http/httpguts"
)
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) {
_, err := io.Copy(c.user, c.backend)
errCh <- err
}
func (c switchProtocolCopier) copyToBackend(errCh chan<- error) {
_, err := io.Copy(c.backend, c.user)
errCh <- err
}
type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn)
func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler {
return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) {
resUpType := upgradeTypeFastHTTP(&res.Header)
if !strings.EqualFold(reqUpType, resUpType) {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
backConn.Close()
return
}
hj, ok := rw.(http.Hijacker)
if !ok {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
backConn.Close()
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-ctx.Done():
case <-backConnCloseCh:
}
_ = backConn.Close()
}()
defer close(backConnCloseCh)
conn, brw, err := hj.Hijack()
if err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err))
return
}
defer conn.Close()
for k, values := range rw.Header() {
for _, v := range values {
res.Header.Add(k, v)
}
}
if err := res.Header.Write(brw.Writer); err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err))
return
}
if err := brw.Flush(); err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err))
return
}
errCh := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errCh)
go spc.copyFromBackend(errCh)
<-errCh
}
}
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
func upgradeTypeFastHTTP(h fasthttpHeader) string {
if !bytes.Contains(h.Peek("Connection"), []byte("Upgrade")) {
return ""
}
return string(h.Peek("Upgrade"))
}