diff --git a/CHANGELOG.md b/CHANGELOG.md index 3019df4e0..5ac83d068 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added - Support iproto feature discovery (#120). +- Event subscription support (#119) ### Changed diff --git a/connection.go b/connection.go index 7255d2435..6d2989015 100644 --- a/connection.go +++ b/connection.go @@ -54,6 +54,8 @@ const ( // LogUnexpectedResultId is logged when response with unknown id was received. // Most probably it is due to request timeout. LogUnexpectedResultId + // LogWatchEventReadFailed is logged when failed to read a watch event. + LogWatchEventReadFailed ) // ConnEvent is sent throw Notify channel specified in Opts. @@ -63,6 +65,12 @@ type ConnEvent struct { When time.Time } +// A raw watch event. +type connWatchEvent struct { + key string + value interface{} +} + var epoch = time.Now() // Logger is logger type expected to be passed in options. @@ -84,6 +92,9 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac case LogUnexpectedResultId: resp := v[0].(*Response) log.Printf("tarantool: connection %s got unexpected resultId (%d) in response", conn.addr, resp.RequestId) + case LogWatchEventReadFailed: + err := v[0].(error) + log.Printf("tarantool: unable to parse watch event: %s", err) default: args := append([]interface{}{"tarantool: unexpected event ", event, conn}, v...) log.Print(args...) @@ -149,6 +160,8 @@ type Connection struct { lastStreamId uint64 serverProtocolInfo ProtocolInfo + // watchMap is a map of key -> watchSharedData. + watchMap sync.Map } var _ = Connector(&Connection{}) // Check compatibility with connector interface. @@ -531,7 +544,7 @@ func (conn *Connection) dial() (err error) { return fmt.Errorf("identify: %w", err) } - // Auth + // Auth. if opts.User != "" { scr, err := scramble(conn.Greeting.auth, opts.Pass) if err != nil { @@ -549,7 +562,41 @@ func (conn *Connection) dial() (err error) { } } - // Only if connected and authenticated. + // Watchers. + watchersChecked := false + conn.watchMap.Range(func(key, value interface{}) bool { + if !watchersChecked { + required := ProtocolInfo{Features: []ProtocolFeature{WatchersFeature}} + err = checkProtocolInfo(required, conn.ServerProtocolInfo()) + if err != nil { + return false + } + watchersChecked = true + } + + st := value.(chan watchState) + state := <-st + if state.unready != nil { + return true + } + + if state.cnt > 0 { + req := newWatchRequest(key.(string)) + if err = conn.writeRequest(w, req); err != nil { + st <- state + return false + } + state.ack = true + } + st <- state + return true + }) + + if err != nil { + return fmt.Errorf("unable to register watch: %w", err) + } + + // Only if connected and fully initialized. conn.lockShards() conn.c = connection atomic.StoreUint32(&conn.state, connConnected) @@ -843,7 +890,52 @@ func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { } } +func readWatchEvent(reader io.Reader) (connWatchEvent, error) { + keyExist := false + event := connWatchEvent{} + d := newDecoder(reader) + + l, err := d.DecodeMapLen() + if err != nil { + return event, err + } + + for ; l > 0; l-- { + cd, err := d.DecodeInt() + if err != nil { + return event, err + } + + switch cd { + case KeyEvent: + if event.key, err = d.DecodeString(); err != nil { + return event, err + } + keyExist = true + case KeyEventData: + if event.value, err = d.DecodeInterface(); err != nil { + return event, err + } + default: + if err = d.Skip(); err != nil { + return event, err + } + } + } + + if !keyExist { + return event, errors.New("watch event does not have a key") + } + + return event, nil +} + func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { + events := make(chan connWatchEvent, 1024) + defer close(events) + + go conn.eventer(events) + for atomic.LoadUint32(&conn.state) != connClosed { respBytes, err := conn.read(r) if err != nil { @@ -858,7 +950,14 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } var fut *Future = nil - if resp.Code == PushCode { + if resp.Code == EventCode { + if event, err := readWatchEvent(&resp.buf); err == nil { + events <- event + } else { + conn.opts.Logger.Report(LogWatchEventReadFailed, conn, err) + } + continue + } else if resp.Code == PushCode { if fut = conn.peekFuture(resp.RequestId); fut != nil { fut.AppendPush(resp) } @@ -868,12 +967,38 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { conn.markDone(fut) } } + if fut == nil { conn.opts.Logger.Report(LogUnexpectedResultId, conn, resp) } } } +// eventer goroutine gets watch events and updates values for watchers. +func (conn *Connection) eventer(events <-chan connWatchEvent) { + for event := range events { + if value, ok := conn.watchMap.Load(event.key); ok { + st := value.(chan watchState) + state := <-st + state.value = event.value + if state.version == math.MaxUint64 { + state.version = initWatchEventVersion + 1 + } else { + state.version += 1 + } + state.ack = false + if state.changed != nil { + close(state.changed) + state.changed = nil + } + st <- state + } + // It is possible to get IPROTO_EVENT after we already send + // IPROTO_UNWATCH due to processing on a Tarantool side or slow + // read from the network, so it looks like an expected behavior. + } +} + func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { @@ -1029,6 +1154,18 @@ func (conn *Connection) putFuture(fut *Future, req Request, streamId uint64) { return } shard.bufmut.Unlock() + + if req.Async() { + if fut = conn.fetchFuture(reqid); fut != nil { + resp := &Response{ + RequestId: reqid, + Code: OkCode, + } + fut.SetResponse(resp) + conn.markDone(fut) + } + } + if firstWritten { conn.dirtyShard <- shardn } @@ -1233,6 +1370,212 @@ func (conn *Connection) NewStream() (*Stream, error) { }, nil } +// watchState is the current state of the watcher. See the idea at p. 70, 105: +// https://drive.google.com/file/d/1nPdvhB0PutEJzdCq5ms6UI58dp50fcAN/view +type watchState struct { + // value is a current value. + value interface{} + // version is a current version of the value. The only reason for uint64: + // go 1.13 has no math.Uint. + version uint64 + // ack true if the acknowledge is already sended. + ack bool + // cnt is a count of active watchers for the key. + cnt int + // changed is a channel for broadcast the value changes. + changed chan struct{} + // unready channel exists if a state is not ready to work (subscribtion + // or unsubscribtion in progress). + unready chan struct{} +} + +// initWatchEventVersion is an initial version until no events from Tarantool. +const initWatchEventVersion = 0 + +// connWatcher is an internal implementation of the Watcher interface. +type connWatcher struct { + unregister sync.Once + done chan struct{} + finished chan struct{} +} + +// Unregister unregisters the connection watcher. +func (w *connWatcher) Unregister() { + w.unregister.Do(func() { + close(w.done) + }) + <-w.finished +} + +// subscribeWatchChannel returns an existing one or a new watch state channel +// for the key. It also increases a counter of active watchers for the channel. +func subscribeWatchChannel(conn *Connection, key string) (chan watchState, error) { + var st chan watchState + + for st == nil { + if val, ok := conn.watchMap.Load(key); !ok { + st = make(chan watchState, 1) + state := watchState{ + value: nil, + version: initWatchEventVersion, + ack: false, + cnt: 0, + changed: nil, + unready: make(chan struct{}), + } + st <- state + + if val, ok := conn.watchMap.LoadOrStore(key, st); !ok { + if _, err := conn.Do(newWatchRequest(key)).Get(); err != nil { + conn.watchMap.Delete(key) + close(state.unready) + return nil, err + } + // It is a successful subsctibtion to a watch events by itself. + state = <-st + state.cnt = 1 + close(state.unready) + state.unready = nil + st <- state + continue + } else { + close(state.unready) + close(st) + st = val.(chan watchState) + } + } else { + st = val.(chan watchState) + } + + // It is an existing channel created outside. It may be in the + // unready state. + state := <-st + if state.unready == nil { + state.cnt += 1 + } + st <- state + + if state.unready != nil { + // Wait for an update and retry. + <-state.unready + st = nil + } + } + + return st, nil +} + +// NewWatcher creates a new Watcher object for the connection. +// +// After watcher creation, the watcher callback is invoked for the first time. +// In this case, the callback is triggered whether or not the key has already +// been broadcast. All subsequent invocations are triggered with +// box.broadcast() called on the remote host. If a watcher is subscribed for a +// key that has not been broadcast yet, the callback is triggered only once, +// after the registration of the watcher. +// +// The watcher callbacks are always invoked in a separate goroutine. A watcher +// callback is never executed in parallel with itself, but they can be executed +// in parallel to other watchers. +// +// If the key is updated while the watcher callback is running, the callback +// will be invoked again with the latest value as soon as it returns. +// +// Watchers survive reconnection. All registered watchers are automatically +// resubscribed when the connection is reestablished. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, error) { + // We need to check the feature because the IPROTO_WATCH request is + // asynchronous. We do not expect any response from a Tarantool instance + // That's why we can't just check the Tarantool response for an unsupported + // request error. + required := ProtocolInfo{Features: []ProtocolFeature{WatchersFeature}} + if err := checkProtocolInfo(required, conn.ServerProtocolInfo()); err != nil { + return nil, err + } + + st, err := subscribeWatchChannel(conn, key) + if err != nil { + return nil, err + } + + // Start the watcher goroutine. + done := make(chan struct{}) + finished := make(chan struct{}) + + go func() { + var version uint64 = initWatchEventVersion + for { + state := <-st + if state.changed == nil { + state.changed = make(chan struct{}) + } + st <- state + + if state.version != version { + callback(WatchEvent{ + Conn: conn, + Key: key, + Value: state.value, + }) + version = state.version + + // Do we need to acknowledge the notification? + state = <-st + ack := !state.ack && version == state.version + if ack { + state.ack = true + } + st <- state + + if ack { + conn.Do(newWatchRequest(key)).Get() + // We expect a reconnect and re-subscribe if it fails to + // send the watch request. So it looks ok do not check a + // result. + } + } + + select { + case <-done: + state := <-st + state.cnt -= 1 + if state.cnt == 0 { + state.unready = make(chan struct{}) + } + st <- state + + if state.cnt == 0 { + // The last one sends IPROTO_UNWATCH. + conn.Do(newUnwatchRequest(key)).Get() + conn.watchMap.Delete(key) + close(state.unready) + } + + close(finished) + return + case <-state.changed: + } + } + }() + + return &connWatcher{ + done: done, + finished: finished, + }, nil +} + // checkProtocolInfo checks that expected protocol version is // and protocol features are supported. func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error { diff --git a/connection_pool/connection_pool.go b/connection_pool/connection_pool.go index 2891a88ab..4d1333d8c 100644 --- a/connection_pool/connection_pool.go +++ b/connection_pool/connection_pool.go @@ -91,12 +91,13 @@ type ConnectionPool struct { connOpts tarantool.Opts opts OptsPool - state state - done chan struct{} - roPool *RoundRobinStrategy - rwPool *RoundRobinStrategy - anyPool *RoundRobinStrategy - poolsMutex sync.RWMutex + state state + done chan struct{} + roPool *RoundRobinStrategy + rwPool *RoundRobinStrategy + anyPool *RoundRobinStrategy + poolsMutex sync.RWMutex + watcherContainer watcherContainer } var _ Pooler = (*ConnectionPool)(nil) @@ -640,25 +641,6 @@ func (connPool *ConnectionPool) ExecuteAsync(expr string, args interface{}, user return conn.ExecuteAsync(expr, args) } -// Do sends the request and returns a future. -// For requests that belong to an only one connection (e.g. Unprepare or ExecutePrepared) -// the argument of type Mode is unused. -func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { - if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { - conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) - if conn == nil { - return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) - } - return connectedReq.Conn().Do(req) - } - conn, err := connPool.getNextConnection(userMode) - if err != nil { - return newErrorFuture(err) - } - - return conn.Do(req) -} - // NewStream creates new Stream object for connection selected // by userMode from connPool. // @@ -682,6 +664,212 @@ func (connPool *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarant return conn.NewPrepared(expr) } +// watcherContainer is a very simple implementation of a thread-safe container +// for watchers. It is not expected that there will be too many watchers and +// they will registered/unregistered too frequently. +// +// Otherwise, the implementation will need to be optimized. +type watcherContainer struct { + head *poolWatcher + mutex sync.RWMutex +} + +// add adds a watcher to the container. +func (c *watcherContainer) add(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + watcher.next = c.head + c.head = watcher +} + +// remove removes a watcher from the container. +func (c *watcherContainer) remove(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if watcher == c.head { + c.head = watcher.next + } else { + cur := c.head + for cur.next != nil { + if cur.next == watcher { + cur.next = watcher.next + break + } + cur = cur.next + } + } +} + +// foreach iterates over the container to the end or until the call returns +// false. +func (c *watcherContainer) foreach(call func(watcher *poolWatcher) error) error { + cur := c.head + for cur != nil { + if err := call(cur); err != nil { + return err + } + cur = cur.next + } + return nil +} + +// poolWatcher is an internal implementation of the tarantool.Watcher interface. +type poolWatcher struct { + // The watcher container data. We can split the structure into two parts + // in the future: a watcher data and a watcher container data, but it looks + // simple at now. + + // next item in the watcher container. + next *poolWatcher + // container is the container for all active poolWatcher objects. + container *watcherContainer + + // The watcher data. + // mode of the watcher. + mode Mode + key string + callback tarantool.WatchCallback + // watchers is a map connection -> connection watcher. + watchers map[string]tarantool.Watcher + // unregistered is true if the watcher already unregistered. + unregistered bool + // mutex for the pool watcher. + mutex sync.Mutex +} + +// Unregister unregisters the pool watcher. +func (w *poolWatcher) Unregister() { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + w.container.remove(w) + w.unregistered = true + for _, watcher := range w.watchers { + watcher.Unregister() + } + } +} + +// watch adds a watcher for the connection. +func (w *poolWatcher) watch(conn *tarantool.Connection) error { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if _, ok := w.watchers[addr]; ok { + return nil + } + + if watcher, err := conn.NewWatcher(w.key, w.callback); err == nil { + w.watchers[addr] = watcher + return nil + } else { + return err + } + } + return nil +} + +// unwatch removes a watcher for the connection. +func (w *poolWatcher) unwatch(conn *tarantool.Connection) { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if watcher, ok := w.watchers[addr]; ok { + watcher.Unregister() + delete(w.watchers, addr) + } + } +} + +// NewWatcher creates a new Watcher object for the connection pool. +// +// You need to require WatchersFeature to use watchers from the pool, see +// examples for the function. +// +// The behavior is same as if Connection.NewWatcher() called for each +// connection with a suitable role. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (pool *ConnectionPool) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + watchersRequired := false + for _, feature := range pool.connOpts.RequiredProtocolInfo.Features { + if tarantool.WatchersFeature == feature { + watchersRequired = true + break + } + } + if !watchersRequired { + return nil, errors.New("the feature WatchersFeature must be " + + "required by connection options to create a watcher") + } + + watcher := &poolWatcher{ + container: &pool.watcherContainer, + mode: mode, + key: key, + callback: callback, + watchers: make(map[string]tarantool.Watcher), + unregistered: false, + } + + watcher.container.add(watcher) + + rr := pool.anyPool + if mode == RW { + rr = pool.rwPool + } else if mode == RO { + rr = pool.roPool + } + + conns := rr.GetConnections() + for _, conn := range conns { + if err := watcher.watch(conn); err != nil { + conn.Close() + } + } + + return watcher, nil +} + +// Do sends the request and returns a future. +// For requests that belong to the only one connection (e.g. Unprepare or ExecutePrepared) +// the argument of type Mode is unused. +func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) + if conn == nil { + return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + } + return connectedReq.Conn().Do(req) + } + conn, err := connPool.getNextConnection(userMode) + if err != nil { + return newErrorFuture(err) + } + + return conn.Do(req) +} + // // private // @@ -733,26 +921,63 @@ func (connPool *ConnectionPool) getConnectionFromPool(addr string) (*tarantool.C return connPool.anyPool.GetConnByAddr(addr), UnknownRole } -func (connPool *ConnectionPool) deleteConnection(addr string) { - if conn := connPool.anyPool.DeleteConnByAddr(addr); conn != nil { - if conn := connPool.rwPool.DeleteConnByAddr(addr); conn != nil { - return +func (pool *ConnectionPool) deleteConnection(addr string) { + if conn := pool.anyPool.DeleteConnByAddr(addr); conn != nil { + if conn := pool.rwPool.DeleteConnByAddr(addr); conn == nil { + pool.roPool.DeleteConnByAddr(addr) + } + // The internal connection deinitialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watcher.unwatch(conn) + return nil + }) + } +} + +func (pool *ConnectionPool) addConnection(addr string, + conn *tarantool.Connection, role Role) error { + // The internal connection initialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + watched := []*poolWatcher{} + err := pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watch := false + if watcher.mode == RW { + watch = role == MasterRole + } else if watcher.mode == RO { + watch = role == ReplicaRole + } else { + watch = true + } + if watch { + if err := watcher.watch(conn); err != nil { + return err + } + watched = append(watched, watcher) + } + return nil + }) + if err != nil { + for _, watcher := range watched { + watcher.unwatch(conn) } - connPool.roPool.DeleteConnByAddr(addr) + log.Printf("tarantool: failed initialize watchers for %s: %s", addr, err) + return err } -} - -func (connPool *ConnectionPool) addConnection(addr string, - conn *tarantool.Connection, role Role) { - connPool.anyPool.AddConn(addr, conn) + pool.anyPool.AddConn(addr, conn) switch role { case MasterRole: - connPool.rwPool.AddConn(addr, conn) + pool.rwPool.AddConn(addr, conn) case ReplicaRole: - connPool.roPool.AddConn(addr, conn) + pool.roPool.AddConn(addr, conn) } + return nil } func (connPool *ConnectionPool) handlerDiscovered(conn *tarantool.Connection, @@ -811,7 +1036,10 @@ func (connPool *ConnectionPool) fillPools() ([]connState, bool) { } if connPool.handlerDiscovered(conn, role) { - connPool.addConnection(addr, conn, role) + if connPool.addConnection(addr, conn, role) != nil { + conn.Close() + connPool.handlerDeactivated(conn, role) + } if conn.ConnectedNow() { states[i].conn = conn @@ -864,7 +1092,15 @@ func (pool *ConnectionPool) updateConnection(s connState) connState { return s } - pool.addConnection(s.addr, s.conn, role) + if pool.addConnection(s.addr, s.conn, role) != nil { + pool.poolsMutex.Unlock() + + s.conn.Close() + pool.handlerDeactivated(s.conn, role) + s.conn = nil + s.role = UnknownRole + return s + } s.role = role } } @@ -911,7 +1147,12 @@ func (pool *ConnectionPool) tryConnect(s connState) connState { return s } - pool.addConnection(s.addr, conn, role) + if pool.addConnection(s.addr, conn, role) != nil { + pool.poolsMutex.Unlock() + conn.Close() + pool.handlerDeactivated(conn, role) + return s + } s.conn = conn s.role = role } diff --git a/connection_pool/connection_pool_test.go b/connection_pool/connection_pool_test.go index 60c9b4f91..a185ad8ba 100644 --- a/connection_pool/connection_pool_test.go +++ b/connection_pool/connection_pool_test.go @@ -2048,6 +2048,329 @@ func TestStream_TxnIsolationLevel(t *testing.T) { } } +func TestConnectionPool_NewWatcher_noWatchersFeature(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_noWatchersFeature" + + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{} + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + watcher, err := pool.NewWatcher(key, + func(event tarantool.WatchEvent) {}, connection_pool.ANY) + require.Nilf(t, watcher, "watcher must not be created") + require.NotNilf(t, err, "an error is expected") + expected := "the feature WatchersFeature must be required by connection " + + "options to create a watcher" + require.Equal(t, expected, err.Error()) +} + +func TestConnectionPool_NewWatcher_modes(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_modes" + + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + modes := []connection_pool.Mode{ + connection_pool.ANY, + connection_pool.RW, + connection_pool.RO, + connection_pool.PreferRW, + connection_pool.PreferRO, + } + for _, mode := range modes { + t.Run(fmt.Sprintf("%d", mode), func(t *testing.T) { + expectedServers := []string{} + for i, server := range servers { + if roles[i] && mode == connection_pool.RW { + continue + } else if !roles[i] && mode == connection_pool.RO { + continue + } + expectedServers = append(expectedServers, server) + } + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to register a watcher") + defer watcher.Unregister() + + testMap := make(map[string]int) + + for i := 0; i < len(expectedServers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch event.") + break + } + } + + for _, server := range expectedServers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else if val != 1 { + t.Errorf("Too many events %d for server %s", val, server) + } + } + }) + } +} + +func TestConnectionPool_NewWatcher_update(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_update" + const mode = connection_pool.RW + const initCnt = 2 + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + defer watcher.Unregister() + + // Wait for all initial events. + testMap := make(map[string]int) + for i := 0; i < initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch init event.") + break + } + } + + // Just invert roles for simplify the test. + for i, role := range roles { + roles[i] = !role + } + err = test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + // Wait for all updated events. + for i := 0; i < len(servers)-initCnt; i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Errorf("Failed to get a watch update event.") + break + } + } + + // Check that all an event happen for an each connection. + for _, server := range servers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const mode = connection_pool.RW + const expectedCnt = 2 + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + for i := 0; i < expectedCnt; i++ { + select { + case <-events: + case <-time.After(time.Second): + t.Fatalf("Failed to skip initial events.") + } + } + watcher.Unregister() + + broadcast := tarantool.NewBroadcastRequest(key).Value("foo") + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } + + select { + case event := <-events: + t.Fatalf("Get unexpected event: %v", event) + case <-time.After(time.Second): + } + + // Reset to the initial state. + broadcast = tarantool.NewBroadcastRequest(key) + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } +} + +func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnectionPool_NewWatcher_concurrent" + + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + mode := connection_pool.ANY + callback := func(event tarantool.WatchEvent) {} + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + watcher, err := pool.NewWatcher(key, callback, mode) + if err != nil { + t.Errorf("Failed to create a watcher: %s", err) + } else { + watcher.Unregister() + } + }(i) + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + roles := []bool{true, false, false, true, true} + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + err := test_helpers.SetClusterRO(servers, opts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, opts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + mode := connection_pool.ANY + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/connection_pool/connector.go b/connection_pool/connector.go index e52109d92..c108aba0b 100644 --- a/connection_pool/connector.go +++ b/connection_pool/connector.go @@ -299,6 +299,14 @@ func (c *ConnectorAdapter) NewStream() (*tarantool.Stream, error) { return c.pool.NewStream(c.mode) } +// NewWatcher creates new Watcher object for the pool +// +// Since 1.10.0 +func (c *ConnectorAdapter) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return c.pool.NewWatcher(key, callback, c.mode) +} + // Do performs a request asynchronously on the connection. func (c *ConnectorAdapter) Do(req tarantool.Request) *tarantool.Future { return c.pool.Do(req, c.mode) diff --git a/connection_pool/connector_test.go b/connection_pool/connector_test.go index fa7cf06ba..f53a05b22 100644 --- a/connection_pool/connector_test.go +++ b/connection_pool/connector_test.go @@ -1139,6 +1139,45 @@ func TestConnectorNewStream(t *testing.T) { require.Equalf(t, testMode, m.mode, "unexpected proxy mode") } +type watcherMock struct{} + +func (w *watcherMock) Unregister() {} + +const reqWatchKey = "foo" + +var reqWatcher tarantool.Watcher = &watcherMock{} + +type newWatcherMock struct { + Pooler + key string + callback tarantool.WatchCallback + called int + mode Mode +} + +func (m *newWatcherMock) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + m.called++ + m.key = key + m.callback = callback + m.mode = mode + return reqWatcher, reqErr +} + +func TestConnectorNewWatcher(t *testing.T) { + m := &newWatcherMock{} + c := NewConnectorAdapter(m, testMode) + + w, err := c.NewWatcher(reqWatchKey, func(event tarantool.WatchEvent) {}) + + require.Equalf(t, reqWatcher, w, "unexpected watcher") + require.Equalf(t, reqErr, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqWatchKey, m.key, "unexpected key") + require.NotNilf(t, m.callback, "callback must be set") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + var reqRequest tarantool.Request = tarantool.NewPingRequest() type doMock struct { diff --git a/connection_pool/example_test.go b/connection_pool/example_test.go index 02715a2bc..cf59455ca 100644 --- a/connection_pool/example_test.go +++ b/connection_pool/example_test.go @@ -575,6 +575,59 @@ func ExampleConnectionPool_NewPrepared() { } } +func ExampleConnectionPool_NewWatcher() { + const key = "foo" + const value = "bar" + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{ + tarantool.WatchersFeature, + } + + pool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer pool.Close() + + callback := func(event tarantool.WatchEvent) { + fmt.Printf("event connection: %s\n", event.Conn.Addr()) + fmt.Printf("event key: %s\n", event.Key) + fmt.Printf("event value: %v\n", event.Value) + } + mode := connection_pool.ANY + watcher, err := pool.NewWatcher(key, callback, mode) + if err != nil { + fmt.Printf("Unexpected error: %s\n", err) + return + } + defer watcher.Unregister() + + pool.Do(tarantool.NewBroadcastRequest(key).Value(value), mode).Get() + time.Sleep(time.Second) +} + +func ExampleConnectionPool_NewWatcher_noWatchersFeature() { + const key = "foo" + + opts := connOpts.Clone() + opts.RequiredProtocolInfo.Features = []tarantool.ProtocolFeature{} + + pool, err := examplePool(testRoles, connOpts) + if err != nil { + fmt.Println(err) + } + defer pool.Close() + + callback := func(event tarantool.WatchEvent) {} + watcher, err := pool.NewWatcher(key, callback, connection_pool.ANY) + fmt.Println(watcher) + fmt.Println(err) + // Output: + // + // the feature WatchersFeature must be required by connection options to create a watcher +} + func getTestTxnOpts() tarantool.Opts { txnOpts := connOpts.Clone() diff --git a/connection_pool/pooler.go b/connection_pool/pooler.go index a9dbe09f9..856f5d5be 100644 --- a/connection_pool/pooler.go +++ b/connection_pool/pooler.go @@ -84,6 +84,8 @@ type Pooler interface { NewPrepared(expr string, mode Mode) (*tarantool.Prepared, error) NewStream(mode Mode) (*tarantool.Stream, error) + NewWatcher(key string, callback tarantool.WatchCallback, + mode Mode) (tarantool.Watcher, error) Do(req tarantool.Request, mode Mode) (fut *tarantool.Future) } diff --git a/connection_pool/round_robin.go b/connection_pool/round_robin.go index b83d877d9..a7fb73e18 100644 --- a/connection_pool/round_robin.go +++ b/connection_pool/round_robin.go @@ -14,6 +14,15 @@ type RoundRobinStrategy struct { current uint } +func NewEmptyRoundRobin(size int) *RoundRobinStrategy { + return &RoundRobinStrategy{ + conns: make([]*tarantool.Connection, 0, size), + indexByAddr: make(map[string]uint), + size: 0, + current: 0, + } +} + func (r *RoundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { r.mutex.RLock() defer r.mutex.RUnlock() @@ -71,13 +80,14 @@ func (r *RoundRobinStrategy) GetNextConnection() *tarantool.Connection { return r.conns[r.nextIndex()] } -func NewEmptyRoundRobin(size int) *RoundRobinStrategy { - return &RoundRobinStrategy{ - conns: make([]*tarantool.Connection, 0, size), - indexByAddr: make(map[string]uint), - size: 0, - current: 0, - } +func (r *RoundRobinStrategy) GetConnections() []*tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + ret := make([]*tarantool.Connection, len(r.conns)) + copy(ret, r.conns) + + return ret } func (r *RoundRobinStrategy) AddConn(addr string, conn *tarantool.Connection) { diff --git a/connection_pool/round_robin_test.go b/connection_pool/round_robin_test.go index 6b54ecfd8..03038eada 100644 --- a/connection_pool/round_robin_test.go +++ b/connection_pool/round_robin_test.go @@ -69,3 +69,22 @@ func TestRoundRobinGetNextConnection(t *testing.T) { } } } + +func TestRoundRobinStrategy_GetConnections(t *testing.T) { + rr := NewEmptyRoundRobin(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConn(addr, conns[i]) + } + + rr.GetConnections()[1] = conns[0] // GetConnections() returns a copy. + rrConns := rr.GetConnections() + for i, expected := range conns { + if expected != rrConns[i] { + t.Errorf("Unexpected connection on %d call", i) + } + } +} diff --git a/connector.go b/connector.go index d6c44c8dd..d93c69ec8 100644 --- a/connector.go +++ b/connector.go @@ -46,6 +46,7 @@ type Connector interface { NewPrepared(expr string) (*Prepared, error) NewStream() (*Stream, error) + NewWatcher(key string, callback WatchCallback) (Watcher, error) Do(req Request) (fut *Future) } diff --git a/const.go b/const.go index 35ec83380..4a6e3e2ef 100644 --- a/const.go +++ b/const.go @@ -19,6 +19,8 @@ const ( PingRequestCode = 64 SubscribeRequestCode = 66 IdRequestCode = 73 + WatchRequestCode = 74 + UnwatchRequestCode = 75 KeyCode = 0x00 KeySync = 0x01 @@ -45,6 +47,8 @@ const ( KeyVersion = 0x54 KeyFeatures = 0x55 KeyTimeout = 0x56 + KeyEvent = 0x57 + KeyEventData = 0x58 KeyTxnIsolation = 0x59 KeyFieldName = 0x00 @@ -73,6 +77,7 @@ const ( RLimitWait = 2 OkCode = uint32(0) + EventCode = uint32(0x4c) PushCode = uint32(0x80) ErrorCodeBit = 0x8000 PacketLengthBytes = 5 diff --git a/example_test.go b/example_test.go index 15574d099..e6db28567 100644 --- a/example_test.go +++ b/example_test.go @@ -329,7 +329,7 @@ func ExampleProtocolVersion() { fmt.Println("Connector client protocol features:", clientProtocolInfo.Features) // Output: // Connector client protocol version: 4 - // Connector client protocol features: [StreamsFeature TransactionsFeature] + // Connector client protocol features: [StreamsFeature TransactionsFeature WatchersFeature] } func getTestTxnOpts() tarantool.Opts { diff --git a/multi/multi.go b/multi/multi.go index 390186f27..ea950cfdf 100644 --- a/multi/multi.go +++ b/multi/multi.go @@ -507,6 +507,16 @@ func (connMulti *ConnectionMulti) NewStream() (*tarantool.Stream, error) { return connMulti.getCurrentConnection().NewStream() } +// NewWatcher does not supported by the ConnectionMulti. The ConnectionMulti is +// deprecated: use ConnectionPool instead. +// +// Since 1.10.0 +func (connMulti *ConnectionMulti) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return nil, errors.New("ConnectionMulti is deprecated " + + "use ConnectionPool") +} + // Do sends the request and returns a future. func (connMulti *ConnectionMulti) Do(req tarantool.Request) *tarantool.Future { if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { diff --git a/multi/multi_test.go b/multi/multi_test.go index 2d43bb179..ef07d629b 100644 --- a/multi/multi_test.go +++ b/multi/multi_test.go @@ -548,6 +548,30 @@ func TestStream_Rollback(t *testing.T) { } } +func TestConnectionMulti_NewWatcher(t *testing.T) { + test_helpers.SkipIfStreamsUnsupported(t) + + multiConn, err := Connect([]string{server1, server2}, connOpts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if multiConn == nil { + t.Fatalf("conn is nil after Connect") + } + defer multiConn.Close() + + watcher, err := multiConn.NewWatcher("foo", func(event tarantool.WatchEvent) {}) + if watcher != nil { + t.Errorf("Unexpected watcher") + } + if err == nil { + t.Fatalf("Unexpected success") + } + if err.Error() != "ConnectionMulti is deprecated use ConnectionPool" { + t.Fatalf("Unexpected error: %s", err) + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/protocol.go b/protocol.go index 1eaf60e2b..82608b6ab 100644 --- a/protocol.go +++ b/protocol.go @@ -80,6 +80,7 @@ var clientProtocolInfo ProtocolInfo = ProtocolInfo{ Features: []ProtocolFeature{ StreamsFeature, TransactionsFeature, + WatchersFeature, }, } diff --git a/request.go b/request.go index cfa40e522..66eb4be41 100644 --- a/request.go +++ b/request.go @@ -538,6 +538,8 @@ type Request interface { Body(resolver SchemaResolver, enc *encoder) error // Ctx returns a context of the request. Ctx() context.Context + // Async returns true if the request does not expect response. + Async() bool } // ConnectedRequest is an interface that provides the info about a Connection @@ -550,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + async bool ctx context.Context } @@ -558,6 +561,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Async returns true if the request does not require a response. +func (req *baseRequest) Async() bool { + return req.async +} + // Ctx returns a context of the request. func (req *baseRequest) Ctx() context.Context { return req.ctx diff --git a/request_test.go b/request_test.go index 89d1d8884..a078f6514 100644 --- a/request_test.go +++ b/request_test.go @@ -19,6 +19,7 @@ const invalidIndex = 2 const validSpace = 1 // Any valid value != default. const validIndex = 3 // Any valid value != default. const validExpr = "any string" // We don't check the value here. +const validKey = "foo" // Any string. const defaultSpace = 0 // And valid too. const defaultIndex = 0 // And valid too. @@ -183,6 +184,7 @@ func TestRequestsCodes(t *testing.T) { {req: NewBeginRequest(), code: BeginRequestCode}, {req: NewCommitRequest(), code: CommitRequestCode}, {req: NewRollbackRequest(), code: RollbackRequestCode}, + {req: NewBroadcastRequest(validKey), code: CallRequestCode}, } for _, test := range tests { @@ -192,6 +194,38 @@ func TestRequestsCodes(t *testing.T) { } } +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req Request + async bool + }{ + {req: NewSelectRequest(validSpace), async: false}, + {req: NewUpdateRequest(validSpace), async: false}, + {req: NewUpsertRequest(validSpace), async: false}, + {req: NewInsertRequest(validSpace), async: false}, + {req: NewReplaceRequest(validSpace), async: false}, + {req: NewDeleteRequest(validSpace), async: false}, + {req: NewCall16Request(validExpr), async: false}, + {req: NewCall17Request(validExpr), async: false}, + {req: NewEvalRequest(validExpr), async: false}, + {req: NewExecuteRequest(validExpr), async: false}, + {req: NewPingRequest(), async: false}, + {req: NewPrepareRequest(validExpr), async: false}, + {req: NewUnprepareRequest(validStmt), async: false}, + {req: NewExecutePreparedRequest(validStmt), async: false}, + {req: NewBeginRequest(), async: false}, + {req: NewCommitRequest(), async: false}, + {req: NewRollbackRequest(), async: false}, + {req: NewBroadcastRequest(validKey), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + func TestPingRequestDefaultValues(t *testing.T) { var refBuf bytes.Buffer @@ -649,3 +683,34 @@ func TestRollbackRequestDefaultValues(t *testing.T) { req := NewRollbackRequest() assertBodyEqual(t, refBuf.Bytes(), req) } + +func TestBroadcastRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey} + err := RefImplCallBody(refEnc, "box.broadcast", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplCallBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestBroadcastRequestSetters(t *testing.T) { + value := []interface{}{uint(34), int(12)} + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey, value} + err := RefImplCallBody(refEnc, "box.broadcast", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplCallBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey).Value(value) + assertBodyEqual(t, refBuf.Bytes(), req) +} diff --git a/tarantool_test.go b/tarantool_test.go index 31d287272..a04862f7e 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2868,8 +2868,12 @@ func TestConnectionProtocolInfoSupported(t *testing.T) { require.Equal(t, clientProtocolInfo, ProtocolInfo{ - Version: ProtocolVersion(4), - Features: []ProtocolFeature{StreamsFeature, TransactionsFeature}, + Version: ProtocolVersion(4), + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + WatchersFeature, + }, }) serverProtocolInfo := conn.ServerProtocolInfo() @@ -2997,8 +3001,12 @@ func TestConnectionProtocolInfoUnsupported(t *testing.T) { require.Equal(t, clientProtocolInfo, ProtocolInfo{ - Version: ProtocolVersion(4), - Features: []ProtocolFeature{StreamsFeature, TransactionsFeature}, + Version: ProtocolVersion(4), + Features: []ProtocolFeature{ + StreamsFeature, + TransactionsFeature, + WatchersFeature, + }, }) serverProtocolInfo := conn.ServerProtocolInfo() @@ -3148,6 +3156,358 @@ func TestConnectionFeatureOptsImmutable(t *testing.T) { require.True(t, connected, "Reconnect success") } +func TestConnection_NewWatcher(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestConnection_NewWatcher_unsupported(t *testing.T) { + test_helpers.SkipIfWatchersSupported(t) + + const key = "TestConnection_NewWatcher_unsupported" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if watcher != nil { + t.Errorf("Unexpected watcher: %v", watcher) + } + if err == nil { + t.Fatalf("An error expected.") + } + + const expected = "protocol feature WatchersFeature is not supported" + require.Equal(t, expected, err.Error()) +} + +func TestConnection_NewWatcher_reconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_reconnect" + const server = "127.0.0.1:3014" + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + InitScript: "config.lua", + Listen: server, + WorkDir: "work_dir", + User: opts.User, + Pass: opts.Pass, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 3, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + conn := test_helpers.ConnectWithValidation(t, server, reconnectOpts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events + + test_helpers.StopTarantool(inst) + if err := test_helpers.RestartTarantool(&inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + select { + case <-events: + case <-time.After(maxTime): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + resp, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + if resp.Code != OkCode { + t.Errorf("Got unexpected broadcast response code: %d", resp.Code) + } + if !reflect.DeepEqual(resp.Data, []interface{}{}) { + t.Errorf("Got unexpected broadcast response data: %v", resp.Data) + } + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest_multi(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest_multi" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events // Skip an initial event. + for i := 0; i < 10; i++ { + val := fmt.Sprintf("%d", i) + _, err := conn.Do(NewBroadcastRequest(key).Value(val)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != val { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event %d", i) + } + } +} + +func TestConnection_NewWatcher_multiOnKey(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_multiOnKey" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := []chan WatchEvent{ + make(chan WatchEvent), + make(chan WatchEvent), + } + for _, ch := range events { + defer close(ch) + } + + for _, ch := range events { + channel := ch + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + channel <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + } + + for i, ch := range events { + select { + case <-ch: // Skip an initial event. + case <-time.After(2 * time.Second): + t.Fatalf("Failed to skip watch event for %d callback", i) + } + } + + _, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + + for i, ch := range events { + select { + case event := <-ch: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event from callback %d", i) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + <-events + watcher.Unregister() + + _, err = conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + + select { + case event := <-events: + t.Fatalf("Get unexpected events: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnection_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnection_NewWatcher_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + var ret error + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + events := make(chan struct{}) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + close(events) + }) + if err != nil { + ret = err + } else { + select { + case <-events: + case <-time.After(time.Second): + ret = fmt.Errorf("Unable to get an event %d", i) + } + watcher.Unregister() + } + }(i) + } + wg.Wait() + + if ret != nil { + t.Fatalf("An error found: %s", ret) + } +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 93551e34a..19c18545e 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -17,6 +17,10 @@ func (sr *StrangerRequest) Code() int32 { return 0 } +func (sr *StrangerRequest) Async() bool { + return false +} + func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *encoder) error { return nil } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index dff0bb357..cdd1190da 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -74,23 +74,20 @@ func SkipIfSQLUnsupported(t testing.TB) { } } -func SkipIfStreamsUnsupported(t *testing.T) { +func skipIfLess2_10(t *testing.T, feature string) { t.Helper() - // Tarantool supports streams and interactive transactions since version 2.10.0 isLess, err := IsTarantoolVersionLess(2, 10, 0) if err != nil { t.Fatalf("Could not check the Tarantool version") } if isLess { - t.Skip("Skipping test for Tarantool without streams support") + t.Skipf("Skipping test for Tarantool without %s support", feature) } } -// SkipIfIdUnsupported skips test run if Tarantool without -// IPROTO_ID support is used. -func SkipIfIdUnsupported(t *testing.T) { +func skipIfGreaterOrEqual2_10(t *testing.T, feature string) { t.Helper() // Tarantool supports Id requests since version 2.10.0 @@ -99,24 +96,48 @@ func SkipIfIdUnsupported(t *testing.T) { t.Fatalf("Could not check the Tarantool version") } - if isLess { - t.Skip("Skipping test for Tarantool without id requests support") + if !isLess { + t.Skipf("Skipping test for Tarantool with supported %s", feature) } } +// SkipOfStreamsUnsupported skips test run if Tarantool without streams +// support is used. +func SkipIfStreamsUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t, "streams") +} + +// SkipOfStreamsUnsupported skips test run if Tarantool without watchers +// support is used. +func SkipIfWatchersUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t, "watchers") +} + +// SkipIfWatchersSupported skips test run if Tarantool with watchers +// support is used. +func SkipIfWatchersSupported(t *testing.T) { + t.Helper() + + skipIfGreaterOrEqual2_10(t, "watchers") +} + +// SkipIfIdUnsupported skips test run if Tarantool without +// IPROTO_ID support is used. +func SkipIfIdUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t, "id requests") +} + // SkipIfIdSupported skips test run if Tarantool with // IPROTO_ID support is used. Skip is useful for tests validating // that protocol info is processed as expected even for pre-IPROTO_ID instances. func SkipIfIdSupported(t *testing.T) { t.Helper() - // Tarantool supports Id requests since version 2.10.0 - isLess, err := IsTarantoolVersionLess(2, 10, 0) - if err != nil { - t.Fatalf("Could not check the Tarantool version") - } - - if !isLess { - t.Skip("Skipping test for Tarantool with non-zero protocol version and features") - } + skipIfGreaterOrEqual2_10(t, "id requests") } diff --git a/watch.go b/watch.go new file mode 100644 index 000000000..61631657c --- /dev/null +++ b/watch.go @@ -0,0 +1,138 @@ +package tarantool + +import ( + "context" +) + +// BroadcastRequest helps to send broadcast messages. See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/broadcast/ +type BroadcastRequest struct { + call *CallRequest + key string +} + +// NewBroadcastRequest returns a new broadcast request for a specified key. +func NewBroadcastRequest(key string) *BroadcastRequest { + req := new(BroadcastRequest) + req.key = key + req.call = NewCallRequest("box.broadcast").Args([]interface{}{key}) + return req +} + +// Value sets the value for the broadcast request. +// Note: default value is nil. +func (req *BroadcastRequest) Value(value interface{}) *BroadcastRequest { + req.call = req.call.Args([]interface{}{req.key, value}) + return req +} + +// Context sets a passed context to the broadcast request. +func (req *BroadcastRequest) Context(ctx context.Context) *BroadcastRequest { + req.call = req.call.Context(ctx) + return req +} + +// Code returns IPROTO code for the broadcast request. +func (req *BroadcastRequest) Code() int32 { + return req.call.Code() +} + +// Body fills an encoder with the broadcast request body. +func (req *BroadcastRequest) Body(res SchemaResolver, enc *encoder) error { + return req.call.Body(res, enc) +} + +// Ctx returns a context of the broadcast request. +func (req *BroadcastRequest) Ctx() context.Context { + return req.call.Ctx() +} + +// Async returns is the broadcast request expects a response. +func (req *BroadcastRequest) Async() bool { + return req.call.Async() +} + +// watchRequest subscribes to the updates of a specified key defined on the +// server. After receiving the notification, you should send a new +// watchRequest to acknowledge the notification. +type watchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newWatchRequest returns a new watchRequest. +func newWatchRequest(key string) *watchRequest { + req := new(watchRequest) + req.requestCode = WatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the watch request body. +func (req *watchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *watchRequest) Context(ctx context.Context) *watchRequest { + req.ctx = ctx + return req +} + +// unwatchRequest unregisters a watcher subscribed to the given notification +// key. +type unwatchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newUnwatchRequest returns a new unwatchRequest. +func newUnwatchRequest(key string) *unwatchRequest { + req := new(unwatchRequest) + req.requestCode = UnwatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the unwatch request body. +func (req *unwatchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *unwatchRequest) Context(ctx context.Context) *unwatchRequest { + req.ctx = ctx + return req +} + +// WatchEvent is a watch notification event received from a server. +type WatchEvent struct { + Conn *Connection // A source connection. + Key string // A key. + Value interface{} // A value. +} + +// Watcher is a subscription to broadcast events. +type Watcher interface { + // Unregister unregisters the watcher. + Unregister() +} + +// WatchCallback is a callback to invoke when the key value is updated. +type WatchCallback func(event WatchEvent)