diff --git a/client/client.go b/client/client.go index 747d958f0..5ab3e10e6 100644 --- a/client/client.go +++ b/client/client.go @@ -444,6 +444,26 @@ func (c *Client) Subscribe( return attachment.rch, attachment.closeWatchStream, nil } +// Watch watches events on a given document. It is not necessary to be called +// outside of this package, but it is exposed for testing purposes. +func (c *Client) Watch(ctx context.Context, doc *document.Document) ( + *connect.ServerStreamForClient[api.WatchDocumentResponse], + error, +) { + attachment, ok := c.attachments[doc.Key()] + if !ok { + return nil, ErrDocumentNotAttached + } + + return c.client.WatchDocument( + ctx, + withShardKey(connect.NewRequest(&api.WatchDocumentRequest{ + ClientId: c.id.String(), + DocumentId: attachment.docID.String(), + }), c.options.APIKey, doc.Key().String()), + ) +} + // runWatchLoop subscribes to events on a given documentIDs. // If an error occurs before stream initialization, the second response, error, // is returned. If the context "watchCtx" is canceled or timed out, returned channel @@ -458,13 +478,7 @@ func (c *Client) runWatchLoop( return ErrDocumentNotAttached } - stream, err := c.client.WatchDocument( - ctx, - withShardKey(connect.NewRequest(&api.WatchDocumentRequest{ - ClientId: c.id.String(), - DocumentId: attachment.docID.String(), - }, - ), c.options.APIKey, doc.Key().String())) + stream, err := c.Watch(ctx, doc) if err != nil { return err } diff --git a/server/backend/database/client_info.go b/server/backend/database/client_info.go index 0c9e4f3e8..eed6ecee1 100644 --- a/server/backend/database/client_info.go +++ b/server/backend/database/client_info.go @@ -193,6 +193,15 @@ func (i *ClientInfo) UpdateCheckpoint( return nil } +// EnsureActivated ensures the client is activated. +func (i *ClientInfo) EnsureActivated() error { + if i.Status != ClientActivated { + return fmt.Errorf("ensure activated client(%s): %w", i.ID, ErrClientNotActivated) + } + + return nil +} + // EnsureDocumentAttached ensures the given document is attached. func (i *ClientInfo) EnsureDocumentAttached(docID types.ID) error { if i.Status != ClientActivated { diff --git a/server/clients/clients.go b/server/clients/clients.go index 9260147a9..817ea52f0 100644 --- a/server/clients/clients.go +++ b/server/clients/clients.go @@ -88,11 +88,20 @@ func Deactivate( return db.DeactivateClient(ctx, refKey) } -// FindClientInfo finds the client with the given refKey. -func FindClientInfo( +// FindActiveClientInfo find the active client info by the given ref key. +func FindActiveClientInfo( ctx context.Context, db database.Database, refKey types.ClientRefKey, ) (*database.ClientInfo, error) { - return db.FindClientInfoByRefKey(ctx, refKey) + info, err := db.FindClientInfoByRefKey(ctx, refKey) + if err != nil { + return nil, err + } + + if err := info.EnsureActivated(); err != nil { + return nil, err + } + + return info, nil } diff --git a/server/rpc/yorkie_server.go b/server/rpc/yorkie_server.go index c2a47310e..10b562e21 100644 --- a/server/rpc/yorkie_server.go +++ b/server/rpc/yorkie_server.go @@ -144,7 +144,7 @@ func (s *yorkieServer) AttachDocument( } }() - clientInfo, err := clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(actorID), }) @@ -217,7 +217,7 @@ func (s *yorkieServer) DetachDocument( } }() - clientInfo, err := clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(actorID), }) @@ -321,7 +321,7 @@ func (s *yorkieServer) PushPullChanges( syncMode = types.SyncModePushOnly } - clientInfo, err := clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(actorID), }) @@ -395,7 +395,7 @@ func (s *yorkieServer) WatchDocument( return err } - if _, err = clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + if _, err = clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(clientID), }); err != nil { @@ -514,7 +514,7 @@ func (s *yorkieServer) RemoveDocument( }() } - clientInfo, err := clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + clientInfo, err := clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(actorID), }) @@ -627,7 +627,7 @@ func (s *yorkieServer) Broadcast( return nil, err } - if _, err = clients.FindClientInfo(ctx, s.backend.DB, types.ClientRefKey{ + if _, err = clients.FindActiveClientInfo(ctx, s.backend.DB, types.ClientRefKey{ ProjectID: project.ID, ClientID: types.IDFromActorID(clientID), }); err != nil { diff --git a/server/server.go b/server/server.go index 6b5280094..764f3fe52 100644 --- a/server/server.go +++ b/server/server.go @@ -20,11 +20,16 @@ package server import ( + "context" gosync "sync" + "github.com/yorkie-team/yorkie/api/types" + "github.com/yorkie-team/yorkie/client" "github.com/yorkie-team/yorkie/server/backend" + "github.com/yorkie-team/yorkie/server/clients" "github.com/yorkie-team/yorkie/server/profiling" "github.com/yorkie-team/yorkie/server/profiling/prometheus" + "github.com/yorkie-team/yorkie/server/projects" "github.com/yorkie-team/yorkie/server/rpc" ) @@ -128,3 +133,17 @@ func (r *Yorkie) ShutdownCh() <-chan struct{} { func (r *Yorkie) RPCAddr() string { return r.conf.RPCAddr() } + +// DeactivateClient deactivates the given client. It is used for testing. +func (r *Yorkie) DeactivateClient(ctx context.Context, c1 *client.Client) error { + project, err := projects.GetProjectFromAPIKey(ctx, r.backend, "") + if err != nil { + return err + } + + _, err = clients.Deactivate(ctx, r.backend.DB, types.ClientRefKey{ + ProjectID: project.ID, + ClientID: types.IDFromActorID(c1.ID()), + }) + return err +} diff --git a/test/integration/client_test.go b/test/integration/client_test.go index e37fe4413..3e9593d7e 100644 --- a/test/integration/client_test.go +++ b/test/integration/client_test.go @@ -20,8 +20,10 @@ package integration import ( "context" + "sync" "testing" + "connectrpc.com/connect" "github.com/stretchr/testify/assert" "github.com/yorkie-team/yorkie/client" @@ -172,4 +174,36 @@ func TestClient(t *testing.T) { assert.Equal(t, doc.Checkpoint(), change.Checkpoint{ClientSeq: 4, ServerSeq: 4}) assert.Equal(t, "2", doc.Root().GetCounter("counter").Marshal()) }) + + t.Run("deactivated client's stream test", func(t *testing.T) { + ctx := context.Background() + + c1, err := client.Dial(defaultServer.RPCAddr()) + assert.NoError(t, err) + assert.NoError(t, c1.Activate(ctx)) + + d1 := document.New(helper.TestDocKey(t)) + + // 01. Attach the document and subscribe. + assert.NoError(t, c1.Attach(ctx, d1)) + + // 02. Deactivate the client and try to watch. + assert.NoError(t, defaultServer.DeactivateClient(ctx, c1)) + + wg := sync.WaitGroup{} + wg.Add(1) + stream, _ := c1.Watch(ctx, d1) + + go func() { + defer wg.Done() + + stream.Receive() + if err = stream.Err(); err != nil { + assert.Equal(t, connect.CodeFailedPrecondition, connect.CodeOf(err)) + return + } + }() + + wg.Wait() + }) }