Skip to content

Commit

Permalink
Add the privilege and db name param for the operate privilege api (mi…
Browse files Browse the repository at this point in the history
…lvus-io#795)

Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Aug 2, 2024
1 parent 78139c2 commit 0c8670f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ type Client interface {
// ListGrants lists all assigned privileges and objects for the role.
ListGrants(ctx context.Context, role string, dbName string) ([]entity.RoleGrants, error)
// Grant adds privilege for role.
Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string) error
Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error
// Revoke removes privilege from role.
Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string) error
Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error

// GetLoadingProgress get the collection or partitions loading progress
GetLoadingProgress(ctx context.Context, collectionName string, partitionNames []string) (int64, error)
Expand Down
22 changes: 20 additions & 2 deletions client/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,18 @@ func (c *GrpcClient) ListGrant(ctx context.Context, role string, object string,
}

// Grant adds object privileged for role.
func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string) error {
func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error {
if c.Service == nil {
return ErrClientNotReady
}

grantOpt := &entity.OperatePrivilegeOpt{}
for _, opt := range options {
opt(grantOpt)
}

req := &milvuspb.OperatePrivilegeRequest{
Base: grantOpt.Base,
Entity: &milvuspb.GrantEntity{
Role: &milvuspb.RoleEntity{
Name: role,
Expand All @@ -339,6 +345,7 @@ func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.P
},
},
ObjectName: object,
DbName: grantOpt.Database,
},
Type: milvuspb.OperatePrivilegeType_Grant,
}
Expand All @@ -352,12 +359,17 @@ func (c *GrpcClient) Grant(ctx context.Context, role string, objectType entity.P
}

// Revoke removes privilege from role.
func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string) error {
func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity.PriviledgeObjectType, object string, privilege string, options ...entity.OperatePrivilegeOption) error {
if c.Service == nil {
return ErrClientNotReady
}
revokeOpt := &entity.OperatePrivilegeOpt{}
for _, opt := range options {
opt(revokeOpt)
}

req := &milvuspb.OperatePrivilegeRequest{
Base: revokeOpt.Base,
Entity: &milvuspb.GrantEntity{
Role: &milvuspb.RoleEntity{
Name: role,
Expand All @@ -366,6 +378,12 @@ func (c *GrpcClient) Revoke(ctx context.Context, role string, objectType entity.
Name: commonpb.ObjectType_name[int32(objectType)],
},
ObjectName: object,
Grantor: &milvuspb.GrantorEntity{
Privilege: &milvuspb.PrivilegeEntity{
Name: privilege,
},
},
DbName: revokeOpt.Database,
},
Type: milvuspb.OperatePrivilegeType_Revoke,
}
Expand Down
25 changes: 16 additions & 9 deletions client/rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ func (s *RBACSuite) TestGrant() {
roleName := "testRole"
objectName := testCollectionName
objectType := entity.PriviledegeObjectTypeCollection
privilegeName := "testPrivilege"
dbName := "testDB"
privilege := "testPrivilege"

s.Run("normal run", func() {
ctx, cancel := context.WithCancel(ctx)
Expand All @@ -640,9 +641,11 @@ func (s *RBACSuite) TestGrant() {
s.Equal(objectName, req.GetEntity().GetObjectName())
s.Equal(commonpb.ObjectType_name[int32(objectType)], req.GetEntity().GetObject().GetName())
s.Equal(milvuspb.OperatePrivilegeType_Grant, req.GetType())
s.Equal(privilege, req.GetEntity().GetGrantor().GetPrivilege().GetName())
s.Equal(dbName, req.GetEntity().GetDbName())
}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)

err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName)
err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))

s.NoError(err)
})
Expand All @@ -653,7 +656,7 @@ func (s *RBACSuite) TestGrant() {
defer s.resetMock()
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))

err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName)
err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
})

Expand All @@ -663,7 +666,7 @@ func (s *RBACSuite) TestGrant() {
defer s.resetMock()
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)

err := s.client.Grant(ctx, roleName, objectType, objectName, privilegeName)
err := s.client.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
})

Expand All @@ -672,7 +675,7 @@ func (s *RBACSuite) TestGrant() {
defer cancel()

c := &GrpcClient{}
err := c.Grant(ctx, roleName, objectType, objectName, privilegeName)
err := c.Grant(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
s.ErrorIs(err, ErrClientNotReady)
})
Expand All @@ -684,6 +687,8 @@ func (s *RBACSuite) TestRevoke() {
roleName := "testRole"
objectName := testCollectionName
objectType := entity.PriviledegeObjectTypeCollection
dbName := "testDB"
privilege := "testPrivilege"

s.Run("normal run", func() {
ctx, cancel := context.WithCancel(ctx)
Expand All @@ -694,9 +699,11 @@ func (s *RBACSuite) TestRevoke() {
s.Equal(objectName, req.GetEntity().GetObjectName())
s.Equal(commonpb.ObjectType_name[int32(objectType)], req.GetEntity().GetObject().GetName())
s.Equal(milvuspb.OperatePrivilegeType_Revoke, req.GetType())
s.Equal(privilege, req.GetEntity().GetGrantor().GetPrivilege().GetName())
s.Equal(dbName, req.GetEntity().GetDbName())
}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)

err := s.client.Revoke(ctx, roleName, objectType, objectName)
err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))

s.NoError(err)
})
Expand All @@ -707,7 +714,7 @@ func (s *RBACSuite) TestRevoke() {
defer s.resetMock()
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))

err := s.client.Revoke(ctx, roleName, objectType, objectName)
err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
})

Expand All @@ -717,7 +724,7 @@ func (s *RBACSuite) TestRevoke() {
defer s.resetMock()
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil)

err := s.client.Revoke(ctx, roleName, objectType, objectName)
err := s.client.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
})

Expand All @@ -727,7 +734,7 @@ func (s *RBACSuite) TestRevoke() {
defer s.resetMock()

c := &GrpcClient{}
err := c.Revoke(ctx, roleName, objectType, objectName)
err := c.Revoke(ctx, roleName, objectType, objectName, privilege, entity.WithOperatePrivilegeDatabase(dbName))
s.Error(err)
s.ErrorIs(err, ErrClientNotReady)
})
Expand Down
19 changes: 19 additions & 0 deletions entity/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,22 @@ const (
// PriviledegeObjectTypeGlobal const value for Global.
PriviledegeObjectTypeGlobal PriviledgeObjectType = PriviledgeObjectType(common.ObjectType_Global)
)

type OperatePrivilegeOpt struct {
Base *common.MsgBase
Database string
}

type OperatePrivilegeOption func(o *OperatePrivilegeOpt)

func WithOperatePrivilegeBase(base *common.MsgBase) OperatePrivilegeOption {
return func(o *OperatePrivilegeOpt) {
o.Base = base
}
}

func WithOperatePrivilegeDatabase(database string) OperatePrivilegeOption {
return func(o *OperatePrivilegeOpt) {
o.Database = database
}
}

0 comments on commit 0c8670f

Please sign in to comment.