191 lines
4.2 KiB
Go
191 lines
4.2 KiB
Go
|
package zk
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"math/rand"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
func init() {
|
||
|
rand.Seed(time.Now().UnixNano())
|
||
|
}
|
||
|
|
||
|
type TestServer struct {
|
||
|
Port int
|
||
|
Path string
|
||
|
Srv *Server
|
||
|
}
|
||
|
|
||
|
type TestCluster struct {
|
||
|
Path string
|
||
|
Servers []TestServer
|
||
|
}
|
||
|
|
||
|
func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) {
|
||
|
tmpPath, err := ioutil.TempDir("", "gozk")
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
success := false
|
||
|
startPort := int(rand.Int31n(6000) + 10000)
|
||
|
cluster := &TestCluster{Path: tmpPath}
|
||
|
defer func() {
|
||
|
if !success {
|
||
|
cluster.Stop()
|
||
|
}
|
||
|
}()
|
||
|
for serverN := 0; serverN < size; serverN++ {
|
||
|
srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN))
|
||
|
if err := os.Mkdir(srvPath, 0700); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
port := startPort + serverN*3
|
||
|
cfg := ServerConfig{
|
||
|
ClientPort: port,
|
||
|
DataDir: srvPath,
|
||
|
}
|
||
|
for i := 0; i < size; i++ {
|
||
|
cfg.Servers = append(cfg.Servers, ServerConfigServer{
|
||
|
ID: i + 1,
|
||
|
Host: "127.0.0.1",
|
||
|
PeerPort: startPort + i*3 + 1,
|
||
|
LeaderElectionPort: startPort + i*3 + 2,
|
||
|
})
|
||
|
}
|
||
|
cfgPath := filepath.Join(srvPath, "zoo.cfg")
|
||
|
fi, err := os.Create(cfgPath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
err = cfg.Marshall(fi)
|
||
|
fi.Close()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
fi, err = os.Create(filepath.Join(srvPath, "myid"))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
_, err = fmt.Fprintf(fi, "%d\n", serverN+1)
|
||
|
fi.Close()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
srv := &Server{
|
||
|
ConfigPath: cfgPath,
|
||
|
Stdout: stdout,
|
||
|
Stderr: stderr,
|
||
|
}
|
||
|
if err := srv.Start(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
cluster.Servers = append(cluster.Servers, TestServer{
|
||
|
Path: srvPath,
|
||
|
Port: cfg.ClientPort,
|
||
|
Srv: srv,
|
||
|
})
|
||
|
}
|
||
|
if err := cluster.waitForStart(10, time.Second); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
success = true
|
||
|
return cluster, nil
|
||
|
}
|
||
|
|
||
|
func (ts *TestCluster) Connect(idx int) (*Conn, error) {
|
||
|
zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", ts.Servers[idx].Port)}, time.Second*15)
|
||
|
return zk, err
|
||
|
}
|
||
|
|
||
|
func (ts *TestCluster) ConnectAll() (*Conn, <-chan Event, error) {
|
||
|
return ts.ConnectAllTimeout(time.Second * 15)
|
||
|
}
|
||
|
|
||
|
func (ts *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
|
||
|
hosts := make([]string, len(ts.Servers))
|
||
|
for i, srv := range ts.Servers {
|
||
|
hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port)
|
||
|
}
|
||
|
zk, ch, err := Connect(hosts, sessionTimeout)
|
||
|
return zk, ch, err
|
||
|
}
|
||
|
|
||
|
func (ts *TestCluster) Stop() error {
|
||
|
for _, srv := range ts.Servers {
|
||
|
srv.Srv.Stop()
|
||
|
}
|
||
|
defer os.RemoveAll(ts.Path)
|
||
|
return ts.waitForStop(5, time.Second)
|
||
|
}
|
||
|
|
||
|
// waitForStart blocks until the cluster is up
|
||
|
func (ts *TestCluster) waitForStart(maxRetry int, interval time.Duration) error {
|
||
|
// verify that the servers are up with SRVR
|
||
|
serverAddrs := make([]string, len(ts.Servers))
|
||
|
for i, s := range ts.Servers {
|
||
|
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
|
||
|
}
|
||
|
|
||
|
for i := 0; i < maxRetry; i++ {
|
||
|
_, ok := FLWSrvr(serverAddrs, time.Second)
|
||
|
if ok {
|
||
|
return nil
|
||
|
}
|
||
|
time.Sleep(interval)
|
||
|
}
|
||
|
return fmt.Errorf("unable to verify health of servers")
|
||
|
}
|
||
|
|
||
|
// waitForStop blocks until the cluster is down
|
||
|
func (ts *TestCluster) waitForStop(maxRetry int, interval time.Duration) error {
|
||
|
// verify that the servers are up with RUOK
|
||
|
serverAddrs := make([]string, len(ts.Servers))
|
||
|
for i, s := range ts.Servers {
|
||
|
serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port)
|
||
|
}
|
||
|
|
||
|
var success bool
|
||
|
for i := 0; i < maxRetry && !success; i++ {
|
||
|
success = true
|
||
|
for _, ok := range FLWRuok(serverAddrs, time.Second) {
|
||
|
if ok {
|
||
|
success = false
|
||
|
}
|
||
|
}
|
||
|
if !success {
|
||
|
time.Sleep(interval)
|
||
|
}
|
||
|
}
|
||
|
if !success {
|
||
|
return fmt.Errorf("unable to verify servers are down")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (tc *TestCluster) StartServer(server string) {
|
||
|
for _, s := range tc.Servers {
|
||
|
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
|
||
|
s.Srv.Start()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
panic(fmt.Sprintf("Unknown server: %s", server))
|
||
|
}
|
||
|
|
||
|
func (tc *TestCluster) StopServer(server string) {
|
||
|
for _, s := range tc.Servers {
|
||
|
if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) {
|
||
|
s.Srv.Stop()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
panic(fmt.Sprintf("Unknown server: %s", server))
|
||
|
}
|