From 08b345acd2b36807003f9d81d180aadff4bceca8 Mon Sep 17 00:00:00 2001 From: Oleg Jukovec Date: Mon, 21 Nov 2022 19:00:47 +0300 Subject: [PATCH] api: add events subscription support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WIP: ConnectionPool/ConnectionMulti A user can create watcher by the Connection.NewWatcher() call: watcher = conn.NewWatcker("key", func(event WatchEvent) { // The callback code. }) After that, the watcher callback is invoked for the first time. In this case, the callback is triggered whether or not the key has already been broadcast. All subsequent invocations are triggered with box.broadcast() called on the remote host. If a watcher is subscribed for a key that has not been broadcast yet, the callback is triggered only once, after the registration of the watcher. If the key is updated while the watcher callback is running, the callback will be invoked again with the latest value as soon as it returns. Multiple watchers can be created for one key. If you don’t need the watcher anymore, you can unregister it using the Unregister method: watcher.Unregister() The api is similar to net.box implementation [1]. It also adds a BroadcastRequest to make it easier to send broadcast messages. 1. https://www.tarantool.io/ru/doc/latest/reference/reference_lua/net_box/#conn-watch Closes #119 --- CHANGELOG.md | 2 + connection.go | 298 +++++++++++++++++++++- connection_pool/connection_pool.go | 244 ++++++++++++++++-- connection_pool/connection_pool_test.go | 153 ++++++++++++ connection_pool/round_robin.go | 25 +- connection_pool/round_robin_test.go | 19 ++ const.go | 5 + multi/multi.go | 46 +++- request.go | 8 + request_test.go | 65 +++++ tarantool_test.go | 319 ++++++++++++++++++++++++ test_helpers/request_mock.go | 4 + test_helpers/utils.go | 15 +- watch.go | 138 ++++++++++ 14 files changed, 1303 insertions(+), 38 deletions(-) create mode 100644 watch.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fb75664b..782dda025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- Event subscription support (#119) + ### Changed ### Fixed diff --git a/connection.go b/connection.go index a4ae8cc36..62cea354b 100644 --- a/connection.go +++ b/connection.go @@ -53,6 +53,8 @@ const ( // LogUnexpectedResultId is logged when response with unknown id was received. // Most probably it is due to request timeout. LogUnexpectedResultId + // LogReadWatchEventFailed is logged when failed to read a watch event. + LogReadWatchEventFailed ) // ConnEvent is sent throw Notify channel specified in Opts. @@ -62,6 +64,12 @@ type ConnEvent struct { When time.Time } +// A raw watch event. +type connWatchEvent struct { + key string + value interface{} +} + var epoch = time.Now() // Logger is logger type expected to be passed in options. @@ -83,6 +91,9 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac case LogUnexpectedResultId: resp := v[0].(*Response) log.Printf("tarantool: connection %s got unexpected resultId (%d) in response", conn.addr, resp.RequestId) + case LogReadWatchEventFailed: + err := v[0].(error) + log.Printf("tarantool: unable to parse watch event: %s\n", err) default: args := append([]interface{}{"tarantool: unexpected event ", event, conn}, v...) log.Print(args...) @@ -146,6 +157,9 @@ type Connection struct { lenbuf [PacketLengthBytes]byte lastStreamId uint64 + + // watchMap is a map of key -> watcherSharedData. + watchMap sync.Map } var _ = Connector(&Connection{}) // Check compatibility with connector interface. @@ -502,7 +516,7 @@ func (conn *Connection) dial() (err error) { conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String() conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String() - // Auth + // Auth. if opts.User != "" { scr, err := scramble(conn.Greeting.auth, opts.Pass) if err != nil { @@ -520,7 +534,20 @@ func (conn *Connection) dial() (err error) { } } - // Only if connected and authenticated. + // Watchers. + conn.watchMap.Range(func(key, value interface{}) bool { + req := newWatchRequest(key.(string)) + if err = conn.writeRequest(w, req); err != nil { + return false + } + return true + }) + + if err != nil { + return fmt.Errorf("unable to register watch: %w", err) + } + + // Only if connected and fully initialized. conn.lockShards() conn.c = connection atomic.StoreUint32(&conn.state, connConnected) @@ -581,23 +608,33 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { +func (conn *Connection) writeRequest(w *bufio.Writer, req Request) (err error) { var packet smallWBuf - req := newAuthRequest(conn.opts.User, string(scramble)) err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) if err != nil { - return errors.New("auth: pack error " + err.Error()) + return fmt.Errorf("pack error %w", err) } if err := write(w, packet.b); err != nil { - return errors.New("auth: write error " + err.Error()) + return fmt.Errorf("write error %w", err) } if err = w.Flush(); err != nil { - return errors.New("auth: flush error " + err.Error()) + return fmt.Errorf("flush error %w", err) } return } +func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { + req := newAuthRequest(conn.opts.User, string(scramble)) + + err = conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil +} + func (conn *Connection) readAuthResponse(r io.Reader) (err error) { respBytes, err := conn.read(r) if err != nil { @@ -774,7 +811,50 @@ func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { } } +func readWatchEvent(reader io.Reader) (connWatchEvent, error) { + keyExist := false + event := connWatchEvent{} + d := newDecoder(reader) + + if l, err := d.DecodeMapLen(); err == nil { + for ; l > 0; l-- { + if cd, err := d.DecodeInt(); err == nil { + switch cd { + case KeyEvent: + if event.key, err = d.DecodeString(); err != nil { + return event, err + } + keyExist = true + case KeyEventData: + if event.value, err = d.DecodeInterface(); err != nil { + return event, err + } + default: + if err = d.Skip(); err != nil { + return event, err + } + } + } else { + return event, err + } + } + } else { + return event, err + } + + if !keyExist { + return event, errors.New("watch event does not have a key") + } + + return event, nil +} + func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { + events := make(chan connWatchEvent, 1024) + defer close(events) + + go conn.eventer(events) + for atomic.LoadUint32(&conn.state) != connClosed { respBytes, err := conn.read(r) if err != nil { @@ -789,7 +869,14 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } var fut *Future = nil - if resp.Code == PushCode { + if resp.Code == EventCode { + if event, err := readWatchEvent(&resp.buf); err == nil { + events<- event + } else { + conn.opts.Logger.Report(LogReadWatchEventFailed, conn, err) + } + continue + } else if resp.Code == PushCode { if fut = conn.peekFuture(resp.RequestId); fut != nil { fut.AppendPush(resp) } @@ -799,12 +886,42 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { conn.markDone(fut) } } + if fut == nil { conn.opts.Logger.Report(LogUnexpectedResultId, conn, resp) } } } +// eventer goroutine gets watch events and updates values for watchers. +func (conn *Connection) eventer(events <-chan connWatchEvent) { + for { + event, ok := <-events + if !ok { + // The channel is closed. + break + } + + if value, ok := conn.watchMap.Load(event.key); ok { + shared := value.(*watchDataShared) + shared.condMutex.Lock() + shared.value = event.value + shared.version += 1 + shared.condMutex.Unlock() + + shared.cond.Broadcast() + + if atomic.LoadUint32(&conn.state) == connConnected { + shared.watchMutex.Lock() + if shared.watchCnt > 0 { + conn.Do(newWatchRequest(event.key)) + } + shared.watchMutex.Unlock() + } + } + } +} + func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { @@ -960,6 +1077,18 @@ func (conn *Connection) putFuture(fut *Future, req Request, streamId uint64) { return } shard.bufmut.Unlock() + + if req.Async() { + if fut = conn.fetchFuture(reqid); fut != nil { + resp := &Response{ + RequestId: reqid, + Code: OkCode, + } + fut.SetResponse(resp) + conn.markDone(fut) + } + } + if firstWritten { conn.dirtyShard <- shardn } @@ -1163,3 +1292,156 @@ func (conn *Connection) NewStream() (*Stream, error) { Conn: conn, }, nil } + +// watchDataShared is a shared between watchers of some key. +type watchDataShared struct { + // value is a last value of the key. + value interface{} + // version is a version of the value. It increases by 1 on the value + // update. + version uint + // The goroutine that gets events updates the value and the version under + // the write lock. Watchers read the value and the version under read lock. + cond *sync.Cond + condMutex sync.RWMutex + + // watchCnt is a number of active watchers. + watchCnt int32 + // watchMutex helps to send IPROTO_WATCH/IPROTO_UNWATCH without duplicates + // and intersections. + watchMutex sync.Mutex +} + +// connWatcher is an internal implementation of the Watcher interface. +type connWatcher struct { + shared *watchDataShared + done chan struct{} + finished chan struct{} + unregister sync.Once +} + +// Unregister unregisters the connection watcher. +func (w *connWatcher) Unregister() { + w.unregister.Do(func(){ + close(w.done) + // The Lock/Unlock helps to sync the w.done state with the watcher + // goroutine. + w.shared.condMutex.Lock() + w.shared.condMutex.Unlock() + w.shared.cond.Broadcast() + }) + <-w.finished +} + +// NewWatcher creates a new Watcher object for the connection. +// +// After watcher creation, the watcher callback is invoked for the first time. +// In this case, the callback is triggered whether or not the key has already +// been broadcast. All subsequent invocations are triggered with +// box.broadcast() called on the remote host. If a watcher is subscribed for a +// key that has not been broadcast yet, the callback is triggered only once, +// after the registration of the watcher. +// +// The watcher callbacks are always invoked in a separate goroutine. A watcher +// callback is never executed in parallel with itself, but they can be executed +// in parallel to other watchers. +// +// If the key is updated while the watcher callback is running, the callback +// will be invoked again with the latest value as soon as it returns. +// +// Watchers survive reconnection. All registered watchers are automatically +// resubscribed when the connection is reestablished. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, error) { + var shared *watchDataShared + // Get or create a shared data for the key. + if val, ok := conn.watchMap.Load(key); !ok { + shared = &watchDataShared{ + value: nil, + version: 0, + watchCnt: 0, + } + shared.cond = sync.NewCond(shared.condMutex.RLocker()) + + if val, ok := conn.watchMap.LoadOrStore(key, shared); ok { + shared = val.(*watchDataShared) + } + } else { + shared = val.(*watchDataShared) + } + + // Send an initial watch request. + shared.watchMutex.Lock() + if shared.watchCnt == 0 { + shared.condMutex.Lock() + shared.version = 0 + shared.condMutex.Unlock() + + if _, err := conn.Do(newWatchRequest(key)).Get(); err != nil { + shared.watchMutex.Unlock() + return nil, err + } + } + shared.watchCnt += 1 + shared.watchMutex.Unlock() + + // Start the watcher goroutine. + version := uint(0) + done := make(chan struct{}) + finished := make(chan struct{}) + + go func() { + for { + shared.cond.L.Lock() + for { + select{ + case <-done: + shared.cond.L.Unlock() + + shared.watchMutex.Lock() + shared.watchCnt -= 1 + if shared.watchCnt == 0 { + // A last one sends IPROTO_UNWATCH. + conn.Do(newUnwatchRequest(key)) + } + shared.watchMutex.Unlock() + + close(finished) + return + default: + } + if version != shared.version { + break + } + shared.cond.Wait() + } + + value := shared.value + version = shared.version + shared.cond.L.Unlock() + + callback(WatchEvent{ + Conn: conn, + Key: key, + Value: value, + }) + } + }() + + return &connWatcher{ + shared: shared, + done: done, + finished: finished, + }, nil +} diff --git a/connection_pool/connection_pool.go b/connection_pool/connection_pool.go index 6597e2dd0..f4fd9d505 100644 --- a/connection_pool/connection_pool.go +++ b/connection_pool/connection_pool.go @@ -87,16 +87,17 @@ Main features: - Automatic master discovery by mode parameter. */ type ConnectionPool struct { - addrs []string - connOpts tarantool.Opts - opts OptsPool + addrs []string + connOpts tarantool.Opts + opts OptsPool - state state - done chan struct{} - roPool *RoundRobinStrategy - rwPool *RoundRobinStrategy - anyPool *RoundRobinStrategy - poolsMutex sync.RWMutex + state state + done chan struct{} + roPool *RoundRobinStrategy + rwPool *RoundRobinStrategy + anyPool *RoundRobinStrategy + poolsMutex sync.RWMutex + watcherContainer watcherContainer } var _ Pooler = (*ConnectionPool)(nil) @@ -682,6 +683,169 @@ func (connPool *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarant return conn.NewPrepared(expr) } +// watcherContainer is a very simple implementation of a thread-safe container +// for watchares. It is not expected that there will be too many watchers and +// they will registered/unregistered too frequently. +// +// Otherwise, the implementation will need to be optimized. +type watcherContainer struct { + head *poolWatcher + mutex sync.RWMutex +} + +// add adds a watcher to the container. +func (c *watcherContainer) add(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + watcher.next = c.head + c.head = watcher +} + +// remove removes a watcher from the container. +func (c *watcherContainer) remove(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if watcher == c.head { + c.head = watcher.next + } else { + cur := c.head + for cur.next != nil { + if cur.next == watcher { + cur.next = watcher.next + break + } + cur = cur.next + } + } +} + +// foreach iterates over the container to the end or until the call returns +// false. +func (c *watcherContainer) foreach(call func(watcher *poolWatcher) error) error { + cur := c.head + for cur != nil { + if err := call(cur); err != nil { + return err + } + cur = cur.next + } + return nil +} + +// poolWatcher is an internal implementation of the tarantool.Watcher interface. +type poolWatcher struct { + // The watcher container data. We can split the structure into two parts + // in the future: a watcher data and a watcher container data, but it looks + // simple at now. + next *poolWatcher + + // The watcher data. + container *watcherContainer + mode Mode + key string + callback tarantool.WatchCallback + watchers map[string]tarantool.Watcher + unregistered bool + mutex sync.Mutex +} + +// Unregister unregisters the pool watcher. +func (w *poolWatcher) Unregister() { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + w.container.remove(w) + w.unregistered = true + for _, watcher := range w.watchers { + watcher.Unregister() + } + } +} + +func (w *poolWatcher) watch(conn *tarantool.Connection) error { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if _, ok := w.watchers[addr]; ok { + return nil + } + + if watcher, err := conn.NewWatcher(w.key, w.callback); err == nil { + w.watchers[addr] = watcher + return nil + } else { + return err + } + } + return nil +} + +func (w *poolWatcher) unwatch(conn *tarantool.Connection) { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if watcher, ok := w.watchers[addr]; ok { + watcher.Unregister() + delete(w.watchers, addr) + } + } +} + +// NewWatcher creates a new Watcher object for the connection pool. +// +// The behavior is same as if Connection.NewWatcher() called for each +// connection with a suitable role. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (pool *ConnectionPool) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + watcher := &poolWatcher { + container: &pool.watcherContainer, + mode: mode, + key: key, + callback: callback, + watchers: make(map[string]tarantool.Watcher), + unregistered: false, + } + + watcher.container.add(watcher) + + rr := pool.anyPool + if mode == RW { + rr = pool.rwPool + } else if mode == RO { + rr = pool.roPool + } + + conns := rr.GetConnections() + for _, conn := range conns { + if err := watcher.watch(conn); err != nil { + conn.Close() + } + } + + return watcher, nil +} + // // private // @@ -742,17 +906,47 @@ func (connPool *ConnectionPool) deleteConnection(addr string) { } } -func (connPool *ConnectionPool) addConnection(addr string, - conn *tarantool.Connection, role Role) { +func (pool *ConnectionPool) addConnection(addr string, + conn *tarantool.Connection, role Role) error { + // The connection initialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() - connPool.anyPool.AddConn(addr, conn) + watched := []tarantool.Watcher{} + err := pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watch := false + if watcher.mode == RW { + watch = role == MasterRole + } else if watcher.mode == RO { + watch = role == ReplicaRole + } else { + watch = true + } + if watch { + if err := watcher.watch(conn); err != nil { + return err + } + watched = append(watched, watcher) + } + return nil + }) + if err != nil { + for _, watcher := range(watched) { + watcher.Unregister() + } + log.Printf("tarantool: failed initialize watchers for %s: %s", addr, err) + return err + } + + pool.anyPool.AddConn(addr, conn) switch role { case MasterRole: - connPool.rwPool.AddConn(addr, conn) + pool.rwPool.AddConn(addr, conn) case ReplicaRole: - connPool.roPool.AddConn(addr, conn) + pool.roPool.AddConn(addr, conn) } + return nil } func (connPool *ConnectionPool) handlerDiscovered(conn *tarantool.Connection, @@ -811,7 +1005,10 @@ func (connPool *ConnectionPool) fillPools() ([]connState, bool) { } if connPool.handlerDiscovered(conn, role) { - connPool.addConnection(addr, conn, role) + if connPool.addConnection(addr, conn, role) != nil { + conn.Close() + connPool.handlerDeactivated(conn, role) + } if conn.ConnectedNow() { states[i].conn = conn @@ -864,7 +1061,15 @@ func (pool *ConnectionPool) updateConnection(s connState) connState { return s } - pool.addConnection(s.addr, s.conn, role) + if pool.addConnection(s.addr, s.conn, role) != nil { + pool.poolsMutex.Unlock() + + s.conn.Close() + pool.handlerDeactivated(s.conn, role) + s.conn = nil + s.role = UnknownRole + return s + } s.role = role } } @@ -911,7 +1116,12 @@ func (pool *ConnectionPool) tryConnect(s connState) connState { return s } - pool.addConnection(s.addr, conn, role) + if pool.addConnection(s.addr, conn, role) != nil { + pool.poolsMutex.Unlock() + conn.Close() + pool.handlerDeactivated(conn, role) + return s + } s.conn = conn s.role = role } diff --git a/connection_pool/connection_pool_test.go b/connection_pool/connection_pool_test.go index 60c9b4f91..58e64ee50 100644 --- a/connection_pool/connection_pool_test.go +++ b/connection_pool/connection_pool_test.go @@ -2048,6 +2048,159 @@ func TestStream_TxnIsolationLevel(t *testing.T) { } } +func TestConnectionPool_NewWatcher_modes(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_modes" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + modes := []connection_pool.Mode{ + connection_pool.ANY, + connection_pool.RW, + connection_pool.RO, + connection_pool.PreferRW, + connection_pool.PreferRO, + } + for _, mode := range modes { + m := mode + t.Run(fmt.Sprintf("%d", mode), func(t *testing.T) { + expectedServers := []string{} + for i, server := range servers { + if roles[i] && mode == connection_pool.RW { + continue + } else if !roles[i] && mode == connection_pool.RO { + continue + } + expectedServers = append(expectedServers, server) + } + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := connPool.NewWatcher(key, func(event tarantool.WatchEvent) { + if event.Key != key { + t.Errorf("Unexpected key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected value: %v", event.Value) + } + events<- event + }, m) + require.Nilf(t, err, "failed to register a watcher") + defer watcher.Unregister() + + testMap := make(map[string]int) + + for i := 0; i < len(expectedServers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } + } + + for _, server := range expectedServers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } + }) + } +} + +func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnectionPool_NewWatcher_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + mode := connection_pool.ANY + callback := func(event tarantool.WatchEvent){} + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + //fmt.Printf("NewWatcher %d\n", i) + watcher, err := connPool.NewWatcher(key, callback, mode) + //fmt.Printf("NewWatcher %d done\n", i) + if err != nil { + t.Errorf("Failed to create a watcher: %s", err) + } else { + //fmt.Printf("Unregister %d\n", i) + watcher.Unregister() + //fmt.Printf("Unregister %d done\n", i) + } + }(i) + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + defer connPool.Close() + + mode := connection_pool.ANY + watcher, err := connPool.NewWatcher(key, func(event tarantool.WatchEvent){ + }, mode) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/connection_pool/round_robin.go b/connection_pool/round_robin.go index b83d877d9..f59c82f38 100644 --- a/connection_pool/round_robin.go +++ b/connection_pool/round_robin.go @@ -14,6 +14,16 @@ type RoundRobinStrategy struct { current uint } +func NewEmptyRoundRobin(size int) *RoundRobinStrategy { + return &RoundRobinStrategy{ + conns: make([]*tarantool.Connection, 0, size), + indexByAddr: make(map[string]uint), + size: 0, + current: 0, + } +} + + func (r *RoundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { r.mutex.RLock() defer r.mutex.RUnlock() @@ -71,13 +81,14 @@ func (r *RoundRobinStrategy) GetNextConnection() *tarantool.Connection { return r.conns[r.nextIndex()] } -func NewEmptyRoundRobin(size int) *RoundRobinStrategy { - return &RoundRobinStrategy{ - conns: make([]*tarantool.Connection, 0, size), - indexByAddr: make(map[string]uint), - size: 0, - current: 0, - } +func (r *RoundRobinStrategy) GetConnections() []*tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + ret := make([]*tarantool.Connection, len(r.conns)) + copy(ret, r.conns) + + return ret } func (r *RoundRobinStrategy) AddConn(addr string, conn *tarantool.Connection) { diff --git a/connection_pool/round_robin_test.go b/connection_pool/round_robin_test.go index 6b54ecfd8..25a681a25 100644 --- a/connection_pool/round_robin_test.go +++ b/connection_pool/round_robin_test.go @@ -69,3 +69,22 @@ func TestRoundRobinGetNextConnection(t *testing.T) { } } } + +func TestRoundRobinStrategy_GetConnections(t *testing.T) { + rr := NewEmptyRoundRobin(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConn(addr, conns[i]) + } + + rr.GetConnections()[1] = conns[0] // GetConnections() retuns a copy. + rrConns := rr.GetConnections() + for i, expected := range conns { + if expected != rrConns[i] { + t.Errorf("Unexpected connection on %d call", i) + } + } +} diff --git a/const.go b/const.go index 4a3cb6833..95a0d366d 100644 --- a/const.go +++ b/const.go @@ -18,6 +18,8 @@ const ( RollbackRequestCode = 16 PingRequestCode = 64 SubscribeRequestCode = 66 + WatchRequestCode = 74 + UnwatchRequestCode = 75 KeyCode = 0x00 KeySync = 0x01 @@ -42,6 +44,8 @@ const ( KeySQLInfo = 0x42 KeyStmtID = 0x43 KeyTimeout = 0x56 + KeyEvent = 0x57 + KeyEventData = 0x58 KeyTxnIsolation = 0x59 KeyFieldName = 0x00 @@ -70,6 +74,7 @@ const ( RLimitWait = 2 OkCode = uint32(0) + EventCode = uint32(0x4c) PushCode = uint32(0x80) ErrorCodeBit = 0x8000 PacketLengthBytes = 5 diff --git a/multi/multi.go b/multi/multi.go index 67f450c5c..0a7a5bd44 100644 --- a/multi/multi.go +++ b/multi/multi.go @@ -116,7 +116,7 @@ func (connMulti *ConnectionMulti) warmUp() (somebodyAlive bool, errs []error) { errs = make([]error, len(connMulti.addrs)) for i, addr := range connMulti.addrs { - conn, err := tarantool.Connect(addr, connMulti.connOpts) + conn, err := connMulti.newConnection(addr) errs[i] = err if conn != nil && err == nil { if connMulti.fallback == nil { @@ -135,6 +135,18 @@ func (connMulti *ConnectionMulti) getState() uint32 { return atomic.LoadUint32(&connMulti.state) } +func (connMulti *ConnectionMulti) newConnection(addr string) (*tarantool.Connection, error) { + conn, err := tarantool.Connect(addr, connMulti.connOpts) + if err != nil { + return nil, err + } + return conn, nil +} + +func (connMulti *ConnectionMulti) unregisterWatchers(conn *tarantool.Connection) { + +} + func (connMulti *ConnectionMulti) getConnectionFromPool(addr string) (*tarantool.Connection, bool) { connMulti.mutex.RLock() defer connMulti.mutex.RUnlock() @@ -145,12 +157,17 @@ func (connMulti *ConnectionMulti) getConnectionFromPool(addr string) (*tarantool func (connMulti *ConnectionMulti) setConnectionToPool(addr string, conn *tarantool.Connection) { connMulti.mutex.Lock() defer connMulti.mutex.Unlock() + if old, ok := connMulti.pool[addr]; ok { + connMulti.unregisterWatchers(old) + } connMulti.pool[addr] = conn } func (connMulti *ConnectionMulti) deleteConnectionFromPool(addr string) { connMulti.mutex.Lock() defer connMulti.mutex.Unlock() + old := connMulti.pool[addr] + connMulti.unregisterWatchers(old) delete(connMulti.pool, addr) } @@ -175,7 +192,7 @@ func (connMulti *ConnectionMulti) checker() { if _, ok := connMulti.getConnectionFromPool(addr); !ok { continue } - conn, _ := tarantool.Connect(addr, connMulti.connOpts) + conn, _ := connMulti.newConnection(addr) if conn != nil { connMulti.setConnectionToPool(addr, conn) } else { @@ -196,7 +213,7 @@ func (connMulti *ConnectionMulti) checker() { // Fill pool with new connections. for _, v := range addrs { if indexOf(v, connMulti.addrs) < 0 { - conn, _ := tarantool.Connect(v, connMulti.connOpts) + conn, _ := connMulti.newConnection(v) if conn != nil { connMulti.setConnectionToPool(v, conn) } @@ -224,7 +241,7 @@ func (connMulti *ConnectionMulti) checker() { continue } } - conn, _ := tarantool.Connect(addr, connMulti.connOpts) + conn, _ := connMulti.newConnection(addr) if conn != nil { connMulti.setConnectionToPool(addr, conn) } @@ -264,6 +281,7 @@ func (connMulti *ConnectionMulti) Close() (err error) { connMulti.state = connClosed for _, conn := range connMulti.pool { + connMulti.unregisterWatchers(conn) if err == nil { err = conn.Close() } else { @@ -271,6 +289,7 @@ func (connMulti *ConnectionMulti) Close() (err error) { } } if connMulti.fallback != nil { + connMulti.unregisterWatchers(connMulti.fallback) connMulti.fallback.Close() } @@ -507,6 +526,25 @@ func (connMulti *ConnectionMulti) NewStream() (*tarantool.Stream, error) { return connMulti.getCurrentConnection().NewStream() } +type watchDataShared struct { + +} + +type multiWatcher struct { + shared *watchDataShared + done chan struct{} +} + +func (w *multiWatcher) Unregister() { + +} + +func (connMulti *ConnectionMulti) NewWatcher(key string, callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return &multiWatcher{ + + }, nil +} + // Do sends the request and returns a future. func (connMulti *ConnectionMulti) Do(req tarantool.Request) *tarantool.Future { if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { diff --git a/request.go b/request.go index cfa40e522..30ef8d1f5 100644 --- a/request.go +++ b/request.go @@ -538,6 +538,8 @@ type Request interface { Body(resolver SchemaResolver, enc *encoder) error // Ctx returns a context of the request. Ctx() context.Context + // Async returns true if the request does not expect response. + Async() bool } // ConnectedRequest is an interface that provides the info about a Connection @@ -550,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + async bool ctx context.Context } @@ -558,6 +561,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Async returns true if the request does not require response. +func (req *baseRequest) Async() bool { + return req.async +} + // Ctx returns a context of the request. func (req *baseRequest) Ctx() context.Context { return req.ctx diff --git a/request_test.go b/request_test.go index 89d1d8884..a680cdcbb 100644 --- a/request_test.go +++ b/request_test.go @@ -19,6 +19,7 @@ const invalidIndex = 2 const validSpace = 1 // Any valid value != default. const validIndex = 3 // Any valid value != default. const validExpr = "any string" // We don't check the value here. +const validKey = "foo" // Any string. const defaultSpace = 0 // And valid too. const defaultIndex = 0 // And valid too. @@ -183,6 +184,7 @@ func TestRequestsCodes(t *testing.T) { {req: NewBeginRequest(), code: BeginRequestCode}, {req: NewCommitRequest(), code: CommitRequestCode}, {req: NewRollbackRequest(), code: RollbackRequestCode}, + {req: NewBroadcastRequest(validKey), code: EvalRequestCode}, } for _, test := range tests { @@ -192,6 +194,38 @@ func TestRequestsCodes(t *testing.T) { } } +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req Request + async bool + }{ + {req: NewSelectRequest(validSpace), async: false}, + {req: NewUpdateRequest(validSpace), async: false}, + {req: NewUpsertRequest(validSpace), async: false}, + {req: NewInsertRequest(validSpace), async: false}, + {req: NewReplaceRequest(validSpace), async: false}, + {req: NewDeleteRequest(validSpace), async: false}, + {req: NewCall16Request(validExpr), async: false}, + {req: NewCall17Request(validExpr), async: false}, + {req: NewEvalRequest(validExpr), async: false}, + {req: NewExecuteRequest(validExpr), async: false}, + {req: NewPingRequest(), async: false}, + {req: NewPrepareRequest(validExpr), async: false}, + {req: NewUnprepareRequest(validStmt), async: false}, + {req: NewExecutePreparedRequest(validStmt), async: false}, + {req: NewBeginRequest(), async: false}, + {req: NewCommitRequest(), async: false}, + {req: NewRollbackRequest(), async: false}, + {req: NewBroadcastRequest(validKey), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + func TestPingRequestDefaultValues(t *testing.T) { var refBuf bytes.Buffer @@ -649,3 +683,34 @@ func TestRollbackRequestDefaultValues(t *testing.T) { req := NewRollbackRequest() assertBodyEqual(t, refBuf.Bytes(), req) } + +func TestBroadcastRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestBroadcastRequestSetters(t *testing.T) { + value := []interface{}{uint(34), int(12)} + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey, value} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey).Value(value) + assertBodyEqual(t, refBuf.Bytes(), req) +} diff --git a/tarantool_test.go b/tarantool_test.go index 1350390f9..ab0abb6ac 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2830,6 +2830,325 @@ func TestStream_DoWithClosedConn(t *testing.T) { } } +func TestConnection_NewWatcher(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestConnection_NewWatcher_reconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_reconnect" + const server = "127.0.0.1:3014" + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + InitScript: "config.lua", + Listen: server, + WorkDir: "work_dir", + User: opts.User, + Pass: opts.Pass, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 3, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + conn := test_helpers.ConnectWithValidation(t, server, reconnectOpts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events + + test_helpers.StopTarantool(inst) + if err := test_helpers.RestartTarantool(&inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + select { + case <-events: + case <-time.After(maxTime): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + resp, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + if resp.Code != OkCode { + t.Errorf("Got unexpected broadcast response code: %d", resp.Code) + } + if !reflect.DeepEqual(resp.Data, []interface{}{}) { + t.Errorf("Got unexpected broadcast response data: %v", resp.Data) + } + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest_multi(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest_multi" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events // Skip an initial event. + for i := 0; i < 10; i++ { + val := fmt.Sprintf("%d", i) + _, err := conn.Do(NewBroadcastRequest(key).Value(val)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != val { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event %d", i) + } + } +} + +func TestConnection_NewWatcher_multiOnKey(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_multiOnKey" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := []chan WatchEvent{ + make(chan WatchEvent), + make(chan WatchEvent), + } + for _, ch := range events { + defer close(ch) + } + + for _, ch := range events { + channel := ch + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + channel <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + } + + for i, ch := range events { + select { + case <-ch: // Skip an initial event. + case <-time.After(2 * time.Second): + t.Fatalf("Failed to skip watch event for %d callback", i) + } + } + + _, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + + for i, ch := range events { + select { + case event := <-ch: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event from callback %d", i) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + <-events + watcher.Unregister() + + _, err = conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + + select { + case event := <-events: + t.Fatalf("Get unexpected events: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnection_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnection_NewWatcher_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } else { + watcher.Unregister() + } + }() + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 93551e34a..19c18545e 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -17,6 +17,10 @@ func (sr *StrangerRequest) Code() int32 { return 0 } +func (sr *StrangerRequest) Async() bool { + return false +} + func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *encoder) error { return nil } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index c936e90b3..4672a7de8 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -53,10 +53,9 @@ func SkipIfSQLUnsupported(t testing.TB) { } } -func SkipIfStreamsUnsupported(t *testing.T) { +func skipIfLess2_10(t *testing.T) { t.Helper() - // Tarantool supports streams and interactive transactions since version 2.10.0 isLess, err := IsTarantoolVersionLess(2, 10, 0) if err != nil { t.Fatalf("Could not check the Tarantool version") @@ -66,3 +65,15 @@ func SkipIfStreamsUnsupported(t *testing.T) { t.Skip("Skipping test for Tarantool without streams support") } } + +func SkipIfStreamsUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t) +} + +func SkipIfWatchersUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t) +} diff --git a/watch.go b/watch.go new file mode 100644 index 000000000..75b253d17 --- /dev/null +++ b/watch.go @@ -0,0 +1,138 @@ +package tarantool + +import ( + "context" +) + +// BroadcastRequest helps to send broadcast messages. See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/broadcast/ +type BroadcastRequest struct { + eval *EvalRequest + key string +} + +// NewBroadcastRequest returns a new broadcast request for a specified key. +func NewBroadcastRequest(key string) *BroadcastRequest { + req := new(BroadcastRequest) + req.key = key + req.eval = NewEvalRequest("box.broadcast(...)").Args([]interface{}{key}) + return req +} + +// Value sets the value for the broadcast request. +// Note: default value is nil. +func (req *BroadcastRequest) Value(value interface{}) *BroadcastRequest { + req.eval = req.eval.Args([]interface{}{req.key, value}) + return req +} + +// Context sets a passed context to the broadcast request. +func (req *BroadcastRequest) Context(ctx context.Context) *BroadcastRequest { + req.eval = req.eval.Context(ctx) + return req +} + +// Code returns IPROTO code for the broadcast request. +func (req *BroadcastRequest) Code() int32 { + return req.eval.Code() +} + +// Body fills an encoder with the broadcast request body. +func (req *BroadcastRequest) Body(res SchemaResolver, enc *encoder) error { + return req.eval.Body(res, enc) +} + +// Ctx returns a context of the broadcast request. +func (req *BroadcastRequest) Ctx() context.Context { + return req.eval.Ctx() +} + +// Async returns is the broadcast request expects a response. +func (req *BroadcastRequest) Async() bool { + return req.eval.Async() +} + +// watchRequest subscribes to the updates of a specified key defined on the +// server. After receiving the notification, you should send a new +// watchRequest to acknowledge the notification. +type watchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newWatchRequest returns a new watchRequest. +func newWatchRequest(key string) *watchRequest { + req := new(watchRequest) + req.requestCode = WatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the watch request body. +func (req *watchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *watchRequest) Context(ctx context.Context) *watchRequest { + req.ctx = ctx + return req +} + +// unwatchRequest unregisters a watcher subscribed to the given notification +// key. +type unwatchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newUnwatchRequest returns a new unwatchRequest. +func newUnwatchRequest(key string) *unwatchRequest { + req := new(unwatchRequest) + req.requestCode = UnwatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the unwatch request body. +func (req *unwatchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *unwatchRequest) Context(ctx context.Context) *unwatchRequest { + req.ctx = ctx + return req +} + +// WatchEvent is a watch notification event received from a server. +type WatchEvent struct { + Conn *Connection // A source connection. + Key string // A key. + Value interface{} // A new value. +} + +// Watcher is a subscription to broadcast events. +type Watcher interface { + // Unregister unregisters the watcher. + Unregister() +} + +// WatchCallback is a callback to invoke when the key value is updated. +type WatchCallback func(event WatchEvent)