// Package zk is a native Go client library for the ZooKeeper orchestration service. package zk /* TODO: * make sure a ping response comes back in a reasonable time Possible watcher events: * Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err} */ import ( "crypto/rand" "encoding/binary" "errors" "fmt" "io" "net" "strconv" "strings" "sync" "sync/atomic" "time" ) // ErrNoServer indicates that an operation cannot be completed // because attempts to connect to all servers in the list failed. var ErrNoServer = errors.New("zk: could not connect to a server") // ErrInvalidPath indicates that an operation was being attempted on // an invalid path. (e.g. empty path) var ErrInvalidPath = errors.New("zk: invalid path") // DefaultLogger uses the stdlib log package for logging. var DefaultLogger Logger = defaultLogger{} const ( bufferSize = 1536 * 1024 eventChanSize = 6 sendChanSize = 16 protectedPrefix = "_c_" ) type watchType int const ( watchTypeData = iota watchTypeExist = iota watchTypeChild = iota ) type watchPathType struct { path string wType watchType } type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) // Logger is an interface that can be implemented to provide custom log output. type Logger interface { Printf(string, ...interface{}) } type Conn struct { lastZxid int64 sessionID int64 state State // must be 32-bit aligned xid uint32 sessionTimeoutMs int32 // session timeout in milliseconds passwd []byte dialer Dialer hostProvider HostProvider serverMu sync.Mutex // protects server server string // remember the address/port of the current server conn net.Conn eventChan chan Event shouldQuit chan struct{} pingInterval time.Duration recvTimeout time.Duration connectTimeout time.Duration sendChan chan *request requests map[int32]*request // Xid -> pending request requestsLock sync.Mutex watchers map[watchPathType][]chan Event watchersLock sync.Mutex // Debug (used by unit tests) reconnectDelay time.Duration logger Logger } // connOption represents a connection option. type connOption func(c *Conn) type request struct { xid int32 opcode int32 pkt interface{} recvStruct interface{} recvChan chan response // Because sending and receiving happen in separate go routines, there's // a possible race condition when creating watches from outside the read // loop. We must ensure that a watcher gets added to the list synchronously // with the response from the server on any request that creates a watch. // In order to not hard code the watch logic for each opcode in the recv // loop the caller can use recvFunc to insert some synchronously code // after a response. recvFunc func(*request, *responseHeader, error) } type response struct { zxid int64 err error } type Event struct { Type EventType State State Path string // For non-session events, the path of the watched node. Err error Server string // For connection events } // HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. // It is an analog of the Java equivalent: // http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup type HostProvider interface { // Init is called first, with the servers specified in the connection string. Init(servers []string) error // Len returns the number of servers. Len() int // Next returns the next server to connect to. retryStart will be true if we've looped through // all known servers without Connected() being called. Next() (server string, retryStart bool) // Notify the HostProvider of a successful connection. Connected() } // ConnectWithDialer establishes a new connection to a pool of zookeeper servers // using a custom Dialer. See Connect for further information about session timeout. // This method is deprecated and provided for compatibility: use the WithDialer option instead. func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { return Connect(servers, sessionTimeout, WithDialer(dialer)) } // Connect establishes a new connection to a pool of zookeeper // servers. The provided session timeout sets the amount of time for which // a session is considered valid after losing connection to a server. Within // the session timeout it's possible to reestablish a connection to a different // server and keep the same session. This is means any ephemeral nodes and // watches are maintained. func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { if len(servers) == 0 { return nil, nil, errors.New("zk: server list must not be empty") } srvs := make([]string, len(servers)) for i, addr := range servers { if strings.Contains(addr, ":") { srvs[i] = addr } else { srvs[i] = addr + ":" + strconv.Itoa(DefaultPort) } } // Randomize the order of the servers to avoid creating hotspots stringShuffle(srvs) ec := make(chan Event, eventChanSize) conn := &Conn{ dialer: net.DialTimeout, hostProvider: &DNSHostProvider{}, conn: nil, state: StateDisconnected, eventChan: ec, shouldQuit: make(chan struct{}), connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), requests: make(map[int32]*request), watchers: make(map[watchPathType][]chan Event), passwd: emptyPassword, logger: DefaultLogger, // Debug reconnectDelay: 0, } // Set provided options. for _, option := range options { option(conn) } if err := conn.hostProvider.Init(srvs); err != nil { return nil, nil, err } conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) go func() { conn.loop() conn.flushRequests(ErrClosing) conn.invalidateWatches(ErrClosing) close(conn.eventChan) }() return conn, ec, nil } // WithDialer returns a connection option specifying a non-default Dialer. func WithDialer(dialer Dialer) connOption { return func(c *Conn) { c.dialer = dialer } } // WithHostProvider returns a connection option specifying a non-default HostProvider. func WithHostProvider(hostProvider HostProvider) connOption { return func(c *Conn) { c.hostProvider = hostProvider } } func (c *Conn) Close() { close(c.shouldQuit) select { case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): case <-time.After(time.Second): } } // State returns the current state of the connection. func (c *Conn) State() State { return State(atomic.LoadInt32((*int32)(&c.state))) } // SessionId returns the current session id of the connection. func (c *Conn) SessionID() int64 { return atomic.LoadInt64(&c.sessionID) } // SetLogger sets the logger to be used for printing errors. // Logger is an interface provided by this package. func (c *Conn) SetLogger(l Logger) { c.logger = l } func (c *Conn) setTimeouts(sessionTimeoutMs int32) { c.sessionTimeoutMs = sessionTimeoutMs sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond c.recvTimeout = sessionTimeout * 2 / 3 c.pingInterval = c.recvTimeout / 2 } func (c *Conn) setState(state State) { atomic.StoreInt32((*int32)(&c.state), int32(state)) select { case c.eventChan <- Event{Type: EventSession, State: state, Server: c.Server()}: default: // panic("zk: event channel full - it must be monitored and never allowed to be full") } } func (c *Conn) connect() error { var retryStart bool for { c.serverMu.Lock() c.server, retryStart = c.hostProvider.Next() c.serverMu.Unlock() c.setState(StateConnecting) if retryStart { c.flushUnsentRequests(ErrNoServer) select { case <-time.After(time.Second): // pass case <-c.shouldQuit: c.setState(StateDisconnected) c.flushUnsentRequests(ErrClosing) return ErrClosing } } zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) if err == nil { c.conn = zkConn c.setState(StateConnected) c.logger.Printf("Connected to %s", c.Server()) return nil } c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err) } } func (c *Conn) loop() { for { if err := c.connect(); err != nil { // c.Close() was called return } err := c.authenticate() switch { case err == ErrSessionExpired: c.logger.Printf("Authentication failed: %s", err) c.invalidateWatches(err) case err != nil && c.conn != nil: c.logger.Printf("Authentication failed: %s", err) c.conn.Close() case err == nil: c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) c.hostProvider.Connected() // mark success closeChan := make(chan struct{}) // channel to tell send loop stop var wg sync.WaitGroup wg.Add(1) go func() { err := c.sendLoop(c.conn, closeChan) c.logger.Printf("Send loop terminated: err=%v", err) c.conn.Close() // causes recv loop to EOF/exit wg.Done() }() wg.Add(1) go func() { err := c.recvLoop(c.conn) c.logger.Printf("Recv loop terminated: err=%v", err) if err == nil { panic("zk: recvLoop should never return nil error") } close(closeChan) // tell send loop to exit wg.Done() }() c.sendSetWatches() wg.Wait() } c.setState(StateDisconnected) select { case <-c.shouldQuit: c.flushRequests(ErrClosing) return default: } if err != ErrSessionExpired { err = ErrConnectionClosed } c.flushRequests(err) if c.reconnectDelay > 0 { select { case <-c.shouldQuit: return case <-time.After(c.reconnectDelay): } } } } func (c *Conn) flushUnsentRequests(err error) { for { select { default: return case req := <-c.sendChan: req.recvChan <- response{-1, err} } } } // Send error to all pending requests and clear request map func (c *Conn) flushRequests(err error) { c.requestsLock.Lock() for _, req := range c.requests { req.recvChan <- response{-1, err} } c.requests = make(map[int32]*request) c.requestsLock.Unlock() } // Send error to all watchers and clear watchers map func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() defer c.watchersLock.Unlock() if len(c.watchers) >= 0 { for pathType, watchers := range c.watchers { ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} for _, ch := range watchers { ch <- ev close(ch) } } c.watchers = make(map[watchPathType][]chan Event) } } func (c *Conn) sendSetWatches() { c.watchersLock.Lock() defer c.watchersLock.Unlock() if len(c.watchers) == 0 { return } req := &setWatchesRequest{ RelativeZxid: c.lastZxid, DataWatches: make([]string, 0), ExistWatches: make([]string, 0), ChildWatches: make([]string, 0), } n := 0 for pathType, watchers := range c.watchers { if len(watchers) == 0 { continue } switch pathType.wType { case watchTypeData: req.DataWatches = append(req.DataWatches, pathType.path) case watchTypeExist: req.ExistWatches = append(req.ExistWatches, pathType.path) case watchTypeChild: req.ChildWatches = append(req.ChildWatches, pathType.path) } n++ } if n == 0 { return } go func() { res := &setWatchesResponse{} _, err := c.request(opSetWatches, req, res, nil) if err != nil { c.logger.Printf("Failed to set previous watches: %s", err.Error()) } }() } func (c *Conn) authenticate() error { buf := make([]byte, 256) // Encode and send a connect request. n, err := encodePacket(buf[4:], &connectRequest{ ProtocolVersion: protocolVersion, LastZxidSeen: c.lastZxid, TimeOut: c.sessionTimeoutMs, SessionID: c.SessionID(), Passwd: c.passwd, }) if err != nil { return err } binary.BigEndian.PutUint32(buf[:4], uint32(n)) c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)) _, err = c.conn.Write(buf[:n+4]) c.conn.SetWriteDeadline(time.Time{}) if err != nil { return err } // Receive and decode a connect response. c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) _, err = io.ReadFull(c.conn, buf[:4]) c.conn.SetReadDeadline(time.Time{}) if err != nil { return err } blen := int(binary.BigEndian.Uint32(buf[:4])) if cap(buf) < blen { buf = make([]byte, blen) } _, err = io.ReadFull(c.conn, buf[:blen]) if err != nil { return err } r := connectResponse{} _, err = decodePacket(buf[:blen], &r) if err != nil { return err } if r.SessionID == 0 { atomic.StoreInt64(&c.sessionID, int64(0)) c.passwd = emptyPassword c.lastZxid = 0 c.setState(StateExpired) return ErrSessionExpired } atomic.StoreInt64(&c.sessionID, r.SessionID) c.setTimeouts(r.TimeOut) c.passwd = r.Passwd c.setState(StateHasSession) return nil } func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error { pingTicker := time.NewTicker(c.pingInterval) defer pingTicker.Stop() buf := make([]byte, bufferSize) for { select { case req := <-c.sendChan: header := &requestHeader{req.xid, req.opcode} n, err := encodePacket(buf[4:], header) if err != nil { req.recvChan <- response{-1, err} continue } n2, err := encodePacket(buf[4+n:], req.pkt) if err != nil { req.recvChan <- response{-1, err} continue } n += n2 binary.BigEndian.PutUint32(buf[:4], uint32(n)) c.requestsLock.Lock() select { case <-closeChan: req.recvChan <- response{-1, ErrConnectionClosed} c.requestsLock.Unlock() return ErrConnectionClosed default: } c.requests[req.xid] = req c.requestsLock.Unlock() conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) _, err = conn.Write(buf[:n+4]) conn.SetWriteDeadline(time.Time{}) if err != nil { req.recvChan <- response{-1, err} conn.Close() return err } case <-pingTicker.C: n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) if err != nil { panic("zk: opPing should never fail to serialize") } binary.BigEndian.PutUint32(buf[:4], uint32(n)) conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) _, err = conn.Write(buf[:n+4]) conn.SetWriteDeadline(time.Time{}) if err != nil { conn.Close() return err } case <-closeChan: return nil } } } func (c *Conn) recvLoop(conn net.Conn) error { buf := make([]byte, bufferSize) for { // package length conn.SetReadDeadline(time.Now().Add(c.recvTimeout)) _, err := io.ReadFull(conn, buf[:4]) if err != nil { return err } blen := int(binary.BigEndian.Uint32(buf[:4])) if cap(buf) < blen { buf = make([]byte, blen) } _, err = io.ReadFull(conn, buf[:blen]) conn.SetReadDeadline(time.Time{}) if err != nil { return err } res := responseHeader{} _, err = decodePacket(buf[:16], &res) if err != nil { return err } if res.Xid == -1 { res := &watcherEvent{} _, err := decodePacket(buf[16:blen], res) if err != nil { return err } ev := Event{ Type: res.Type, State: res.State, Path: res.Path, Err: nil, } select { case c.eventChan <- ev: default: } wTypes := make([]watchType, 0, 2) switch res.Type { case EventNodeCreated: wTypes = append(wTypes, watchTypeExist) case EventNodeDeleted, EventNodeDataChanged: wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) case EventNodeChildrenChanged: wTypes = append(wTypes, watchTypeChild) } c.watchersLock.Lock() for _, t := range wTypes { wpt := watchPathType{res.Path, t} if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 { for _, ch := range watchers { ch <- ev close(ch) } delete(c.watchers, wpt) } } c.watchersLock.Unlock() } else if res.Xid == -2 { // Ping response. Ignore. } else if res.Xid < 0 { c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) } else { if res.Zxid > 0 { c.lastZxid = res.Zxid } c.requestsLock.Lock() req, ok := c.requests[res.Xid] if ok { delete(c.requests, res.Xid) } c.requestsLock.Unlock() if !ok { c.logger.Printf("Response for unknown request with xid %d", res.Xid) } else { if res.Err != 0 { err = res.Err.toError() } else { _, err = decodePacket(buf[16:blen], req.recvStruct) } if req.recvFunc != nil { req.recvFunc(req, &res, err) } req.recvChan <- response{res.Zxid, err} if req.opcode == opClose { return io.EOF } } } } } func (c *Conn) nextXid() int32 { return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) } func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { c.watchersLock.Lock() defer c.watchersLock.Unlock() ch := make(chan Event, 1) wpt := watchPathType{path, watchType} c.watchers[wpt] = append(c.watchers[wpt], ch) return ch } func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { rq := &request{ xid: c.nextXid(), opcode: opcode, pkt: req, recvStruct: res, recvChan: make(chan response, 1), recvFunc: recvFunc, } c.sendChan <- rq return rq.recvChan } func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { r := <-c.queueRequest(opcode, req, res, recvFunc) return r.zxid, r.err } func (c *Conn) AddAuth(scheme string, auth []byte) error { _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) return err } func (c *Conn) Children(path string) ([]string, *Stat, error) { res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) return res.Children, &res.Stat, err } func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { var ech <-chan Event res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { ech = c.addWatcher(path, watchTypeChild) } }) if err != nil { return nil, nil, nil, err } return res.Children, &res.Stat, ech, err } func (c *Conn) Get(path string) ([]byte, *Stat, error) { res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) return res.Data, &res.Stat, err } // GetW returns the contents of a znode and sets a watch func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { var ech <-chan Event res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { ech = c.addWatcher(path, watchTypeData) } }) if err != nil { return nil, nil, nil, err } return res.Data, &res.Stat, ech, err } func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { if path == "" { return nil, ErrInvalidPath } res := &setDataResponse{} _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) return &res.Stat, err } func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) { res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) return res.Path, err } // CreateProtectedEphemeralSequential fixes a race condition if the server crashes // after it creates the node. On reconnect the session may still be valid so the // ephemeral node still exists. Therefore, on reconnect we need to check if a node // with a GUID generated on create exists. func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { var guid [16]byte _, err := io.ReadFull(rand.Reader, guid[:16]) if err != nil { return "", err } guidStr := fmt.Sprintf("%x", guid) parts := strings.Split(path, "/") parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) rootPath := strings.Join(parts[:len(parts)-1], "/") protectedPath := strings.Join(parts, "/") var newPath string for i := 0; i < 3; i++ { newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) switch err { case ErrSessionExpired: // No need to search for the node since it can't exist. Just try again. case ErrConnectionClosed: children, _, err := c.Children(rootPath) if err != nil { return "", err } for _, p := range children { parts := strings.Split(p, "/") if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { return rootPath + "/" + p, nil } } } case nil: return newPath, nil default: return "", err } } return "", err } func (c *Conn) Delete(path string, version int32) error { _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil) return err } func (c *Conn) Exists(path string) (bool, *Stat, error) { res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) exists := true if err == ErrNoNode { exists = false err = nil } return exists, &res.Stat, err } func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { var ech <-chan Event res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { ech = c.addWatcher(path, watchTypeData) } else if err == ErrNoNode { ech = c.addWatcher(path, watchTypeExist) } }) exists := true if err == ErrNoNode { exists = false err = nil } if err != nil { return false, nil, nil, err } return exists, &res.Stat, ech, err } func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { res := &getAclResponse{} _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) return res.Acl, &res.Stat, err } func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { res := &setAclResponse{} _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) return &res.Stat, err } func (c *Conn) Sync(path string) (string, error) { res := &syncResponse{} _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) return res.Path, err } type MultiResponse struct { Stat *Stat String string } // Multi executes multiple ZooKeeper operations or none of them. The provided // ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or // *CheckVersionRequest. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { req := &multiRequest{ Ops: make([]multiRequestOp, 0, len(ops)), DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, } for _, op := range ops { var opCode int32 switch op.(type) { case *CreateRequest: opCode = opCreate case *SetDataRequest: opCode = opSetData case *DeleteRequest: opCode = opDelete case *CheckVersionRequest: opCode = opCheck default: return nil, fmt.Errorf("unknown operation type %T", op) } req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) } res := &multiResponse{} _, err := c.request(opMulti, req, res, nil) mr := make([]MultiResponse, len(res.Ops)) for i, op := range res.Ops { mr[i] = MultiResponse{Stat: op.Stat, String: op.String} } return mr, err } // Server returns the current or last-connected server name. func (c *Conn) Server() string { c.serverMu.Lock() defer c.serverMu.Unlock() return c.server }