From 0c8670f12532fd9233e3c73728480d2efd68f5aa Mon Sep 17 00:00:00 2001 From: SimFG Date: Fri, 2 Aug 2024 14:16:08 +0800 Subject: [PATCH] Add the privilege and db name param for the operate privilege api (#795) Signed-off-by: SimFG --- client/client.go | 4 ++-- client/rbac.go | 22 ++++++++++++++++++++-- client/rbac_test.go | 25 ++++++++++++++++--------- entity/rbac.go | 19 +++++++++++++++++++ 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/client/client.go b/client/client.go index 1728926e..472c91bb 100644 --- a/client/client.go +++ b/client/client.go @@ -202,9 +202,9 @@ type Client interface { // ListGrants lists all assigned privileges and objects for the role. ListGrants(ctx context.Context, role string, dbName string) ([]entity.RoleGrants, error) // Grant adds privilege for role. - Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string) error + Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error // Revoke removes privilege from role. - Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string) error + Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error // GetLoadingProgress get the collection or partitions loading progress GetLoadingProgress(ctx context.Context, collectionName string, partitionNames []string) (int64, error) diff --git a/client/rbac.go b/client/rbac.go index a94bfc05..8b71eb3a 100644 --- a/client/rbac.go +++ b/client/rbac.go @@ -320,12 +320,18 @@ func (c *GrpcClient) ListGrant(ctx context.Context, role string, object string, } // Grant adds object privileged for role. -func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string) error { +func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error { if c.Service == nil { return ErrClientNotReady } + grantOpt := &entity.OperatePrivilegeOpt{} + for _, opt := range options { + opt(grantOpt) + } + req := &milvuspb.OperatePrivilegeRequest{ + Base: grantOpt.Base, Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{ Name: role, @@ -339,6 +345,7 @@ func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.P }, }, ObjectName: object, + DbName: grantOpt.Database, }, Type: milvuspb.OperatePrivilegeType_Grant, } @@ -352,12 +359,17 @@ func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.P } // Revoke removes privilege from role. -func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string) error { +func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error { if c.Service == nil { return ErrClientNotReady } + revokeOpt := &entity.OperatePrivilegeOpt{} + for _, opt := range options { + opt(revokeOpt) + } req := &milvuspb.OperatePrivilegeRequest{ + Base: revokeOpt.Base, Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{ Name: role, @@ -366,6 +378,12 @@ func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity. Name: commonpb.ObjectType_name[int32(objectType)], }, ObjectName: object, + Grantor: &milvuspb.GrantorEntity{ + Privilege: &milvuspb.PrivilegeEntity{ + Name: privilege, + }, + }, + DbName: revokeOpt.Database, }, Type: milvuspb.OperatePrivilegeType_Revoke, } diff --git a/client/rbac_test.go b/client/rbac_test.go index 2e6673d9..666406f4 100644 --- a/client/rbac_test.go +++ b/client/rbac_test.go @@ -629,7 +629,8 @@ func (s *RBACSuite) TestGrant() { roleName := "testRole" objectName := testCollectionName objectType := entity.PriviledegeObjectTypeCollection - privilegeName := "testPrivilege" + dbName := "testDB" + privilege := "testPrivilege" s.Run("normal run", func() { ctx, cancel := context.WithCancel(ctx) @@ -640,9 +641,11 @@ func (s *RBACSuite) TestGrant() { s.Equal(objectName, req.GetEntity().GetObjectName()) s.Equal(commonpb.ObjectType_name[int32(objectType)], req.GetEntity().GetObject().GetName()) s.Equal(milvuspb.OperatePrivilegeType_Grant, req.GetType()) + s.Equal(privilege, req.GetEntity().GetGrantor().GetPrivilege().GetName()) + s.Equal(dbName, req.GetEntity().GetDbName()) }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) - err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName) + err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.NoError(err) }) @@ -653,7 +656,7 @@ func (s *RBACSuite) TestGrant() { defer s.resetMock() s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) - err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName) + err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) }) @@ -663,7 +666,7 @@ func (s *RBACSuite) TestGrant() { defer s.resetMock() s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) - err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName) + err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) }) @@ -672,7 +675,7 @@ func (s *RBACSuite) TestGrant() { defer cancel() c := &GrpcClient{} - err := c.Grant(ctx, roleName, objectType, objectName, privilegeName) + err := c.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) s.ErrorIs(err, ErrClientNotReady) }) @@ -684,6 +687,8 @@ func (s *RBACSuite) TestRevoke() { roleName := "testRole" objectName := testCollectionName objectType := entity.PriviledegeObjectTypeCollection + dbName := "testDB" + privilege := "testPrivilege" s.Run("normal run", func() { ctx, cancel := context.WithCancel(ctx) @@ -694,9 +699,11 @@ func (s *RBACSuite) TestRevoke() { s.Equal(objectName, req.GetEntity().GetObjectName()) s.Equal(commonpb.ObjectType_name[int32(objectType)], req.GetEntity().GetObject().GetName()) s.Equal(milvuspb.OperatePrivilegeType_Revoke, req.GetType()) + s.Equal(privilege, req.GetEntity().GetGrantor().GetPrivilege().GetName()) + s.Equal(dbName, req.GetEntity().GetDbName()) }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) - err := s.client.Revoke(ctx, roleName, objectType, objectName) + err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.NoError(err) }) @@ -707,7 +714,7 @@ func (s *RBACSuite) TestRevoke() { defer s.resetMock() s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) - err := s.client.Revoke(ctx, roleName, objectType, objectName) + err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) }) @@ -717,7 +724,7 @@ func (s *RBACSuite) TestRevoke() { defer s.resetMock() s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) - err := s.client.Revoke(ctx, roleName, objectType, objectName) + err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) }) @@ -727,7 +734,7 @@ func (s *RBACSuite) TestRevoke() { defer s.resetMock() c := &GrpcClient{} - err := c.Revoke(ctx, roleName, objectType, objectName) + err := c.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName)) s.Error(err) s.ErrorIs(err, ErrClientNotReady) }) diff --git a/entity/rbac.go b/entity/rbac.go index ef4e97d9..f215f288 100644 --- a/entity/rbac.go +++ b/entity/rbac.go @@ -42,3 +42,22 @@ const ( // PriviledegeObjectTypeGlobal const value for Global. PriviledegeObjectTypeGlobal PriviledgeObjectType = PriviledgeObjectType(common.ObjectType_Global) ) + +type OperatePrivilegeOpt struct { + Base *common.MsgBase + Database string +} + +type OperatePrivilegeOption func(o *OperatePrivilegeOpt) + +func WithOperatePrivilegeBase(base *common.MsgBase) OperatePrivilegeOption { + return func(o *OperatePrivilegeOpt) { + o.Base = base + } +} + +func WithOperatePrivilegeDatabase(database string) OperatePrivilegeOption { + return func(o *OperatePrivilegeOpt) { + o.Database = database + } +}