diff --git a/middlewares/websocket.go b/middlewares/websocket.go new file mode 100644 index 000000000..ac7d6a035 --- /dev/null +++ b/middlewares/websocket.go @@ -0,0 +1,52 @@ +/* +Copyright +*/ +package middlewares + +import ( + log "github.com/Sirupsen/logrus" + "github.com/mailgun/oxy/roundrobin" + "net/http" + "strings" + "time" +) + +type WebsocketUpgrader struct { + rr *roundrobin.RoundRobin +} + +func NewWebsocketUpgrader(rr *roundrobin.RoundRobin) *WebsocketUpgrader { + wu := WebsocketUpgrader{ + rr: rr, + } + return &wu +} + +func (u *WebsocketUpgrader) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // If request is websocket, serve with golang websocket server to do protocol handshake + if strings.Join(req.Header["Upgrade"], "") == "websocket" { + start := time.Now().UTC() + url, err := u.rr.NextServer() + if err != nil { + log.Errorf("Can't round robin in websocket middleware") + return + } + log.Debugf("Websocket forward to %s", url.String()) + NewProxy(url).ServeHTTP(w, req) + + if req.TLS != nil { + log.Debugf("Round trip: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v", + req.URL, time.Now().UTC().Sub(start), + req.TLS.Version, + req.TLS.DidResume, + req.TLS.CipherSuite, + req.TLS.ServerName) + } else { + log.Debugf("Round trip: %v, duration: %v", + req.URL, time.Now().UTC().Sub(start)) + } + + return + } + u.rr.ServeHTTP(w, req) +} diff --git a/middlewares/websocketproxy.go b/middlewares/websocketproxy.go new file mode 100644 index 000000000..abb82688f --- /dev/null +++ b/middlewares/websocketproxy.go @@ -0,0 +1,170 @@ +package middlewares + +import ( + "io" + "net" + "net/http" + "net/url" + "strings" + + log "github.com/Sirupsen/logrus" + "github.com/gorilla/websocket" +) + +// Original developpement made by https://github.com/koding/websocketproxy +var ( + // DefaultUpgrader specifies the parameters for upgrading an HTTP + // connection to a WebSocket connection. + DefaultUpgrader = &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + // DefaultDialer is a dialer with all fields set to the default zero values. + DefaultDialer = websocket.DefaultDialer +) + +// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket +// connection and proxies it to another server. +type WebsocketProxy struct { + // Backend returns the backend URL which the proxy uses to reverse proxy + // the incoming WebSocket connection. Request is the initial incoming and + // unmodified request. + Backend func(*http.Request) *url.URL + + // Upgrader specifies the parameters for upgrading a incoming HTTP + // connection to a WebSocket connection. If nil, DefaultUpgrader is used. + Upgrader *websocket.Upgrader + + // Dialer contains options for connecting to the backend WebSocket server. + // If nil, DefaultDialer is used. + Dialer *websocket.Dialer +} + +// ProxyHandler returns a new http.Handler interface that reverse proxies the +// request to the given target. +func ProxyHandler(target *url.URL) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + NewProxy(target).ServeHTTP(rw, req) + }) +} + +// NewProxy returns a new Websocket reverse proxy that rewrites the +// URL's to the scheme, host and base path provider in target. +func NewProxy(target *url.URL) *WebsocketProxy { + backend := func(r *http.Request) *url.URL { + // Shallow copy + u := *target + u.Fragment = r.URL.Fragment + u.Path = r.URL.Path + u.RawQuery = r.URL.RawQuery + rurl := u.String() + if strings.HasPrefix(rurl, "http") { + u.Scheme = "ws" + } + if strings.HasPrefix(rurl, "https") { + u.Scheme = "wss" + } + return &u + } + return &WebsocketProxy{Backend: backend} +} + +// ServeHTTP implements the http.Handler that proxies WebSocket connections. +func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if w.Backend == nil { + log.Println("websocketproxy: backend function is not defined") + http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError) + return + } + + backendURL := w.Backend(req) + if backendURL == nil { + log.Println("websocketproxy: backend URL is nil") + http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError) + return + } + + dialer := w.Dialer + if w.Dialer == nil { + dialer = DefaultDialer + } + + // Pass headers from the incoming request to the dialer to forward them to + // the final destinations. + requestHeader := http.Header{} + requestHeader.Add("Origin", req.Header.Get("Origin")) + for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + requestHeader.Add("Sec-WebSocket-Protocol", prot) + } + for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] { + requestHeader.Add("Cookie", cookie) + } + + // Pass X-Forwarded-For headers too, code below is a part of + // httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For + // for more information + // TODO: use RFC7239 http://tools.ietf.org/html/rfc7239 + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := req.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + requestHeader.Set("X-Forwarded-For", clientIP) + } + + // Set the originating protocol of the incoming HTTP request. The SSL might + // be terminated on our site and because we doing proxy adding this would + // be helpful for applications on the backend. + requestHeader.Set("X-Forwarded-Proto", "http") + if req.TLS != nil { + requestHeader.Set("X-Forwarded-Proto", "https") + } + + // Connect to the backend URL, also pass the headers we get from the requst + // together with the Forwarded headers we prepared above. + // TODO: support multiplexing on the same backend connection instead of + // opening a new TCP connection time for each request. This should be + // optional: + // http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01 + connBackend, resp, err := dialer.Dial(backendURL.String(), nil) + if err != nil { + log.Printf("websocketproxy: couldn't dial to remote backend url %s, %s, %+v", backendURL.String(), err, resp) + return + } + defer connBackend.Close() + + upgrader := w.Upgrader + if w.Upgrader == nil { + upgrader = DefaultUpgrader + } + + // Only pass those headers to the upgrader. + upgradeHeader := http.Header{} + upgradeHeader.Set("Sec-WebSocket-Protocol", + resp.Header.Get(http.CanonicalHeaderKey("Sec-WebSocket-Protocol"))) + upgradeHeader.Set("Set-Cookie", + resp.Header.Get(http.CanonicalHeaderKey("Set-Cookie"))) + + // Now upgrade the existing incoming request to a WebSocket connection. + // Also pass the header that we gathered from the Dial handshake. + connPub, err := upgrader.Upgrade(rw, req, upgradeHeader) + if err != nil { + log.Printf("websocketproxy: couldn't upgrade %s\n", err) + return + } + defer connPub.Close() + + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + + // Start our proxy now, everything is ready... + go cp(connBackend.UnderlyingConn(), connPub.UnderlyingConn()) + go cp(connPub.UnderlyingConn(), connBackend.UnderlyingConn()) + <-errc +} diff --git a/traefik.go b/traefik.go index 539dadf2c..a928b1178 100644 --- a/traefik.go +++ b/traefik.go @@ -278,7 +278,7 @@ func LoadConfig(configurations configs, globalConfiguration *GlobalConfiguration } case wrr: log.Infof("Creating load-balancer wrr") - lb = rr + lb = middlewares.NewWebsocketUpgrader(rr) for serverName, server := range configuration.Backends[frontend.Backend].Servers { url, err := url.Parse(server.URL) if err != nil {