package middlewares import ( "io" "net" "net/http" "net/url" "strings" log "github.com/Sirupsen/logrus" "github.com/gorilla/websocket" ) // Original developpement made by https://github.com/koding/websocketproxy var ( // DefaultUpgrader specifies the parameters for upgrading an HTTP // connection to a WebSocket connection. DefaultUpgrader = &websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } // DefaultDialer is a dialer with all fields set to the default zero values. DefaultDialer = websocket.DefaultDialer ) // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket // connection and proxies it to another server. type WebsocketProxy struct { // Backend returns the backend URL which the proxy uses to reverse proxy // the incoming WebSocket connection. Request is the initial incoming and // unmodified request. Backend func(*http.Request) *url.URL // Upgrader specifies the parameters for upgrading a incoming HTTP // connection to a WebSocket connection. If nil, DefaultUpgrader is used. Upgrader *websocket.Upgrader // Dialer contains options for connecting to the backend WebSocket server. // If nil, DefaultDialer is used. Dialer *websocket.Dialer } // ProxyHandler returns a new http.Handler interface that reverse proxies the // request to the given target. func ProxyHandler(target *url.URL) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { NewProxy(target).ServeHTTP(rw, req) }) } // NewProxy returns a new Websocket reverse proxy that rewrites the // URL's to the scheme, host and base path provider in target. func NewProxy(target *url.URL) *WebsocketProxy { backend := func(r *http.Request) *url.URL { // Shallow copy u := *target u.Fragment = r.URL.Fragment u.Path = r.URL.Path u.RawQuery = r.URL.RawQuery rurl := u.String() if strings.HasPrefix(rurl, "http") { u.Scheme = "ws" } if strings.HasPrefix(rurl, "https") { u.Scheme = "wss" } return &u } return &WebsocketProxy{Backend: backend} } // ServeHTTP implements the http.Handler that proxies WebSocket connections. func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if w.Backend == nil { log.Errorf("Websocketproxy: backend function is not defined") http.Error(rw, "Backend not found", http.StatusInternalServerError) http.NotFound(rw, req) return } backendURL := w.Backend(req) if backendURL == nil { log.Errorf("Websocketproxy: backend URL is nil") http.Error(rw, "Backend URL is nil", http.StatusInternalServerError) return } dialer := w.Dialer if w.Dialer == nil { dialer = DefaultDialer } // Pass headers from the incoming request to the dialer to forward them to // the final destinations. requestHeader := http.Header{} requestHeader.Add("Origin", req.Header.Get("Origin")) for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { requestHeader.Add("Sec-WebSocket-Protocol", prot) } for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] { requestHeader.Add("Cookie", cookie) } // Pass X-Forwarded-For headers too, code below is a part of // httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For // for more information // TODO: use RFC7239 http://tools.ietf.org/html/rfc7239 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } requestHeader.Set("X-Forwarded-For", clientIP) } // Set the originating protocol of the incoming HTTP request. The SSL might // be terminated on our site and because we doing proxy adding this would // be helpful for applications on the backend. requestHeader.Set("X-Forwarded-Proto", "http") if req.TLS != nil { requestHeader.Set("X-Forwarded-Proto", "https") } // Connect to the backend URL, also pass the headers we get from the requst // together with the Forwarded headers we prepared above. // TODO: support multiplexing on the same backend connection instead of // opening a new TCP connection time for each request. This should be // optional: // http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01 connBackend, resp, err := dialer.Dial(backendURL.String(), nil) if err != nil { log.Errorf("Websocketproxy: couldn't dial to remote backend url %s, %s, %+v", backendURL.String(), err, resp) http.NotFound(rw, req) return } defer connBackend.Close() upgrader := w.Upgrader if w.Upgrader == nil { upgrader = DefaultUpgrader } // Only pass those headers to the upgrader. upgradeHeader := http.Header{} upgradeHeader.Set("Sec-WebSocket-Protocol", resp.Header.Get(http.CanonicalHeaderKey("Sec-WebSocket-Protocol"))) upgradeHeader.Set("Set-Cookie", resp.Header.Get(http.CanonicalHeaderKey("Set-Cookie"))) // Now upgrade the existing incoming request to a WebSocket connection. // Also pass the header that we gathered from the Dial handshake. connPub, err := upgrader.Upgrade(rw, req, upgradeHeader) if err != nil { log.Errorf("Websocketproxy: couldn't upgrade %s", err) http.NotFound(rw, req) return } defer connPub.Close() errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) errc <- err } // Start our proxy now, everything is ready... go cp(connBackend.UnderlyingConn(), connPub.UnderlyingConn()) go cp(connPub.UnderlyingConn(), connBackend.UnderlyingConn()) <-errc }