554 lines
13 KiB
Go
554 lines
13 KiB
Go
|
package fast
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"encoding/base64"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/http/httptrace"
|
||
|
"net/http/httputil"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/rs/zerolog/log"
|
||
|
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||
|
"github.com/valyala/fasthttp"
|
||
|
"golang.org/x/net/http/httpguts"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
bufferSize = 32 * 1024
|
||
|
bufioSize = 64 * 1024
|
||
|
)
|
||
|
|
||
|
var hopHeaders = []string{
|
||
|
"Connection",
|
||
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||
|
"Keep-Alive",
|
||
|
"Proxy-Authenticate",
|
||
|
"Proxy-Authorization",
|
||
|
"Te", // canonicalized version of "TE"
|
||
|
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||
|
"Transfer-Encoding",
|
||
|
"Upgrade",
|
||
|
}
|
||
|
|
||
|
type pool[T any] struct {
|
||
|
pool sync.Pool
|
||
|
}
|
||
|
|
||
|
func (p *pool[T]) Get() T {
|
||
|
if tmp := p.pool.Get(); tmp != nil {
|
||
|
return tmp.(T)
|
||
|
}
|
||
|
|
||
|
var res T
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func (p *pool[T]) Put(x T) {
|
||
|
p.pool.Put(x)
|
||
|
}
|
||
|
|
||
|
type buffConn struct {
|
||
|
*bufio.Reader
|
||
|
net.Conn
|
||
|
}
|
||
|
|
||
|
func (b buffConn) Read(p []byte) (int, error) {
|
||
|
return b.Reader.Read(p)
|
||
|
}
|
||
|
|
||
|
type writeDetector struct {
|
||
|
net.Conn
|
||
|
|
||
|
written bool
|
||
|
}
|
||
|
|
||
|
func (w *writeDetector) Write(p []byte) (int, error) {
|
||
|
n, err := w.Conn.Write(p)
|
||
|
if n > 0 {
|
||
|
w.written = true
|
||
|
}
|
||
|
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
type writeFlusher struct {
|
||
|
io.Writer
|
||
|
}
|
||
|
|
||
|
func (w *writeFlusher) Write(b []byte) (int, error) {
|
||
|
n, err := w.Writer.Write(b)
|
||
|
if f, ok := w.Writer.(http.Flusher); ok {
|
||
|
f.Flush()
|
||
|
}
|
||
|
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
type timeoutError struct {
|
||
|
error
|
||
|
}
|
||
|
|
||
|
func (t timeoutError) Timeout() bool {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (t timeoutError) Temporary() bool {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// ReverseProxy is the FastProxy reverse proxy implementation.
|
||
|
type ReverseProxy struct {
|
||
|
debug bool
|
||
|
|
||
|
connPool *connPool
|
||
|
|
||
|
bufferPool pool[[]byte]
|
||
|
readerPool pool[*bufio.Reader]
|
||
|
writerPool pool[*bufio.Writer]
|
||
|
limitReaderPool pool[*io.LimitedReader]
|
||
|
|
||
|
proxyAuth string
|
||
|
|
||
|
targetURL *url.URL
|
||
|
passHostHeader bool
|
||
|
responseHeaderTimeout time.Duration
|
||
|
}
|
||
|
|
||
|
// NewReverseProxy creates a new ReverseProxy.
|
||
|
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
|
||
|
var proxyAuth string
|
||
|
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
||
|
username := proxyURL.User.Username()
|
||
|
password, _ := proxyURL.User.Password()
|
||
|
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||
|
}
|
||
|
|
||
|
return &ReverseProxy{
|
||
|
debug: debug,
|
||
|
passHostHeader: passHostHeader,
|
||
|
targetURL: targetURL,
|
||
|
proxyAuth: proxyAuth,
|
||
|
connPool: connPool,
|
||
|
responseHeaderTimeout: responseHeaderTimeout,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||
|
if req.Body != nil {
|
||
|
defer req.Body.Close()
|
||
|
}
|
||
|
|
||
|
outReq := fasthttp.AcquireRequest()
|
||
|
defer fasthttp.ReleaseRequest(outReq)
|
||
|
|
||
|
// This is not required as the headers are already normalized by net/http.
|
||
|
outReq.Header.DisableNormalizing()
|
||
|
|
||
|
for k, v := range req.Header {
|
||
|
for _, s := range v {
|
||
|
outReq.Header.Add(k, s)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
removeConnectionHeaders(&outReq.Header)
|
||
|
|
||
|
for _, header := range hopHeaders {
|
||
|
outReq.Header.Del(header)
|
||
|
}
|
||
|
|
||
|
if p.proxyAuth != "" {
|
||
|
outReq.Header.Set("Proxy-Authorization", p.proxyAuth)
|
||
|
}
|
||
|
|
||
|
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||
|
outReq.Header.Set("Te", "trailers")
|
||
|
}
|
||
|
|
||
|
if p.debug {
|
||
|
outReq.Header.Set("X-Traefik-Fast-Proxy", "enabled")
|
||
|
}
|
||
|
|
||
|
reqUpType := upgradeType(req.Header)
|
||
|
if !isGraphic(reqUpType) {
|
||
|
proxyhttputil.ErrorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if reqUpType != "" {
|
||
|
outReq.Header.Set("Connection", "Upgrade")
|
||
|
outReq.Header.Set("Upgrade", reqUpType)
|
||
|
if reqUpType == "websocket" {
|
||
|
cleanWebSocketHeaders(&outReq.Header)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
u2 := new(url.URL)
|
||
|
*u2 = *req.URL
|
||
|
u2.Scheme = p.targetURL.Scheme
|
||
|
u2.Host = p.targetURL.Host
|
||
|
|
||
|
u := req.URL
|
||
|
if req.RequestURI != "" {
|
||
|
parsedURL, err := url.ParseRequestURI(req.RequestURI)
|
||
|
if err == nil {
|
||
|
u = parsedURL
|
||
|
}
|
||
|
}
|
||
|
|
||
|
u2.Path = u.Path
|
||
|
u2.RawPath = u.RawPath
|
||
|
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||
|
|
||
|
outReq.SetHost(u2.Host)
|
||
|
outReq.Header.SetHost(u2.Host)
|
||
|
|
||
|
if p.passHostHeader {
|
||
|
outReq.Header.SetHost(req.Host)
|
||
|
}
|
||
|
|
||
|
outReq.SetRequestURI(u2.RequestURI())
|
||
|
|
||
|
outReq.SetBodyStream(req.Body, int(req.ContentLength))
|
||
|
|
||
|
outReq.Header.SetMethod(req.Method)
|
||
|
|
||
|
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||
|
// If we aren't the first proxy retain prior
|
||
|
// X-Forwarded-For information as a comma+space
|
||
|
// separated list and fold multiple headers into one.
|
||
|
prior, ok := req.Header["X-Forwarded-For"]
|
||
|
if len(prior) > 0 {
|
||
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||
|
}
|
||
|
|
||
|
omit := ok && prior == nil // Go Issue 38079: nil now means don't populate the header
|
||
|
if !omit {
|
||
|
outReq.Header.Set("X-Forwarded-For", clientIP)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := p.roundTrip(rw, req, outReq, reqUpType); err != nil {
|
||
|
proxyhttputil.ErrorHandler(rw, req, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Note that unlike the net/http RoundTrip:
|
||
|
// - we are not supporting "100 Continue" response to forward them as-is to the client.
|
||
|
// - we are not asking for compressed response automatically. That is because this will add an extra cost when the
|
||
|
// client is asking for an uncompressed response, as we will have to un-compress it, and nowadays most clients are
|
||
|
// already asking for compressed response (allowing "passthrough" compression).
|
||
|
func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outReq *fasthttp.Request, reqUpType string) error {
|
||
|
ctx := req.Context()
|
||
|
trace := httptrace.ContextClientTrace(ctx)
|
||
|
|
||
|
var co *conn
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
co, err = p.connPool.AcquireConn()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("acquire connection: %w", err)
|
||
|
}
|
||
|
|
||
|
wd := &writeDetector{Conn: co}
|
||
|
|
||
|
err = p.writeRequest(wd, outReq)
|
||
|
if wd.written && trace != nil && trace.WroteRequest != nil {
|
||
|
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
||
|
trace.WroteRequest(httptrace.WroteRequestInfo{})
|
||
|
}
|
||
|
if err == nil {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
log.Ctx(ctx).Debug().Err(err).Msg("Error while writing request")
|
||
|
|
||
|
co.Close()
|
||
|
|
||
|
if wd.written && !isReplayable(req) {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
br := p.readerPool.Get()
|
||
|
if br == nil {
|
||
|
br = bufio.NewReaderSize(co, bufioSize)
|
||
|
}
|
||
|
defer p.readerPool.Put(br)
|
||
|
|
||
|
br.Reset(co)
|
||
|
|
||
|
res := fasthttp.AcquireResponse()
|
||
|
defer fasthttp.ReleaseResponse(res)
|
||
|
|
||
|
res.Header.SetNoDefaultContentType(true)
|
||
|
|
||
|
for {
|
||
|
var timer *time.Timer
|
||
|
errTimeout := atomic.Pointer[timeoutError]{}
|
||
|
if p.responseHeaderTimeout > 0 {
|
||
|
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
|
||
|
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||
|
co.Close()
|
||
|
})
|
||
|
}
|
||
|
|
||
|
res.Header.SetNoDefaultContentType(true)
|
||
|
if err := res.Header.Read(br); err != nil {
|
||
|
if p.responseHeaderTimeout > 0 {
|
||
|
if errT := errTimeout.Load(); errT != nil {
|
||
|
return errT
|
||
|
}
|
||
|
}
|
||
|
co.Close()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if timer != nil {
|
||
|
timer.Stop()
|
||
|
}
|
||
|
|
||
|
fixPragmaCacheControl(&res.Header)
|
||
|
|
||
|
resCode := res.StatusCode()
|
||
|
is1xx := 100 <= resCode && resCode <= 199
|
||
|
// treat 101 as a terminal status, see issue 26161
|
||
|
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||
|
if is1xxNonTerminal {
|
||
|
removeConnectionHeaders(&res.Header)
|
||
|
h := rw.Header()
|
||
|
|
||
|
for _, header := range hopHeaders {
|
||
|
res.Header.Del(header)
|
||
|
}
|
||
|
|
||
|
res.Header.VisitAll(func(key, value []byte) {
|
||
|
rw.Header().Add(string(key), string(value))
|
||
|
})
|
||
|
|
||
|
rw.WriteHeader(res.StatusCode())
|
||
|
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||
|
for k := range h {
|
||
|
delete(h, k)
|
||
|
}
|
||
|
|
||
|
res.Reset()
|
||
|
res.Header.Reset()
|
||
|
res.Header.SetNoDefaultContentType(true)
|
||
|
|
||
|
continue
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
|
||
|
announcedTrailers := res.Header.Peek("Trailer")
|
||
|
|
||
|
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||
|
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||
|
// As the connection has been hijacked, it cannot be added back to the pool.
|
||
|
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
removeConnectionHeaders(&res.Header)
|
||
|
|
||
|
for _, header := range hopHeaders {
|
||
|
res.Header.Del(header)
|
||
|
}
|
||
|
|
||
|
if len(announcedTrailers) > 0 {
|
||
|
res.Header.Add("Trailer", string(announcedTrailers))
|
||
|
}
|
||
|
|
||
|
res.Header.VisitAll(func(key, value []byte) {
|
||
|
rw.Header().Add(string(key), string(value))
|
||
|
})
|
||
|
|
||
|
rw.WriteHeader(res.StatusCode())
|
||
|
|
||
|
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||
|
if res.Header.ContentLength() == -1 {
|
||
|
cbr := httputil.NewChunkedReader(br)
|
||
|
|
||
|
b := p.bufferPool.Get()
|
||
|
if b == nil {
|
||
|
b = make([]byte, bufferSize)
|
||
|
}
|
||
|
defer p.bufferPool.Put(b)
|
||
|
|
||
|
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
|
||
|
co.Close()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
res.Header.Reset()
|
||
|
res.Header.SetNoDefaultContentType(true)
|
||
|
if err := res.Header.ReadTrailer(br); err != nil {
|
||
|
co.Close()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if res.Header.Len() > 0 {
|
||
|
var announcedTrailersKey []string
|
||
|
if len(announcedTrailers) > 0 {
|
||
|
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||
|
}
|
||
|
|
||
|
res.Header.VisitAll(func(key, value []byte) {
|
||
|
for _, s := range announcedTrailersKey {
|
||
|
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||
|
rw.Header().Add(string(key), string(value))
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
p.connPool.ReleaseConn(co)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
brl := p.limitReaderPool.Get()
|
||
|
if brl == nil {
|
||
|
brl = &io.LimitedReader{}
|
||
|
}
|
||
|
defer p.limitReaderPool.Put(brl)
|
||
|
|
||
|
brl.R = br
|
||
|
brl.N = int64(res.Header.ContentLength())
|
||
|
|
||
|
b := p.bufferPool.Get()
|
||
|
if b == nil {
|
||
|
b = make([]byte, bufferSize)
|
||
|
}
|
||
|
defer p.bufferPool.Put(b)
|
||
|
|
||
|
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
|
||
|
co.Close()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
p.connPool.ReleaseConn(co)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *ReverseProxy) writeRequest(co net.Conn, outReq *fasthttp.Request) error {
|
||
|
bw := p.writerPool.Get()
|
||
|
if bw == nil {
|
||
|
bw = bufio.NewWriterSize(co, bufioSize)
|
||
|
}
|
||
|
defer p.writerPool.Put(bw)
|
||
|
|
||
|
bw.Reset(co)
|
||
|
|
||
|
if err := outReq.Write(bw); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return bw.Flush()
|
||
|
}
|
||
|
|
||
|
// isReplayable returns whether the request is replayable.
|
||
|
func isReplayable(req *http.Request) bool {
|
||
|
if req.Body == nil || req.Body == http.NoBody {
|
||
|
switch req.Method {
|
||
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// The Idempotency-Key, while non-standard, is widely used to
|
||
|
// mean a POST or other request is idempotent. See
|
||
|
// https://golang.org/issue/19943#issuecomment-421092421
|
||
|
if _, ok := req.Header["Idempotency-Key"]; ok {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
if _, ok := req.Header["X-Idempotency-Key"]; ok {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// isGraphic returns whether s is ASCII and printable according to
|
||
|
// https://tools.ietf.org/html/rfc20#section-4.2.
|
||
|
func isGraphic(s string) bool {
|
||
|
for i := range len(s) {
|
||
|
if s[i] < ' ' || s[i] > '~' {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
type fasthttpHeader interface {
|
||
|
Peek(key string) []byte
|
||
|
Set(key string, value string)
|
||
|
SetBytesV(key string, value []byte)
|
||
|
DelBytes(key []byte)
|
||
|
Del(key string)
|
||
|
}
|
||
|
|
||
|
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
|
||
|
// See RFC 7230, section 6.1.
|
||
|
func removeConnectionHeaders(h fasthttpHeader) {
|
||
|
f := h.Peek(fasthttp.HeaderConnection)
|
||
|
for _, sf := range bytes.Split(f, []byte{','}) {
|
||
|
if sf = bytes.TrimSpace(sf); len(sf) > 0 {
|
||
|
h.DelBytes(sf)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RFC 7234, section 5.4: Should treat Pragma: no-cache like Cache-Control: no-cache.
|
||
|
func fixPragmaCacheControl(header fasthttpHeader) {
|
||
|
if pragma := header.Peek("Pragma"); bytes.Equal(pragma, []byte("no-cache")) {
|
||
|
if len(header.Peek("Cache-Control")) == 0 {
|
||
|
header.Set("Cache-Control", "no-cache")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
||
|
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
||
|
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||
|
// https://tools.ietf.org/html/rfc6455#page-20
|
||
|
func cleanWebSocketHeaders(headers fasthttpHeader) {
|
||
|
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
|
||
|
headers.Del("Sec-Websocket-Key")
|
||
|
|
||
|
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
|
||
|
headers.Del("Sec-Websocket-Extensions")
|
||
|
|
||
|
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
|
||
|
headers.Del("Sec-Websocket-Accept")
|
||
|
|
||
|
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
|
||
|
headers.Del("Sec-Websocket-Protocol")
|
||
|
|
||
|
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
|
||
|
headers.DelBytes([]byte("Sec-Websocket-Version"))
|
||
|
}
|