package zk import ( "fmt" "io" "io/ioutil" "math/rand" "os" "path/filepath" "strings" "time" ) func init() { rand.Seed(time.Now().UnixNano()) } type TestServer struct { Port int Path string Srv *Server } type TestCluster struct { Path string Servers []TestServer } func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) { tmpPath, err := ioutil.TempDir("", "gozk") if err != nil { return nil, err } success := false startPort := int(rand.Int31n(6000) + 10000) cluster := &TestCluster{Path: tmpPath} defer func() { if !success { cluster.Stop() } }() for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN)) if err := os.Mkdir(srvPath, 0700); err != nil { return nil, err } port := startPort + serverN*3 cfg := ServerConfig{ ClientPort: port, DataDir: srvPath, } for i := 0; i < size; i++ { cfg.Servers = append(cfg.Servers, ServerConfigServer{ ID: i + 1, Host: "127.0.0.1", PeerPort: startPort + i*3 + 1, LeaderElectionPort: startPort + i*3 + 2, }) } cfgPath := filepath.Join(srvPath, "zoo.cfg") fi, err := os.Create(cfgPath) if err != nil { return nil, err } err = cfg.Marshall(fi) fi.Close() if err != nil { return nil, err } fi, err = os.Create(filepath.Join(srvPath, "myid")) if err != nil { return nil, err } _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() if err != nil { return nil, err } srv := &Server{ ConfigPath: cfgPath, Stdout: stdout, Stderr: stderr, } if err := srv.Start(); err != nil { return nil, err } cluster.Servers = append(cluster.Servers, TestServer{ Path: srvPath, Port: cfg.ClientPort, Srv: srv, }) } if err := cluster.waitForStart(10, time.Second); err != nil { return nil, err } success = true return cluster, nil } func (tc *TestCluster) Connect(idx int) (*Conn, error) { zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", tc.Servers[idx].Port)}, time.Second*15) return zk, err } func (tc *TestCluster) ConnectAll() (*Conn, <-chan Event, error) { return tc.ConnectAllTimeout(time.Second * 15) } func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) { return tc.ConnectWithOptions(sessionTimeout) } func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { hosts := make([]string, len(tc.Servers)) for i, srv := range tc.Servers { hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) } zk, ch, err := Connect(hosts, sessionTimeout, options...) return zk, ch, err } func (tc *TestCluster) Stop() error { for _, srv := range tc.Servers { srv.Srv.Stop() } defer os.RemoveAll(tc.Path) return tc.waitForStop(5, time.Second) } // waitForStart blocks until the cluster is up func (tc *TestCluster) waitForStart(maxRetry int, interval time.Duration) error { // verify that the servers are up with SRVR serverAddrs := make([]string, len(tc.Servers)) for i, s := range tc.Servers { serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) } for i := 0; i < maxRetry; i++ { _, ok := FLWSrvr(serverAddrs, time.Second) if ok { return nil } time.Sleep(interval) } return fmt.Errorf("unable to verify health of servers") } // waitForStop blocks until the cluster is down func (tc *TestCluster) waitForStop(maxRetry int, interval time.Duration) error { // verify that the servers are up with RUOK serverAddrs := make([]string, len(tc.Servers)) for i, s := range tc.Servers { serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) } var success bool for i := 0; i < maxRetry && !success; i++ { success = true for _, ok := range FLWRuok(serverAddrs, time.Second) { if ok { success = false } } if !success { time.Sleep(interval) } } if !success { return fmt.Errorf("unable to verify servers are down") } return nil } func (tc *TestCluster) StartServer(server string) { for _, s := range tc.Servers { if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { s.Srv.Start() return } } panic(fmt.Sprintf("Unknown server: %s", server)) } func (tc *TestCluster) StopServer(server string) { for _, s := range tc.Servers { if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { s.Srv.Stop() return } } panic(fmt.Sprintf("Unknown server: %s", server)) } func (tc *TestCluster) StartAllServers() error { for _, s := range tc.Servers { if err := s.Srv.Start(); err != nil { return fmt.Errorf( "Failed to start server listening on port `%d` : %+v", s.Port, err) } } return nil } func (tc *TestCluster) StopAllServers() error { for _, s := range tc.Servers { if err := s.Srv.Stop(); err != nil { return fmt.Errorf( "Failed to stop server listening on port `%d` : %+v", s.Port, err) } } return nil }