diff --git a/client/v3/watch.go b/client/v3/watch.go index b73925ba128..5bd2d4cd0cd 100644 --- a/client/v3/watch.go +++ b/client/v3/watch.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "time" @@ -580,6 +581,26 @@ func (w *watchGrpcStream) run() { switch { case pbresp.Created: + cancelReasonError := v3rpc.Error(errors.New(pbresp.CancelReason)) + if shouldRetryWatch(cancelReasonError) { + var newErr error + if wc, newErr = w.newWatchClient(); newErr != nil { + w.lg.Error("failed to create a new watch client", zap.Error(newErr)) + return + } + + if len(w.resuming) != 0 { + if ws := w.resuming[0]; ws != nil { + if err := wc.Send(ws.initReq.toPB()); err != nil { + w.lg.Debug("error when sending request", zap.Error(err)) + } + } + } + + cur = nil + continue + } + // response to head of queue creation if len(w.resuming) != 0 { if ws := w.resuming[0]; ws != nil { @@ -688,6 +709,11 @@ func (w *watchGrpcStream) run() { } } +func shouldRetryWatch(cancelReasonError error) bool { + return (strings.Compare(cancelReasonError.Error(), v3rpc.ErrGRPCInvalidAuthToken.Error()) == 0) || + (strings.Compare(cancelReasonError.Error(), v3rpc.ErrGRPCAuthOldRevision.Error()) == 0) +} + // nextResume chooses the next resuming to register with the grpc stream. Abandoned // streams are marked as nil in the queue since the head must wait for its inflight registration. func (w *watchGrpcStream) nextResume() *watcherStream { diff --git a/server/etcdserver/api/v3rpc/watch.go b/server/etcdserver/api/v3rpc/watch.go index 1a3cff539f6..c206f7002e9 100644 --- a/server/etcdserver/api/v3rpc/watch.go +++ b/server/etcdserver/api/v3rpc/watch.go @@ -224,16 +224,16 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { return err } -func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool { +func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) error { authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context()) if err != nil { - return false + return err } if authInfo == nil { // if auth is enabled, IsRangePermitted() can cause an error authInfo = &auth.AuthInfo{} } - return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil + return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) } func (sws *serverWatchStream) recvLoop() error { @@ -267,13 +267,29 @@ func (sws *serverWatchStream) recvLoop() error { creq.RangeEnd = []byte{} } - if !sws.isWatchPermitted(creq) { + err := sws.isWatchPermitted(creq) + if err != nil { + var cancelReason string + switch err { + case auth.ErrInvalidAuthToken: + cancelReason = rpctypes.ErrGRPCInvalidAuthToken.Error() + case auth.ErrAuthOldRevision: + cancelReason = rpctypes.ErrGRPCAuthOldRevision.Error() + case auth.ErrUserEmpty: + cancelReason = rpctypes.ErrGRPCUserEmpty.Error() + default: + if err != auth.ErrPermissionDenied { + sws.lg.Error("unexpected error code", zap.Error(err)) + } + cancelReason = rpctypes.ErrGRPCPermissionDenied.Error() + } + wr := &pb.WatchResponse{ Header: sws.newResponseHeader(sws.watchStream.Rev()), WatchId: creq.WatchId, Canceled: true, Created: true, - CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(), + CancelReason: cancelReason, } select { diff --git a/tests/framework/integration/cluster.go b/tests/framework/integration/cluster.go index 82d4a7b41bb..a9b0a098f49 100644 --- a/tests/framework/integration/cluster.go +++ b/tests/framework/integration/cluster.go @@ -139,7 +139,8 @@ type ClusterConfig struct { DiscoveryURL string - AuthToken string + AuthToken string + AuthTokenTTL uint QuotaBackendBytes int64 @@ -263,6 +264,7 @@ func (c *Cluster) mustNewMember(t testutil.TB) *Member { Name: fmt.Sprintf("m%v", memberNumber), MemberNumber: memberNumber, AuthToken: c.Cfg.AuthToken, + AuthTokenTTL: c.Cfg.AuthTokenTTL, PeerTLS: c.Cfg.PeerTLS, ClientTLS: c.Cfg.ClientTLS, QuotaBackendBytes: c.Cfg.QuotaBackendBytes, @@ -586,6 +588,7 @@ type MemberConfig struct { PeerTLS *transport.TLSInfo ClientTLS *transport.TLSInfo AuthToken string + AuthTokenTTL uint QuotaBackendBytes int64 MaxTxnOps uint MaxRequestBytes uint @@ -679,6 +682,9 @@ func MustNewMember(t testutil.TB, mcfg MemberConfig) *Member { if mcfg.AuthToken != "" { m.AuthToken = mcfg.AuthToken } + if mcfg.AuthTokenTTL != 0 { + m.TokenTTL = mcfg.AuthTokenTTL + } m.BcryptCost = uint(bcrypt.MinCost) // use min bcrypt cost to speedy up integration testing diff --git a/tests/integration/v3_auth_test.go b/tests/integration/v3_auth_test.go index fd300ba78ef..2169b6f36f2 100644 --- a/tests/integration/v3_auth_test.go +++ b/tests/integration/v3_auth_test.go @@ -497,3 +497,36 @@ func TestV3AuthRestartMember(t *testing.T) { _, err = c2.Put(context.TODO(), "foo", "bar2") testutil.AssertNil(t, err) } + +func TestV3AuthWatchAndTokenExpire(t *testing.T) { + integration.BeforeTest(t) + clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1, AuthTokenTTL: 3}) + defer clus.Terminate(t) + + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + authSetupRoot(t, integration.ToGRPC(clus.Client(0)).Auth) + + c, cerr := integration.NewClient(t, clientv3.Config{Endpoints: clus.Client(0).Endpoints(), Username: "root", Password: "123"}) + if cerr != nil { + t.Fatal(cerr) + } + defer c.Close() + + _, err := c.Put(ctx, "key", "val") + if err != nil { + t.Fatalf("Unexpected error from Put: %v", err) + } + + // The first watch gets a valid auth token through watcher.newWatcherGrpcStream() + // We should discard the first one by waiting TTL after the first watch. + wChan := c.Watch(ctx, "key", clientv3.WithRev(1)) + watchResponse := <-wChan + + time.Sleep(5 * time.Second) + + wChan = c.Watch(ctx, "key", clientv3.WithRev(1)) + watchResponse = <-wChan + testutil.AssertNil(t, watchResponse.Err()) +}