package integration import ( "crypto/tls" "crypto/x509" "io/ioutil" "net" "net/http" "net/http/httptest" "os" "time" "github.com/containous/traefik/integration/try" "github.com/go-check/check" gorillawebsocket "github.com/gorilla/websocket" checker "github.com/vdemeester/shakers" "golang.org/x/net/websocket" ) // WebsocketSuite type WebsocketSuite struct{ BaseSuite } func (s *WebsocketSuite) TestBase(c *check.C) { var upgrader = gorillawebsocket.Upgrader{} // use default options srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) file := s.adaptFile(c, "fixtures/websocket/config.toml", struct { WebsocketServer string }{ WebsocketServer: srv.URL, }) defer os.Remove(file) cmd, display := s.traefikCmd(withConfigFile(file), "--debug") defer display(c) err := cmd.Start() c.Assert(err, check.IsNil) defer cmd.Process.Kill() // wait for traefik err = try.GetRequest("http://127.0.0.1:8080/api/providers", 10*time.Second, try.BodyContains("127.0.0.1")) c.Assert(err, checker.IsNil) conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws://127.0.0.1:8000/ws", nil) c.Assert(err, checker.IsNil) err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) c.Assert(err, checker.IsNil) _, msg, err := conn.ReadMessage() c.Assert(err, checker.IsNil) c.Assert(string(msg), checker.Equals, "OK") } func (s *WebsocketSuite) TestWrongOrigin(c *check.C) { var upgrader = gorillawebsocket.Upgrader{} // use default options srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) file := s.adaptFile(c, "fixtures/websocket/config.toml", struct { WebsocketServer string }{ WebsocketServer: srv.URL, }) defer os.Remove(file) cmd, display := s.traefikCmd(withConfigFile(file), "--debug") defer display(c) err := cmd.Start() c.Assert(err, check.IsNil) defer cmd.Process.Kill() // wait for traefik err = try.GetRequest("http://127.0.0.1:8080/api/providers", 10*time.Second, try.BodyContains("127.0.0.1")) c.Assert(err, checker.IsNil) config, err := websocket.NewConfig("ws://127.0.0.1:8000/ws", "ws://127.0.0.1:800") c.Assert(err, check.IsNil) conn, err := net.DialTimeout("tcp", "127.0.0.1:8000", time.Second) c.Assert(err, checker.IsNil) _, err = websocket.NewClient(config, conn) c.Assert(err, checker.NotNil) c.Assert(err, checker.ErrorMatches, "bad status") } func (s *WebsocketSuite) TestOrigin(c *check.C) { // use default options var upgrader = gorillawebsocket.Upgrader{} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) file := s.adaptFile(c, "fixtures/websocket/config.toml", struct { WebsocketServer string }{ WebsocketServer: srv.URL, }) defer os.Remove(file) cmd, display := s.traefikCmd(withConfigFile(file), "--debug") defer display(c) err := cmd.Start() c.Assert(err, check.IsNil) defer cmd.Process.Kill() // wait for traefik err = try.GetRequest("http://127.0.0.1:8080/api/providers", 10*time.Second, try.BodyContains("127.0.0.1")) c.Assert(err, checker.IsNil) config, err := websocket.NewConfig("ws://127.0.0.1:8000/ws", "ws://127.0.0.1:8000") c.Assert(err, check.IsNil) conn, err := net.DialTimeout("tcp", "127.0.0.1:8000", time.Second) c.Assert(err, check.IsNil) client, err := websocket.NewClient(config, conn) c.Assert(err, checker.IsNil) n, err := client.Write([]byte("OK")) c.Assert(err, checker.IsNil) c.Assert(n, checker.Equals, 2) msg := make([]byte, 2) n, err = client.Read(msg) c.Assert(err, checker.IsNil) c.Assert(n, checker.Equals, 2) c.Assert(string(msg), checker.Equals, "OK") } func (s *WebsocketSuite) TestWrongOriginIgnoredByServer(c *check.C) { var upgrader = gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) file := s.adaptFile(c, "fixtures/websocket/config.toml", struct { WebsocketServer string }{ WebsocketServer: srv.URL, }) defer os.Remove(file) cmd, display := s.traefikCmd(withConfigFile(file), "--debug") defer display(c) err := cmd.Start() c.Assert(err, check.IsNil) defer cmd.Process.Kill() // wait for traefik err = try.GetRequest("http://127.0.0.1:8080/api/providers", 10*time.Second, try.BodyContains("127.0.0.1")) c.Assert(err, checker.IsNil) config, err := websocket.NewConfig("ws://127.0.0.1:8000/ws", "ws://127.0.0.1:80") c.Assert(err, check.IsNil) conn, err := net.DialTimeout("tcp", "127.0.0.1:8000", time.Second) c.Assert(err, checker.IsNil) client, err := websocket.NewClient(config, conn) c.Assert(err, checker.IsNil) n, err := client.Write([]byte("OK")) c.Assert(err, checker.IsNil) c.Assert(n, checker.Equals, 2) msg := make([]byte, 2) n, err = client.Read(msg) c.Assert(err, checker.IsNil) c.Assert(n, checker.Equals, 2) c.Assert(string(msg), checker.Equals, "OK") } func (s *WebsocketSuite) TestSSLTermination(c *check.C) { var upgrader = gorillawebsocket.Upgrader{} // use default options srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer c.Close() for { mt, message, err := c.ReadMessage() if err != nil { break } err = c.WriteMessage(mt, message) if err != nil { break } } })) file := s.adaptFile(c, "fixtures/websocket/config_https.toml", struct { WebsocketServer string }{ WebsocketServer: srv.URL, }) defer os.Remove(file) cmd, display := s.traefikCmd(withConfigFile(file), "--debug") defer display(c) err := cmd.Start() c.Assert(err, check.IsNil) defer cmd.Process.Kill() // wait for traefik err = try.GetRequest("http://127.0.0.1:8080/api/providers", 10*time.Second, try.BodyContains("127.0.0.1")) c.Assert(err, checker.IsNil) //Add client self-signed cert roots := x509.NewCertPool() certContent, err := ioutil.ReadFile("./resources/tls/local.cert") roots.AppendCertsFromPEM(certContent) gorillawebsocket.DefaultDialer.TLSClientConfig = &tls.Config{ RootCAs: roots, } conn, _, err := gorillawebsocket.DefaultDialer.Dial("wss://127.0.0.1:8000/ws", nil) c.Assert(err, checker.IsNil) err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) c.Assert(err, checker.IsNil) _, msg, err := conn.ReadMessage() c.Assert(err, checker.IsNil) c.Assert(string(msg), checker.Equals, "OK") }