traefik/vendor/github.com/vulcand/oxy/forward/fwd.go
2018-12-04 11:42:03 +01:00

576 lines
16 KiB
Go

// Package forward implements http handler that forwards requests to remote server
// and serves back the response
// websocket proxying support based on https://github.com/yhat/wsutil
package forward
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os"
"reflect"
"strings"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
// OxyLogger interface of the internal
type OxyLogger interface {
log.FieldLogger
GetLevel() log.Level
}
type internalLogger struct {
*log.Logger
}
func (i *internalLogger) GetLevel() log.Level {
return i.Level
}
// ReqRewriter can alter request headers and body
type ReqRewriter interface {
Rewrite(r *http.Request)
}
type optSetter func(f *Forwarder) error
// PassHostHeader specifies if a client's Host header field should be delegated
func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.passHost = b
return nil
}
}
// RoundTripper sets a new http.RoundTripper
// Forwarder will use http.DefaultTransport as a default round tripper
func RoundTripper(r http.RoundTripper) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.roundTripper = r
return nil
}
}
// Rewriter defines a request rewriter for the HTTP forwarder
func Rewriter(r ReqRewriter) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.rewriter = r
return nil
}
}
// WebsocketTLSClientConfig define the websocker client TLS configuration
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.tlsClientConfig = tcc
return nil
}
}
// ErrorHandler is a functional argument that sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) optSetter {
return func(f *Forwarder) error {
f.errHandler = h
return nil
}
}
// BufferPool specifies a buffer pool for httputil.ReverseProxy.
func BufferPool(pool httputil.BufferPool) optSetter {
return func(f *Forwarder) error {
f.bufferPool = pool
return nil
}
}
// Stream specifies if HTTP responses should be streamed.
func Stream(stream bool) optSetter {
return func(f *Forwarder) error {
f.stream = stream
return nil
}
}
// Logger defines the logger the forwarder will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l log.FieldLogger) optSetter {
return func(f *Forwarder) error {
if logger, ok := l.(OxyLogger); ok {
f.log = logger
return nil
}
if logger, ok := l.(*log.Logger); ok {
f.log = &internalLogger{Logger: logger}
return nil
}
return errors.New("the type of the logger must be OxyLogger or logrus.Logger")
}
}
// StateListener defines a state listener for the HTTP forwarder
func StateListener(stateListener UrlForwardingStateListener) optSetter {
return func(f *Forwarder) error {
f.stateListener = stateListener
return nil
}
}
// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketConnectionClosedHook = hook
return nil
}
}
// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier
return nil
}
}
// StreamingFlushInterval defines a streaming flush interval for the HTTP forwarder
func StreamingFlushInterval(flushInterval time.Duration) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.flushInterval = flushInterval
return nil
}
}
// ErrorHandlingRoundTripper a error handling round tripper
type ErrorHandlingRoundTripper struct {
http.RoundTripper
errorHandler utils.ErrorHandler
}
// RoundTrip executes the round trip
func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTripper.RoundTrip(req)
if err != nil {
// We use the recorder from httptest because there isn't another `public` implementation of a recorder.
recorder := httptest.NewRecorder()
rt.errorHandler.ServeHTTP(recorder, req, err)
res = recorder.Result()
err = nil
}
return res, err
}
// Forwarder wraps two traffic forwarding implementations: HTTP and websockets.
// It decides based on the specified request which implementation to use
type Forwarder struct {
*httpForwarder
*handlerContext
stateListener UrlForwardingStateListener
stream bool
}
// handlerContext defines a handler context for error reporting and logging
type handlerContext struct {
errHandler utils.ErrorHandler
}
// httpForwarder is a handler that can reverse proxy
// HTTP traffic
type httpForwarder struct {
roundTripper http.RoundTripper
rewriter ReqRewriter
passHost bool
flushInterval time.Duration
modifyResponse func(*http.Response) error
tlsClientConfig *tls.Config
log OxyLogger
bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
}
const defaultFlushInterval = time.Duration(100) * time.Millisecond
// Connection states
const (
StateConnected = iota
StateDisconnected
)
// UrlForwardingStateListener URL forwarding state listener
type UrlForwardingStateListener func(*url.URL, int)
// New creates an instance of Forwarder based on the provided list of configuration options
func New(setters ...optSetter) (*Forwarder, error) {
f := &Forwarder{
httpForwarder: &httpForwarder{log: &internalLogger{Logger: log.StandardLogger()}},
handlerContext: &handlerContext{},
}
for _, s := range setters {
if err := s(f); err != nil {
return nil, err
}
}
if !f.stream {
f.flushInterval = 0
} else if f.flushInterval == 0 {
f.flushInterval = defaultFlushInterval
}
if f.httpForwarder.rewriter == nil {
h, err := os.Hostname()
if err != nil {
h = "localhost"
}
f.httpForwarder.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h}
}
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
}
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}
if f.tlsClientConfig == nil {
if ht, ok := f.httpForwarder.roundTripper.(*http.Transport); ok {
f.tlsClientConfig = ht.TLSClientConfig
}
}
f.httpForwarder.roundTripper = ErrorHandlingRoundTripper{
RoundTripper: f.httpForwarder.roundTripper,
errorHandler: f.errHandler,
}
f.postConfig()
return f, nil
}
// ServeHTTP decides which forwarder to use based on the specified
// request and delegates to the proper implementation
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if f.log.GetLevel() >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward: completed ServeHttp on request")
}
if f.stateListener != nil {
f.stateListener(req.URL, StateConnected)
defer f.stateListener(req.URL, StateDisconnected)
}
if IsWebsocketRequest(req) {
f.httpForwarder.serveWebSocket(w, req, f.handlerContext)
} else {
f.httpForwarder.serveHTTP(w, req, f.handlerContext)
}
}
func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL {
// If the Request was created by Go via a real HTTP request, RequestURI will
// contain the original query string. If the Request was created in code, RequestURI
// will be empty, and we will use the URL object instead
u := req.URL
if req.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(req.RequestURI)
if err == nil {
u = parsedURL
} else {
f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err)
}
}
return u
}
// Modify the request to handle the target URL
func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) {
outReq.URL = utils.CopyURL(outReq.URL)
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
u := f.getUrlFromRequest(outReq)
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = target.Host
}
}
// serveHTTP forwards websocket traffic
func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
if f.log.GetLevel() >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward/websocket: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/websocket: completed ServeHttp on request")
}
outReq := f.copyWebSocketRequest(req)
dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil {
dialer.TLSClientConfig = f.tlsClientConfig.Clone()
// WebSocket is only in http/1.1
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header)
if err != nil {
if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err)
} else {
f.log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
hijacker, ok := w.(http.Hijacker)
if !ok {
f.log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w))
ctx.errHandler.ServeHTTP(w, req, err)
return
}
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, errHijack)
return
}
defer conn.Close()
errWrite := resp.Write(conn)
if errWrite != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, errWrite)
return
}
}
return
}
// Only the targetConn choose to CheckOrigin or not
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...)
utils.CopyHeaders(resp.Header, w.Header())
underlyingConn, err := upgrader.Upgrade(w, req, resp.Header)
if err != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer func() {
underlyingConn.Close()
targetConn.Close()
if f.websocketConnectionClosedHook != nil {
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
}
}()
errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
forward := func(messageType int, reader io.Reader) error {
writer, err := dst.NextWriter(messageType)
if err != nil {
return err
}
_, err = io.Copy(writer, reader)
if err != nil {
return err
}
return writer.Close()
}
src.SetPingHandler(func(data string) error {
return forward(websocket.PingMessage, bytes.NewReader([]byte(data)))
})
src.SetPongHandler(func(data string) error {
return forward(websocket.PongMessage, bytes.NewReader([]byte(data)))
})
for {
msgType, reader, err := src.NextReader()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = nil
// Following codes are not valid on the wire so just close the
// underlying TCP connection without sending a close frame.
if e.Code != websocket.CloseAbnormalClosure &&
e.Code != websocket.CloseTLSHandshake {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
}
errc <- err
if m != nil {
forward(websocket.CloseMessage, bytes.NewReader([]byte(m)))
}
break
}
err = forward(msgType, reader)
if err != nil {
errc <- err
break
}
}
}
go replicateWebsocketConn(underlyingConn, targetConn, errClient)
go replicateWebsocketConn(targetConn, underlyingConn, errBackend)
var message string
select {
case err = <-errClient:
message = "vulcand/oxy/forward/websocket: Error when copying from backend to client: %v"
case err = <-errBackend:
message = "vulcand/oxy/forward/websocket: Error when copying from client to backend: %v"
}
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
f.log.Errorf(message, err)
}
}
// copyWebsocketRequest makes a copy of the specified request.
func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) {
outReq = new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = req.URL.Scheme
// sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
switch req.URL.Scheme {
case "https":
outReq.URL.Scheme = "wss"
case "http":
outReq.URL.Scheme = "ws"
}
u := f.getUrlFromRequest(outReq)
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.URL.Host = req.URL.Host
if !f.passHost {
outReq.Host = req.URL.Host
}
outReq.Header = make(http.Header)
// gorilla websocket use this header to set the request.Host tested in checkSameOrigin
outReq.Header.Set("Host", outReq.Host)
utils.CopyHeaders(outReq.Header, req.Header)
utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...)
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
return outReq
}
// serveHTTP forwards HTTP traffic using the configured transport
func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ctx *handlerContext) {
if f.log.GetLevel() >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(inReq))
logEntry.Debug("vulcand/oxy/forward/http: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request")
}
start := time.Now().UTC()
outReq := new(http.Request)
*outReq = *inReq // includes shallow copies of maps, but we handle this in Director
revproxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
f.modifyRequest(req, inReq.URL)
},
Transport: f.roundTripper,
FlushInterval: f.flushInterval,
ModifyResponse: f.modifyResponse,
BufferPool: f.bufferPool,
}
if f.log.GetLevel() >= log.DebugLevel {
pw := utils.NewProxyWriter(w)
revproxy.ServeHTTP(pw, outReq)
if inReq.TLS != nil {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start),
inReq.TLS.Version,
inReq.TLS.DidResume,
inReq.TLS.CipherSuite,
inReq.TLS.ServerName)
} else {
f.log.Debugf("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v",
inReq.URL, pw.StatusCode(), pw.GetLength(), time.Now().UTC().Sub(start))
}
} else {
revproxy.ServeHTTP(w, outReq)
}
for key := range w.Header() {
if strings.HasPrefix(key, http.TrailerPrefix) {
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
break
}
}
}
// IsWebsocketRequest determines if the specified HTTP request is a
// websocket handshake request
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
}
return containsHeader(Connection, "upgrade") && containsHeader(Upgrade, "websocket")
}