diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fb75664b..782dda025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Added +- Event subscription support (#119) + ### Changed ### Fixed diff --git a/connection.go b/connection.go index a4ae8cc36..3f06bb173 100644 --- a/connection.go +++ b/connection.go @@ -53,6 +53,8 @@ const ( // LogUnexpectedResultId is logged when response with unknown id was received. // Most probably it is due to request timeout. LogUnexpectedResultId + // LogReadWatchEventFailed is logged when failed to read a watch event. + LogReadWatchEventFailed ) // ConnEvent is sent throw Notify channel specified in Opts. @@ -62,6 +64,12 @@ type ConnEvent struct { When time.Time } +// A raw watch event. +type connWatchEvent struct { + key string + value interface{} +} + var epoch = time.Now() // Logger is logger type expected to be passed in options. @@ -83,6 +91,9 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac case LogUnexpectedResultId: resp := v[0].(*Response) log.Printf("tarantool: connection %s got unexpected resultId (%d) in response", conn.addr, resp.RequestId) + case LogReadWatchEventFailed: + err := v[0].(error) + log.Printf("tarantool: unable to parse watch event: %s\n", err) default: args := append([]interface{}{"tarantool: unexpected event ", event, conn}, v...) log.Print(args...) @@ -146,6 +157,9 @@ type Connection struct { lenbuf [PacketLengthBytes]byte lastStreamId uint64 + + // watchMap is a map of key -> watchSharedData. + watchMap sync.Map } var _ = Connector(&Connection{}) // Check compatibility with connector interface. @@ -502,7 +516,7 @@ func (conn *Connection) dial() (err error) { conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String() conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String() - // Auth + // Auth. if opts.User != "" { scr, err := scramble(conn.Greeting.auth, opts.Pass) if err != nil { @@ -520,7 +534,20 @@ func (conn *Connection) dial() (err error) { } } - // Only if connected and authenticated. + // Watchers. + conn.watchMap.Range(func(key, value interface{}) bool { + req := newWatchRequest(key.(string)) + if err = conn.writeRequest(w, req); err != nil { + return false + } + return true + }) + + if err != nil { + return fmt.Errorf("unable to register watch: %w", err) + } + + // Only if connected and fully initialized. conn.lockShards() conn.c = connection atomic.StoreUint32(&conn.state, connConnected) @@ -581,23 +608,33 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { +func (conn *Connection) writeRequest(w *bufio.Writer, req Request) (err error) { var packet smallWBuf - req := newAuthRequest(conn.opts.User, string(scramble)) err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) if err != nil { - return errors.New("auth: pack error " + err.Error()) + return fmt.Errorf("pack error %w", err) } if err := write(w, packet.b); err != nil { - return errors.New("auth: write error " + err.Error()) + return fmt.Errorf("write error %w", err) } if err = w.Flush(); err != nil { - return errors.New("auth: flush error " + err.Error()) + return fmt.Errorf("flush error %w", err) } return } +func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { + req := newAuthRequest(conn.opts.User, string(scramble)) + + err = conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil +} + func (conn *Connection) readAuthResponse(r io.Reader) (err error) { respBytes, err := conn.read(r) if err != nil { @@ -774,7 +811,50 @@ func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { } } +func readWatchEvent(reader io.Reader) (connWatchEvent, error) { + keyExist := false + event := connWatchEvent{} + d := newDecoder(reader) + + if l, err := d.DecodeMapLen(); err == nil { + for ; l > 0; l-- { + if cd, err := d.DecodeInt(); err == nil { + switch cd { + case KeyEvent: + if event.key, err = d.DecodeString(); err != nil { + return event, err + } + keyExist = true + case KeyEventData: + if event.value, err = d.DecodeInterface(); err != nil { + return event, err + } + default: + if err = d.Skip(); err != nil { + return event, err + } + } + } else { + return event, err + } + } + } else { + return event, err + } + + if !keyExist { + return event, errors.New("watch event does not have a key") + } + + return event, nil +} + func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { + events := make(chan connWatchEvent, 1024) + defer close(events) + + go conn.eventer(events) + for atomic.LoadUint32(&conn.state) != connClosed { respBytes, err := conn.read(r) if err != nil { @@ -789,7 +869,14 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } var fut *Future = nil - if resp.Code == PushCode { + if resp.Code == EventCode { + if event, err := readWatchEvent(&resp.buf); err == nil { + events <- event + } else { + conn.opts.Logger.Report(LogReadWatchEventFailed, conn, err) + } + continue + } else if resp.Code == PushCode { if fut = conn.peekFuture(resp.RequestId); fut != nil { fut.AppendPush(resp) } @@ -799,12 +886,42 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { conn.markDone(fut) } } + if fut == nil { conn.opts.Logger.Report(LogUnexpectedResultId, conn, resp) } } } +// eventer goroutine gets watch events and updates values for watchers. +func (conn *Connection) eventer(events <-chan connWatchEvent) { + for { + event, ok := <-events + if !ok { + // The channel is closed. + break + } + + if value, ok := conn.watchMap.Load(event.key); ok { + shared := value.(*watchSharedData) + shared.condMutex.Lock() + shared.value = event.value + shared.version += 1 + shared.condMutex.Unlock() + + shared.cond.Broadcast() + + if atomic.LoadUint32(&conn.state) == connConnected { + shared.watchMutex.Lock() + if shared.watchCnt > 0 { + conn.Do(newWatchRequest(event.key)) + } + shared.watchMutex.Unlock() + } + } + } +} + func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { @@ -960,6 +1077,18 @@ func (conn *Connection) putFuture(fut *Future, req Request, streamId uint64) { return } shard.bufmut.Unlock() + + if req.Async() { + if fut = conn.fetchFuture(reqid); fut != nil { + resp := &Response{ + RequestId: reqid, + Code: OkCode, + } + fut.SetResponse(resp) + conn.markDone(fut) + } + } + if firstWritten { conn.dirtyShard <- shardn } @@ -1163,3 +1292,157 @@ func (conn *Connection) NewStream() (*Stream, error) { Conn: conn, }, nil } + +// watchSharedData is a shared between watchers of some key. +type watchSharedData struct { + // value is a last value of the key. + value interface{} + // version is a version of the value. It increases by 1 on the value + // update. + version uint + // The goroutine that gets events updates the value and the version under + // the write lock. Watchers read the value and the version under read lock. + cond *sync.Cond + condMutex sync.RWMutex + + // watchCnt is a number of active watchers. + watchCnt int32 + // watchMutex helps to send IPROTO_WATCH/IPROTO_UNWATCH without duplicates + // and intersections. + watchMutex sync.Mutex +} + +// connWatcher is an internal implementation of the Watcher interface. +type connWatcher struct { + shared *watchSharedData + unregister sync.Once + unregistered bool + finished chan struct{} +} + +// Unregister unregisters the connection watcher. +func (w *connWatcher) Unregister() { + w.unregister.Do(func() { + // The Lock/Unlock helps to update the w.unregistered state with the + // watcher goroutine. + w.shared.condMutex.Lock() + w.unregistered = true + w.shared.condMutex.Unlock() + w.shared.cond.Broadcast() + }) + <-w.finished +} + +// NewWatcher creates a new Watcher object for the connection. +// +// After watcher creation, the watcher callback is invoked for the first time. +// In this case, the callback is triggered whether or not the key has already +// been broadcast. All subsequent invocations are triggered with +// box.broadcast() called on the remote host. If a watcher is subscribed for a +// key that has not been broadcast yet, the callback is triggered only once, +// after the registration of the watcher. +// +// The watcher callbacks are always invoked in a separate goroutine. A watcher +// callback is never executed in parallel with itself, but they can be executed +// in parallel to other watchers. +// +// If the key is updated while the watcher callback is running, the callback +// will be invoked again with the latest value as soon as it returns. +// +// Watchers survive reconnection. All registered watchers are automatically +// resubscribed when the connection is reestablished. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (conn *Connection) NewWatcher(key string, callback WatchCallback) (Watcher, error) { + // TODO: check required features after: + // + // https://github.com/tarantool/go-tarantool/issues/120 + var shared *watchSharedData + // Get or create a shared data for the key. + if val, ok := conn.watchMap.Load(key); !ok { + shared = &watchSharedData{ + value: nil, + version: 0, + watchCnt: 0, + } + shared.cond = sync.NewCond(shared.condMutex.RLocker()) + + if val, ok := conn.watchMap.LoadOrStore(key, shared); ok { + shared = val.(*watchSharedData) + } + } else { + shared = val.(*watchSharedData) + } + + // Send an initial watch request. + shared.watchMutex.Lock() + if shared.watchCnt == 0 { + shared.condMutex.Lock() + shared.version = 0 + shared.condMutex.Unlock() + + if _, err := conn.Do(newWatchRequest(key)).Get(); err != nil { + shared.watchMutex.Unlock() + return nil, err + } + } + shared.watchCnt += 1 + shared.watchMutex.Unlock() + + // Start the watcher goroutine. + version := uint(0) + + watcher := &connWatcher{ + shared: shared, + unregistered: false, + finished: make(chan struct{}), + } + + go func() { + for { + shared.cond.L.Lock() + for { + if watcher.unregistered { + shared.cond.L.Unlock() + + shared.watchMutex.Lock() + shared.watchCnt -= 1 + if shared.watchCnt == 0 { + // A last one sends IPROTO_UNWATCH. + conn.Do(newUnwatchRequest(key)) + } + shared.watchMutex.Unlock() + + close(watcher.finished) + return + } + if version != shared.version { + break + } + shared.cond.Wait() + } + + value := shared.value + version = shared.version + shared.cond.L.Unlock() + + callback(WatchEvent{ + Conn: conn, + Key: key, + Value: value, + }) + } + }() + + return watcher, nil +} diff --git a/connection_pool/connection_pool.go b/connection_pool/connection_pool.go index 6597e2dd0..a9e59a66d 100644 --- a/connection_pool/connection_pool.go +++ b/connection_pool/connection_pool.go @@ -91,12 +91,13 @@ type ConnectionPool struct { connOpts tarantool.Opts opts OptsPool - state state - done chan struct{} - roPool *RoundRobinStrategy - rwPool *RoundRobinStrategy - anyPool *RoundRobinStrategy - poolsMutex sync.RWMutex + state state + done chan struct{} + roPool *RoundRobinStrategy + rwPool *RoundRobinStrategy + anyPool *RoundRobinStrategy + poolsMutex sync.RWMutex + watcherContainer watcherContainer } var _ Pooler = (*ConnectionPool)(nil) @@ -640,25 +641,6 @@ func (connPool *ConnectionPool) ExecuteAsync(expr string, args interface{}, user return conn.ExecuteAsync(expr, args) } -// Do sends the request and returns a future. -// For requests that belong to an only one connection (e.g. Unprepare or ExecutePrepared) -// the argument of type Mode is unused. -func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { - if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { - conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) - if conn == nil { - return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) - } - return connectedReq.Conn().Do(req) - } - conn, err := connPool.getNextConnection(userMode) - if err != nil { - return newErrorFuture(err) - } - - return conn.Do(req) -} - // NewStream creates new Stream object for connection selected // by userMode from connPool. // @@ -682,6 +664,200 @@ func (connPool *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarant return conn.NewPrepared(expr) } +// watcherContainer is a very simple implementation of a thread-safe container +// for watchares. It is not expected that there will be too many watchers and +// they will registered/unregistered too frequently. +// +// Otherwise, the implementation will need to be optimized. +type watcherContainer struct { + head *poolWatcher + mutex sync.RWMutex +} + +// add adds a watcher to the container. +func (c *watcherContainer) add(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + watcher.next = c.head + c.head = watcher +} + +// remove removes a watcher from the container. +func (c *watcherContainer) remove(watcher *poolWatcher) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if watcher == c.head { + c.head = watcher.next + } else { + cur := c.head + for cur.next != nil { + if cur.next == watcher { + cur.next = watcher.next + break + } + cur = cur.next + } + } +} + +// foreach iterates over the container to the end or until the call returns +// false. +func (c *watcherContainer) foreach(call func(watcher *poolWatcher) error) error { + cur := c.head + for cur != nil { + if err := call(cur); err != nil { + return err + } + cur = cur.next + } + return nil +} + +// poolWatcher is an internal implementation of the tarantool.Watcher interface. +type poolWatcher struct { + // The watcher container data. We can split the structure into two parts + // in the future: a watcher data and a watcher container data, but it looks + // simple at now. + + // next item in the watcher container. + next *poolWatcher + // container is the container for all active poolWatcher objects. + container *watcherContainer + + // The watcher data. + // mode of the watcher. + mode Mode + key string + callback tarantool.WatchCallback + // watchers is a map connection -> connection watcher. + watchers map[string]tarantool.Watcher + // unregistered is true if the watcher already unregistered. + unregistered bool + // mutex for the pool watcher. + mutex sync.Mutex +} + +// Unregister unregisters the pool watcher. +func (w *poolWatcher) Unregister() { + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + w.container.remove(w) + w.unregistered = true + for _, watcher := range w.watchers { + watcher.Unregister() + } + } +} + +// watch adds a watcher for the connection. +func (w *poolWatcher) watch(conn *tarantool.Connection) error { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if _, ok := w.watchers[addr]; ok { + return nil + } + + if watcher, err := conn.NewWatcher(w.key, w.callback); err == nil { + w.watchers[addr] = watcher + return nil + } else { + return err + } + } + return nil +} + +// unwatch removes a watcher for the connection. +func (w *poolWatcher) unwatch(conn *tarantool.Connection) { + addr := conn.Addr() + + w.mutex.Lock() + defer w.mutex.Unlock() + + if !w.unregistered { + if watcher, ok := w.watchers[addr]; ok { + watcher.Unregister() + delete(w.watchers, addr) + } + } +} + +// NewWatcher creates a new Watcher object for the connection pool. +// +// The behavior is same as if Connection.NewWatcher() called for each +// connection with a suitable role. +// +// Keep in mind that garbage collection of a watcher handle doesn’t lead to the +// watcher’s destruction. In this case, the watcher remains registered. You +// need to call Unregister() directly. +// +// Unregister() guarantees that there will be no the watcher's callback calls +// after it, but Unregister() call from the callback leads to a deadlock. +// +// See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/#box-watchers +// +// Since 1.10.0 +func (pool *ConnectionPool) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + watcher := &poolWatcher{ + container: &pool.watcherContainer, + mode: mode, + key: key, + callback: callback, + watchers: make(map[string]tarantool.Watcher), + unregistered: false, + } + + watcher.container.add(watcher) + + rr := pool.anyPool + if mode == RW { + rr = pool.rwPool + } else if mode == RO { + rr = pool.roPool + } + + conns := rr.GetConnections() + for _, conn := range conns { + // TODO: check required features after: + // + // https://github.com/tarantool/go-tarantool/issues/120 + if err := watcher.watch(conn); err != nil { + conn.Close() + } + } + + return watcher, nil +} + +// Do sends the request and returns a future. +// For requests that belong to an only one connection (e.g. Unprepare or ExecutePrepared) +// the argument of type Mode is unused. +func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) + if conn == nil { + return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + } + return connectedReq.Conn().Do(req) + } + conn, err := connPool.getNextConnection(userMode) + if err != nil { + return newErrorFuture(err) + } + + return conn.Do(req) +} + // // private // @@ -733,26 +909,63 @@ func (connPool *ConnectionPool) getConnectionFromPool(addr string) (*tarantool.C return connPool.anyPool.GetConnByAddr(addr), UnknownRole } -func (connPool *ConnectionPool) deleteConnection(addr string) { - if conn := connPool.anyPool.DeleteConnByAddr(addr); conn != nil { - if conn := connPool.rwPool.DeleteConnByAddr(addr); conn != nil { - return +func (pool *ConnectionPool) deleteConnection(addr string) { + if conn := pool.anyPool.DeleteConnByAddr(addr); conn != nil { + if conn := pool.rwPool.DeleteConnByAddr(addr); conn == nil { + pool.roPool.DeleteConnByAddr(addr) + } + // The internal connection deinitialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watcher.unwatch(conn) + return nil + }) + } +} + +func (pool *ConnectionPool) addConnection(addr string, + conn *tarantool.Connection, role Role) error { + // The internal connection initialization. + pool.watcherContainer.mutex.RLock() + defer pool.watcherContainer.mutex.RUnlock() + + watched := []*poolWatcher{} + err := pool.watcherContainer.foreach(func(watcher *poolWatcher) error { + watch := false + if watcher.mode == RW { + watch = role == MasterRole + } else if watcher.mode == RO { + watch = role == ReplicaRole + } else { + watch = true + } + if watch { + if err := watcher.watch(conn); err != nil { + return err + } + watched = append(watched, watcher) } - connPool.roPool.DeleteConnByAddr(addr) + return nil + }) + if err != nil { + for _, watcher := range watched { + watcher.unwatch(conn) + } + log.Printf("tarantool: failed initialize watchers for %s: %s", addr, err) + return err } -} - -func (connPool *ConnectionPool) addConnection(addr string, - conn *tarantool.Connection, role Role) { - connPool.anyPool.AddConn(addr, conn) + pool.anyPool.AddConn(addr, conn) switch role { case MasterRole: - connPool.rwPool.AddConn(addr, conn) + pool.rwPool.AddConn(addr, conn) case ReplicaRole: - connPool.roPool.AddConn(addr, conn) + pool.roPool.AddConn(addr, conn) } + return nil } func (connPool *ConnectionPool) handlerDiscovered(conn *tarantool.Connection, @@ -811,7 +1024,10 @@ func (connPool *ConnectionPool) fillPools() ([]connState, bool) { } if connPool.handlerDiscovered(conn, role) { - connPool.addConnection(addr, conn, role) + if connPool.addConnection(addr, conn, role) != nil { + conn.Close() + connPool.handlerDeactivated(conn, role) + } if conn.ConnectedNow() { states[i].conn = conn @@ -864,7 +1080,15 @@ func (pool *ConnectionPool) updateConnection(s connState) connState { return s } - pool.addConnection(s.addr, s.conn, role) + if pool.addConnection(s.addr, s.conn, role) != nil { + pool.poolsMutex.Unlock() + + s.conn.Close() + pool.handlerDeactivated(s.conn, role) + s.conn = nil + s.role = UnknownRole + return s + } s.role = role } } @@ -911,7 +1135,12 @@ func (pool *ConnectionPool) tryConnect(s connState) connState { return s } - pool.addConnection(s.addr, conn, role) + if pool.addConnection(s.addr, conn, role) != nil { + pool.poolsMutex.Unlock() + conn.Close() + pool.handlerDeactivated(conn, role) + return s + } s.conn = conn s.role = role } diff --git a/connection_pool/connection_pool_test.go b/connection_pool/connection_pool_test.go index 60c9b4f91..6250f6619 100644 --- a/connection_pool/connection_pool_test.go +++ b/connection_pool/connection_pool_test.go @@ -2048,6 +2048,255 @@ func TestStream_TxnIsolationLevel(t *testing.T) { } } +func TestConnectionPool_NewWatcher_modes(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_modes" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + modes := []connection_pool.Mode{ + connection_pool.ANY, + connection_pool.RW, + connection_pool.RO, + connection_pool.PreferRW, + connection_pool.PreferRO, + } + for _, mode := range modes { + t.Run(fmt.Sprintf("%d", mode), func(t *testing.T) { + expectedServers := []string{} + for i, server := range servers { + if roles[i] && mode == connection_pool.RW { + continue + } else if !roles[i] && mode == connection_pool.RO { + continue + } + expectedServers = append(expectedServers, server) + } + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to register a watcher") + defer watcher.Unregister() + + testMap := make(map[string]int) + + for i := 0; i < len(expectedServers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } + } + + for _, server := range expectedServers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } + }) + } +} + +func TestConnectionPool_NewWatcher_update(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnectionPool_NewWatcher_update" + const mode = connection_pool.RW + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + defer watcher.Unregister() + + // Just invert roles for simplify the test. + for i, role := range roles { + roles[i] = !role + } + err = test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + testMap := make(map[string]int) + for i := 0; i < len(servers); i++ { + select { + case event := <-events: + require.NotNil(t, event.Conn) + addr := event.Conn.Addr() + if val, ok := testMap[addr]; ok { + testMap[addr] = val + 1 + } else { + testMap[addr] = 1 + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } + } + + for _, server := range servers { + if val, ok := testMap[server]; !ok { + t.Errorf("Server not found: %s", server) + } else { + require.Equal(t, val, 1, fmt.Sprintf("for server %s", server)) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const mode = connection_pool.RW + const expectedCnt = 2 + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + events := make(chan tarantool.WatchEvent, 1024) + defer close(events) + + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + require.Equal(t, key, event.Key) + require.Equal(t, nil, event.Value) + events <- event + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + for i := 0; i < expectedCnt; i++ { + select { + case <-events: + case <-time.After(time.Second): + t.Fatalf("Failed to skip initial events.") + } + } + watcher.Unregister() + + broadcast := tarantool.NewBroadcastRequest(key).Value("foo") + for i := 0; i < expectedCnt; i++ { + _, err := pool.Do(broadcast, mode).Get() + require.Nilf(t, err, "failed to send a broadcast request") + } + + select { + case event := <-events: + t.Fatalf("Get unexpected event: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnectionPool_NewWatcher_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + mode := connection_pool.ANY + callback := func(event tarantool.WatchEvent) {} + for i := 0; i < testConcurrency; i++ { + go func(i int) { + defer wg.Done() + + watcher, err := pool.NewWatcher(key, callback, mode) + if err != nil { + t.Errorf("Failed to create a watcher: %s", err) + } else { + watcher.Unregister() + } + }(i) + } + wg.Wait() +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + + roles := []bool{true, false, false, true, true} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + pool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, pool, "conn is nil after Connect") + defer pool.Close() + + mode := connection_pool.ANY + watcher, err := pool.NewWatcher(key, func(event tarantool.WatchEvent) { + }, mode) + require.Nilf(t, err, "failed to create a watcher") + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/connection_pool/connector.go b/connection_pool/connector.go index e52109d92..c108aba0b 100644 --- a/connection_pool/connector.go +++ b/connection_pool/connector.go @@ -299,6 +299,14 @@ func (c *ConnectorAdapter) NewStream() (*tarantool.Stream, error) { return c.pool.NewStream(c.mode) } +// NewWatcher creates new Watcher object for the pool +// +// Since 1.10.0 +func (c *ConnectorAdapter) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return c.pool.NewWatcher(key, callback, c.mode) +} + // Do performs a request asynchronously on the connection. func (c *ConnectorAdapter) Do(req tarantool.Request) *tarantool.Future { return c.pool.Do(req, c.mode) diff --git a/connection_pool/connector_test.go b/connection_pool/connector_test.go index fa7cf06ba..f53a05b22 100644 --- a/connection_pool/connector_test.go +++ b/connection_pool/connector_test.go @@ -1139,6 +1139,45 @@ func TestConnectorNewStream(t *testing.T) { require.Equalf(t, testMode, m.mode, "unexpected proxy mode") } +type watcherMock struct{} + +func (w *watcherMock) Unregister() {} + +const reqWatchKey = "foo" + +var reqWatcher tarantool.Watcher = &watcherMock{} + +type newWatcherMock struct { + Pooler + key string + callback tarantool.WatchCallback + called int + mode Mode +} + +func (m *newWatcherMock) NewWatcher(key string, + callback tarantool.WatchCallback, mode Mode) (tarantool.Watcher, error) { + m.called++ + m.key = key + m.callback = callback + m.mode = mode + return reqWatcher, reqErr +} + +func TestConnectorNewWatcher(t *testing.T) { + m := &newWatcherMock{} + c := NewConnectorAdapter(m, testMode) + + w, err := c.NewWatcher(reqWatchKey, func(event tarantool.WatchEvent) {}) + + require.Equalf(t, reqWatcher, w, "unexpected watcher") + require.Equalf(t, reqErr, err, "unexpected error") + require.Equalf(t, 1, m.called, "should be called only once") + require.Equalf(t, reqWatchKey, m.key, "unexpected key") + require.NotNilf(t, m.callback, "callback must be set") + require.Equalf(t, testMode, m.mode, "unexpected proxy mode") +} + var reqRequest tarantool.Request = tarantool.NewPingRequest() type doMock struct { diff --git a/connection_pool/pooler.go b/connection_pool/pooler.go index a9dbe09f9..856f5d5be 100644 --- a/connection_pool/pooler.go +++ b/connection_pool/pooler.go @@ -84,6 +84,8 @@ type Pooler interface { NewPrepared(expr string, mode Mode) (*tarantool.Prepared, error) NewStream(mode Mode) (*tarantool.Stream, error) + NewWatcher(key string, callback tarantool.WatchCallback, + mode Mode) (tarantool.Watcher, error) Do(req tarantool.Request, mode Mode) (fut *tarantool.Future) } diff --git a/connection_pool/round_robin.go b/connection_pool/round_robin.go index b83d877d9..a7fb73e18 100644 --- a/connection_pool/round_robin.go +++ b/connection_pool/round_robin.go @@ -14,6 +14,15 @@ type RoundRobinStrategy struct { current uint } +func NewEmptyRoundRobin(size int) *RoundRobinStrategy { + return &RoundRobinStrategy{ + conns: make([]*tarantool.Connection, 0, size), + indexByAddr: make(map[string]uint), + size: 0, + current: 0, + } +} + func (r *RoundRobinStrategy) GetConnByAddr(addr string) *tarantool.Connection { r.mutex.RLock() defer r.mutex.RUnlock() @@ -71,13 +80,14 @@ func (r *RoundRobinStrategy) GetNextConnection() *tarantool.Connection { return r.conns[r.nextIndex()] } -func NewEmptyRoundRobin(size int) *RoundRobinStrategy { - return &RoundRobinStrategy{ - conns: make([]*tarantool.Connection, 0, size), - indexByAddr: make(map[string]uint), - size: 0, - current: 0, - } +func (r *RoundRobinStrategy) GetConnections() []*tarantool.Connection { + r.mutex.RLock() + defer r.mutex.RUnlock() + + ret := make([]*tarantool.Connection, len(r.conns)) + copy(ret, r.conns) + + return ret } func (r *RoundRobinStrategy) AddConn(addr string, conn *tarantool.Connection) { diff --git a/connection_pool/round_robin_test.go b/connection_pool/round_robin_test.go index 6b54ecfd8..03038eada 100644 --- a/connection_pool/round_robin_test.go +++ b/connection_pool/round_robin_test.go @@ -69,3 +69,22 @@ func TestRoundRobinGetNextConnection(t *testing.T) { } } } + +func TestRoundRobinStrategy_GetConnections(t *testing.T) { + rr := NewEmptyRoundRobin(10) + + addrs := []string{validAddr1, validAddr2} + conns := []*tarantool.Connection{&tarantool.Connection{}, &tarantool.Connection{}} + + for i, addr := range addrs { + rr.AddConn(addr, conns[i]) + } + + rr.GetConnections()[1] = conns[0] // GetConnections() returns a copy. + rrConns := rr.GetConnections() + for i, expected := range conns { + if expected != rrConns[i] { + t.Errorf("Unexpected connection on %d call", i) + } + } +} diff --git a/connector.go b/connector.go index d6c44c8dd..d93c69ec8 100644 --- a/connector.go +++ b/connector.go @@ -46,6 +46,7 @@ type Connector interface { NewPrepared(expr string) (*Prepared, error) NewStream() (*Stream, error) + NewWatcher(key string, callback WatchCallback) (Watcher, error) Do(req Request) (fut *Future) } diff --git a/const.go b/const.go index 4a3cb6833..95a0d366d 100644 --- a/const.go +++ b/const.go @@ -18,6 +18,8 @@ const ( RollbackRequestCode = 16 PingRequestCode = 64 SubscribeRequestCode = 66 + WatchRequestCode = 74 + UnwatchRequestCode = 75 KeyCode = 0x00 KeySync = 0x01 @@ -42,6 +44,8 @@ const ( KeySQLInfo = 0x42 KeyStmtID = 0x43 KeyTimeout = 0x56 + KeyEvent = 0x57 + KeyEventData = 0x58 KeyTxnIsolation = 0x59 KeyFieldName = 0x00 @@ -70,6 +74,7 @@ const ( RLimitWait = 2 OkCode = uint32(0) + EventCode = uint32(0x4c) PushCode = uint32(0x80) ErrorCodeBit = 0x8000 PacketLengthBytes = 5 diff --git a/multi/multi.go b/multi/multi.go index 67f450c5c..9d3828dd7 100644 --- a/multi/multi.go +++ b/multi/multi.go @@ -507,6 +507,16 @@ func (connMulti *ConnectionMulti) NewStream() (*tarantool.Stream, error) { return connMulti.getCurrentConnection().NewStream() } +// NewWatcher does not supported by the ConnectionMulti. The ConnectionMulti is +// deprecated: use ConnectionPool instead. +// +// Since 1.10.0 +func (connMulti *ConnectionMulti) NewWatcher(key string, + callback tarantool.WatchCallback) (tarantool.Watcher, error) { + return nil, errors.New("ConnectionMulti is deprecated " + + "use ConnectionPool") +} + // Do sends the request and returns a future. func (connMulti *ConnectionMulti) Do(req tarantool.Request) *tarantool.Future { if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { diff --git a/multi/multi_test.go b/multi/multi_test.go index 2d43bb179..ef07d629b 100644 --- a/multi/multi_test.go +++ b/multi/multi_test.go @@ -548,6 +548,30 @@ func TestStream_Rollback(t *testing.T) { } } +func TestConnectionMulti_NewWatcher(t *testing.T) { + test_helpers.SkipIfStreamsUnsupported(t) + + multiConn, err := Connect([]string{server1, server2}, connOpts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if multiConn == nil { + t.Fatalf("conn is nil after Connect") + } + defer multiConn.Close() + + watcher, err := multiConn.NewWatcher("foo", func(event tarantool.WatchEvent) {}) + if watcher != nil { + t.Errorf("Unexpected watcher") + } + if err == nil { + t.Fatalf("Unexpected success") + } + if err.Error() != "ConnectionMulti is deprecated use ConnectionPool" { + t.Fatalf("Unexpected error: %s", err) + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/request.go b/request.go index cfa40e522..66eb4be41 100644 --- a/request.go +++ b/request.go @@ -538,6 +538,8 @@ type Request interface { Body(resolver SchemaResolver, enc *encoder) error // Ctx returns a context of the request. Ctx() context.Context + // Async returns true if the request does not expect response. + Async() bool } // ConnectedRequest is an interface that provides the info about a Connection @@ -550,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + async bool ctx context.Context } @@ -558,6 +561,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Async returns true if the request does not require a response. +func (req *baseRequest) Async() bool { + return req.async +} + // Ctx returns a context of the request. func (req *baseRequest) Ctx() context.Context { return req.ctx diff --git a/request_test.go b/request_test.go index 89d1d8884..a680cdcbb 100644 --- a/request_test.go +++ b/request_test.go @@ -19,6 +19,7 @@ const invalidIndex = 2 const validSpace = 1 // Any valid value != default. const validIndex = 3 // Any valid value != default. const validExpr = "any string" // We don't check the value here. +const validKey = "foo" // Any string. const defaultSpace = 0 // And valid too. const defaultIndex = 0 // And valid too. @@ -183,6 +184,7 @@ func TestRequestsCodes(t *testing.T) { {req: NewBeginRequest(), code: BeginRequestCode}, {req: NewCommitRequest(), code: CommitRequestCode}, {req: NewRollbackRequest(), code: RollbackRequestCode}, + {req: NewBroadcastRequest(validKey), code: EvalRequestCode}, } for _, test := range tests { @@ -192,6 +194,38 @@ func TestRequestsCodes(t *testing.T) { } } +func TestRequestsAsync(t *testing.T) { + tests := []struct { + req Request + async bool + }{ + {req: NewSelectRequest(validSpace), async: false}, + {req: NewUpdateRequest(validSpace), async: false}, + {req: NewUpsertRequest(validSpace), async: false}, + {req: NewInsertRequest(validSpace), async: false}, + {req: NewReplaceRequest(validSpace), async: false}, + {req: NewDeleteRequest(validSpace), async: false}, + {req: NewCall16Request(validExpr), async: false}, + {req: NewCall17Request(validExpr), async: false}, + {req: NewEvalRequest(validExpr), async: false}, + {req: NewExecuteRequest(validExpr), async: false}, + {req: NewPingRequest(), async: false}, + {req: NewPrepareRequest(validExpr), async: false}, + {req: NewUnprepareRequest(validStmt), async: false}, + {req: NewExecutePreparedRequest(validStmt), async: false}, + {req: NewBeginRequest(), async: false}, + {req: NewCommitRequest(), async: false}, + {req: NewRollbackRequest(), async: false}, + {req: NewBroadcastRequest(validKey), async: false}, + } + + for _, test := range tests { + if async := test.req.Async(); async != test.async { + t.Errorf("An invalid async %t, expected %t", async, test.async) + } + } +} + func TestPingRequestDefaultValues(t *testing.T) { var refBuf bytes.Buffer @@ -649,3 +683,34 @@ func TestRollbackRequestDefaultValues(t *testing.T) { req := NewRollbackRequest() assertBodyEqual(t, refBuf.Bytes(), req) } + +func TestBroadcastRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestBroadcastRequestSetters(t *testing.T) { + value := []interface{}{uint(34), int(12)} + var refBuf bytes.Buffer + + refEnc := NewEncoder(&refBuf) + expectedArgs := []interface{}{validKey, value} + err := RefImplEvalBody(refEnc, "box.broadcast(...)", expectedArgs) + if err != nil { + t.Errorf("An unexpected RefImplEvalBody() error: %q", err.Error()) + return + } + + req := NewBroadcastRequest(validKey).Value(value) + assertBodyEqual(t, refBuf.Bytes(), req) +} diff --git a/tarantool_test.go b/tarantool_test.go index 1350390f9..b12d6f9c2 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2830,6 +2830,330 @@ func TestStream_DoWithClosedConn(t *testing.T) { } } +func TestConnection_NewWatcher(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != nil { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestConnection_NewWatcher_reconnect(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_reconnect" + const server = "127.0.0.1:3014" + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + InitScript: "config.lua", + Listen: server, + WorkDir: "work_dir", + User: opts.User, + Pass: opts.Pass, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 3, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 10 + conn := test_helpers.ConnectWithValidation(t, server, reconnectOpts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events + + test_helpers.StopTarantool(inst) + if err := test_helpers.RestartTarantool(&inst); err != nil { + t.Fatalf("Unable to restart Tarantool: %s", err) + } + + maxTime := reconnectOpts.Reconnect * time.Duration(reconnectOpts.MaxReconnects) + select { + case <-events: + case <-time.After(maxTime): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + resp, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + if resp.Code != OkCode { + t.Errorf("Got unexpected broadcast response code: %d", resp.Code) + } + if !reflect.DeepEqual(resp.Data, []interface{}{}) { + t.Errorf("Got unexpected broadcast response data: %v", resp.Data) + } + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event.") + } +} + +func TestBroadcastRequest_multi(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestBroadcastRequest_multi" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + + <-events // Skip an initial event. + for i := 0; i < 10; i++ { + val := fmt.Sprintf("%d", i) + _, err := conn.Do(NewBroadcastRequest(key).Value(val)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + select { + case event := <-events: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != val { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(time.Second): + t.Fatalf("Failed to get watch event %d", i) + } + } +} + +func TestConnection_NewWatcher_multiOnKey(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestConnection_NewWatcher_multiOnKey" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := []chan WatchEvent{ + make(chan WatchEvent), + make(chan WatchEvent), + } + for _, ch := range events { + defer close(ch) + } + + for _, ch := range events { + channel := ch + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + channel <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + defer watcher.Unregister() + } + + for i, ch := range events { + select { + case <-ch: // Skip an initial event. + case <-time.After(2 * time.Second): + t.Fatalf("Failed to skip watch event for %d callback", i) + } + } + + _, err := conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Failed to send a broadcast request: %s", err) + } + + for i, ch := range events { + select { + case event := <-ch: + if event.Conn != conn { + t.Errorf("Unexpected event connection: %v", event.Conn) + } + if event.Key != key { + t.Errorf("Unexpected event key: %s", event.Key) + } + if event.Value.(string) != value { + t.Errorf("Unexpected event value: %v", event.Value) + } + case <-time.After(2 * time.Second): + t.Fatalf("Failed to get watch event from callback %d", i) + } + } +} + +func TestWatcher_Unregister(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const key = "TestWatcher_Unregister" + const value = "bar" + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + events := make(chan WatchEvent) + defer close(events) + watcher, err := conn.NewWatcher(key, func(event WatchEvent) { + events <- event + }) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + <-events + watcher.Unregister() + + _, err = conn.Do(NewBroadcastRequest(key).Value(value)).Get() + if err != nil { + t.Fatalf("Got broadcast error: %s", err) + } + + select { + case event := <-events: + t.Fatalf("Get unexpected events: %v", event) + case <-time.After(time.Second): + } +} + +func TestConnection_NewWatcher_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestConnection_NewWatcher_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + var ret error + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + ret = err + } else { + watcher.Unregister() + } + }() + } + wg.Wait() + + if ret != nil { + t.Fatalf("Unable to create a watcher: %s", ret) + } +} + +func TestWatcher_Unregister_concurrent(t *testing.T) { + test_helpers.SkipIfWatchersUnsupported(t) + + const testConcurrency = 1000 + const key = "TestWatcher_Unregister_concurrent" + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + watcher, err := conn.NewWatcher(key, func(event WatchEvent) {}) + if err != nil { + t.Fatalf("Failed to create a watch: %s", err) + } + + var wg sync.WaitGroup + wg.Add(testConcurrency) + + for i := 0; i < testConcurrency; i++ { + go func() { + defer wg.Done() + watcher.Unregister() + }() + } + wg.Wait() +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 93551e34a..19c18545e 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -17,6 +17,10 @@ func (sr *StrangerRequest) Code() int32 { return 0 } +func (sr *StrangerRequest) Async() bool { + return false +} + func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *encoder) error { return nil } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index c936e90b3..4672a7de8 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -53,10 +53,9 @@ func SkipIfSQLUnsupported(t testing.TB) { } } -func SkipIfStreamsUnsupported(t *testing.T) { +func skipIfLess2_10(t *testing.T) { t.Helper() - // Tarantool supports streams and interactive transactions since version 2.10.0 isLess, err := IsTarantoolVersionLess(2, 10, 0) if err != nil { t.Fatalf("Could not check the Tarantool version") @@ -66,3 +65,15 @@ func SkipIfStreamsUnsupported(t *testing.T) { t.Skip("Skipping test for Tarantool without streams support") } } + +func SkipIfStreamsUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t) +} + +func SkipIfWatchersUnsupported(t *testing.T) { + t.Helper() + + skipIfLess2_10(t) +} diff --git a/watch.go b/watch.go new file mode 100644 index 000000000..2bd91b4bf --- /dev/null +++ b/watch.go @@ -0,0 +1,138 @@ +package tarantool + +import ( + "context" +) + +// BroadcastRequest helps to send broadcast messages. See: +// https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_events/broadcast/ +type BroadcastRequest struct { + eval *EvalRequest + key string +} + +// NewBroadcastRequest returns a new broadcast request for a specified key. +func NewBroadcastRequest(key string) *BroadcastRequest { + req := new(BroadcastRequest) + req.key = key + req.eval = NewEvalRequest("box.broadcast(...)").Args([]interface{}{key}) + return req +} + +// Value sets the value for the broadcast request. +// Note: default value is nil. +func (req *BroadcastRequest) Value(value interface{}) *BroadcastRequest { + req.eval = req.eval.Args([]interface{}{req.key, value}) + return req +} + +// Context sets a passed context to the broadcast request. +func (req *BroadcastRequest) Context(ctx context.Context) *BroadcastRequest { + req.eval = req.eval.Context(ctx) + return req +} + +// Code returns IPROTO code for the broadcast request. +func (req *BroadcastRequest) Code() int32 { + return req.eval.Code() +} + +// Body fills an encoder with the broadcast request body. +func (req *BroadcastRequest) Body(res SchemaResolver, enc *encoder) error { + return req.eval.Body(res, enc) +} + +// Ctx returns a context of the broadcast request. +func (req *BroadcastRequest) Ctx() context.Context { + return req.eval.Ctx() +} + +// Async returns is the broadcast request expects a response. +func (req *BroadcastRequest) Async() bool { + return req.eval.Async() +} + +// watchRequest subscribes to the updates of a specified key defined on the +// server. After receiving the notification, you should send a new +// watchRequest to acknowledge the notification. +type watchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newWatchRequest returns a new watchRequest. +func newWatchRequest(key string) *watchRequest { + req := new(watchRequest) + req.requestCode = WatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the watch request body. +func (req *watchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *watchRequest) Context(ctx context.Context) *watchRequest { + req.ctx = ctx + return req +} + +// unwatchRequest unregisters a watcher subscribed to the given notification +// key. +type unwatchRequest struct { + baseRequest + key string + ctx context.Context +} + +// newUnwatchRequest returns a new unwatchRequest. +func newUnwatchRequest(key string) *unwatchRequest { + req := new(unwatchRequest) + req.requestCode = UnwatchRequestCode + req.async = true + req.key = key + return req +} + +// Body fills an encoder with the unwatch request body. +func (req *unwatchRequest) Body(res SchemaResolver, enc *encoder) error { + if err := enc.EncodeMapLen(1); err != nil { + return err + } + if err := encodeUint(enc, KeyEvent); err != nil { + return err + } + return enc.EncodeString(req.key) +} + +// Context sets a passed context to the request. +func (req *unwatchRequest) Context(ctx context.Context) *unwatchRequest { + req.ctx = ctx + return req +} + +// WatchEvent is a watch notification event received from a server. +type WatchEvent struct { + Conn *Connection // A source connection. + Key string // A key. + Value interface{} // A value. +} + +// Watcher is a subscription to broadcast events. +type Watcher interface { + // Unregister unregisters the watcher. + Unregister() +} + +// WatchCallback is a callback to invoke when the key value is updated. +type WatchCallback func(event WatchEvent)