diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index bf20e51df9..6c26c147e4 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -17,7 +17,9 @@ from feast import proto_json, utils from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL from feast.data_source import PushMode -from feast.errors import PushSourceNotFoundException +from feast.errors import FeatureViewNotFoundException, PushSourceNotFoundException +from feast.permissions.action import WRITE, AuthzedAction +from feast.permissions.security_manager import assert_permissions # TODO: deprecate this in favor of push features @@ -86,19 +88,40 @@ async def get_body(request: Request): def get_online_features(body=Depends(get_body)): try: body = json.loads(body) + full_feature_names = body.get("full_feature_names", False) + entity_rows = body["entities"] # Initialize parameters for FeatureStore.get_online_features(...) call if "feature_service" in body: - features = store.get_feature_service( + feature_service = store.get_feature_service( body["feature_service"], allow_cache=True ) + assert_permissions( + resource=feature_service, actions=[AuthzedAction.QUERY_ONLINE] + ) + features = feature_service else: features = body["features"] - - full_feature_names = body.get("full_feature_names", False) + all_feature_views, all_on_demand_feature_views = ( + utils._get_feature_views_to_use( + store.registry, + store.project, + features, + allow_cache=True, + hide_dummy_entity=False, + ) + ) + for feature_view in all_feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.QUERY_ONLINE] + ) + for od_feature_view in all_on_demand_feature_views: + assert_permissions( + resource=od_feature_view, actions=[AuthzedAction.QUERY_ONLINE] + ) response_proto = store.get_online_features( features=features, - entity_rows=body["entities"], + entity_rows=entity_rows, full_feature_names=full_feature_names, ).proto @@ -117,16 +140,41 @@ def push(body=Depends(get_body)): try: request = PushFeaturesRequest(**json.loads(body)) df = pd.DataFrame(request.df) + actions = [] if request.to == "offline": to = PushMode.OFFLINE + actions = [AuthzedAction.WRITE_OFFLINE] elif request.to == "online": to = PushMode.ONLINE + actions = [AuthzedAction.WRITE_ONLINE] elif request.to == "online_and_offline": to = PushMode.ONLINE_AND_OFFLINE + actions = WRITE else: raise ValueError( f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." ) + + from feast.data_source import PushSource + + all_fvs = store.list_feature_views( + allow_cache=request.allow_registry_cache + ) + store.list_stream_feature_views( + allow_cache=request.allow_registry_cache + ) + fvs_with_push_sources = { + fv + for fv in all_fvs + if ( + fv.stream_source is not None + and isinstance(fv.stream_source, PushSource) + and fv.stream_source.name == request.push_source_name + ) + } + + for feature_view in fvs_with_push_sources: + assert_permissions(resource=feature_view, actions=actions) + store.push( push_source_name=request.push_source_name, df=df, @@ -149,10 +197,24 @@ def write_to_online_store(body=Depends(get_body)): try: request = WriteToFeatureStoreRequest(**json.loads(body)) df = pd.DataFrame(request.df) + feature_view_name = request.feature_view_name + allow_registry_cache = request.allow_registry_cache + try: + feature_view = store.get_stream_feature_view( + feature_view_name, allow_registry_cache=allow_registry_cache + ) + except FeatureViewNotFoundException: + feature_view = store.get_feature_view( + feature_view_name, allow_registry_cache=allow_registry_cache + ) + + assert_permissions( + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] + ) store.write_to_online_store( - feature_view_name=request.feature_view_name, + feature_view_name=feature_view_name, df=df, - allow_registry_cache=request.allow_registry_cache, + allow_registry_cache=allow_registry_cache, ) except Exception as e: # Print the original exception on the server side @@ -168,6 +230,10 @@ def health(): def materialize(body=Depends(get_body)): try: request = MaterializeRequest(**json.loads(body)) + for feature_view in request.feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] + ) store.materialize( utils.make_tzaware(parser.parse(request.start_ts)), utils.make_tzaware(parser.parse(request.end_ts)), @@ -183,6 +249,10 @@ def materialize(body=Depends(get_body)): def materialize_incremental(body=Depends(get_body)): try: request = MaterializeIncrementalRequest(**json.loads(body)) + for feature_view in request.feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] + ) store.materialize_incremental( utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views ) diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index be92620d68..a16dcb8932 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -3,7 +3,7 @@ import logging import traceback from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, cast import pyarrow as pa import pyarrow.flight as fl @@ -12,6 +12,8 @@ from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_view import DUMMY_ENTITY_NAME from feast.infra.offline_stores.offline_utils import get_offline_store_from_config +from feast.permissions.action import AuthzedAction +from feast.permissions.security_manager import assert_permissions from feast.saved_dataset import SavedDatasetStorage logger = logging.getLogger(__name__) @@ -217,7 +219,15 @@ def offline_write_batch(self, command: dict, key: str): assert len(feature_views) == 1, "incorrect feature view" table = self.flights[key] self.offline_store.offline_write_batch( - self.store.config, feature_views[0], table, command["progress"] + self.store.config, + cast( + FeatureView, + assert_permissions( + feature_views[0], actions=[AuthzedAction.WRITE_OFFLINE] + ), + ), + table, + command["progress"], ) def _validate_write_logged_features_parameters(self, command: dict): @@ -234,6 +244,10 @@ def write_logged_features(self, command: dict, key: str): feature_service.logging_config is not None ), "feature service must have logging_config set" + assert_permissions( + resource=feature_service, + actions=[AuthzedAction.WRITE_OFFLINE], + ) self.offline_store.write_logged_features( config=self.store.config, data=table, @@ -260,10 +274,12 @@ def _validate_pull_all_from_table_or_query_parameters(self, command: dict): def pull_all_from_table_or_query(self, command: dict): self._validate_pull_all_from_table_or_query_parameters(command) + data_source = self.store.get_data_source(command["data_source_name"]) + assert_permissions(data_source, actions=[AuthzedAction.QUERY_OFFLINE]) return self.offline_store.pull_all_from_table_or_query( self.store.config, - self.store.get_data_source(command["data_source_name"]), + data_source, command["join_key_columns"], command["feature_name_columns"], command["timestamp_field"], @@ -287,10 +303,11 @@ def _validate_pull_latest_from_table_or_query_parameters(self, command: dict): def pull_latest_from_table_or_query(self, command: dict): self._validate_pull_latest_from_table_or_query_parameters(command) - + data_source = self.store.get_data_source(command["data_source_name"]) + assert_permissions(resource=data_source, actions=[AuthzedAction.QUERY_OFFLINE]) return self.offline_store.pull_latest_from_table_or_query( self.store.config, - self.store.get_data_source(command["data_source_name"]), + data_source, command["join_key_columns"], command["feature_name_columns"], command["timestamp_field"], @@ -343,6 +360,11 @@ def get_historical_features(self, command: dict, key: str): project=project, ) + for feature_view in feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.QUERY_OFFLINE] + ) + retJob = self.offline_store.get_historical_features( config=self.store.config, feature_views=feature_views, @@ -377,6 +399,10 @@ def persist(self, command: dict, key: str): raise NotImplementedError data_source = self.store.get_data_source(command["data_source_name"]) + assert_permissions( + resource=data_source, + actions=[AuthzedAction.WRITE_OFFLINE], + ) storage = SavedDatasetStorage.from_data_source(data_source) ret_job.persist(storage, command["allow_overwrite"], command["timeout"]) except Exception as e: diff --git a/sdk/python/feast/permissions/action.py b/sdk/python/feast/permissions/action.py index 82125848a3..09bce94511 100644 --- a/sdk/python/feast/permissions/action.py +++ b/sdk/python/feast/permissions/action.py @@ -29,3 +29,12 @@ class AuthzedAction(enum.Enum): AuthzedAction.WRITE_OFFLINE, AuthzedAction.WRITE_ONLINE, ] + + +# Alias for CRUD actions +CRUD = [ + AuthzedAction.CREATE, + AuthzedAction.READ, + AuthzedAction.UPDATE, + AuthzedAction.DELETE, +] diff --git a/sdk/python/feast/registry_server.py b/sdk/python/feast/registry_server.py index 17a4c2b5fe..d9cecdff15 100644 --- a/sdk/python/feast/registry_server.py +++ b/sdk/python/feast/registry_server.py @@ -1,19 +1,22 @@ from concurrent import futures from datetime import datetime +from typing import cast import grpc from google.protobuf.empty_pb2 import Empty from pytz import utc -from feast import FeatureStore +from feast import FeatureService, FeatureStore from feast.data_source import DataSource from feast.entity import Entity -from feast.feature_service import FeatureService +from feast.feast_object import FeastObject from feast.feature_view import FeatureView from feast.infra.infra_object import Infra from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.action import CRUD, AuthzedAction from feast.permissions.permission import Permission +from feast.permissions.security_manager import assert_permissions, permitted_resources from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc from feast.saved_dataset import SavedDataset, ValidationReference from feast.stream_feature_view import StreamFeatureView @@ -25,31 +28,52 @@ def __init__(self, registry: BaseRegistry) -> None: self.proxied_registry = registry def ApplyEntity(self, request: RegistryServer_pb2.ApplyEntityRequest, context): - self.proxied_registry.apply_entity( - entity=Entity.from_proto(request.entity), - project=request.project, - commit=request.commit, + assert_permissions( + resource=self.proxied_registry.apply_entity( + entity=Entity.from_proto(request.entity), + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) return Empty() def GetEntity(self, request: RegistryServer_pb2.GetEntityRequest, context): - return self.proxied_registry.get_entity( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + self.proxied_registry.get_entity( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListEntities(self, request: RegistryServer_pb2.ListEntitiesRequest, context): return RegistryServer_pb2.ListEntitiesResponse( entities=[ entity.to_proto() - for entity in self.proxied_registry.list_entities( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for entity in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_entities( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) def DeleteEntity(self, request: RegistryServer_pb2.DeleteEntityRequest, context): + assert_permissions( + resource=self.proxied_registry.get_entity( + name=request.name, project=request.project + ), + actions=AuthzedAction.DELETE, + ) + self.proxied_registry.delete_entity( name=request.name, project=request.project, commit=request.commit ) @@ -58,16 +82,24 @@ def DeleteEntity(self, request: RegistryServer_pb2.DeleteEntityRequest, context) def ApplyDataSource( self, request: RegistryServer_pb2.ApplyDataSourceRequest, context ): - self.proxied_registry.apply_data_source( - data_source=DataSource.from_proto(request.data_source), - project=request.project, - commit=request.commit, + assert_permissions( + resource=self.proxied_registry.apply_data_source( + data_source=DataSource.from_proto(request.data_source), + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) return Empty() def GetDataSource(self, request: RegistryServer_pb2.GetDataSourceRequest, context): - return self.proxied_registry.get_data_source( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + resource=self.proxied_registry.get_data_source( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=AuthzedAction.READ, ).to_proto() def ListDataSources( @@ -76,10 +108,16 @@ def ListDataSources( return RegistryServer_pb2.ListDataSourcesResponse( data_sources=[ data_source.to_proto() - for data_source in self.proxied_registry.list_data_sources( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for data_source in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_data_sources( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -87,6 +125,14 @@ def ListDataSources( def DeleteDataSource( self, request: RegistryServer_pb2.DeleteDataSourceRequest, context ): + assert_permissions( + resource=self.proxied_registry.get_data_source( + name=request.name, + project=request.project, + ), + actions=AuthzedAction.DELETE, + ) + self.proxied_registry.delete_data_source( name=request.name, project=request.project, commit=request.commit ) @@ -95,8 +141,13 @@ def DeleteDataSource( def GetFeatureView( self, request: RegistryServer_pb2.GetFeatureViewRequest, context ): - return self.proxied_registry.get_feature_view( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + self.proxied_registry.get_feature_view( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ApplyFeatureView( @@ -112,8 +163,13 @@ def ApplyFeatureView( elif feature_view_type == "stream_feature_view": feature_view = StreamFeatureView.from_proto(request.stream_feature_view) - self.proxied_registry.apply_feature_view( - feature_view=feature_view, project=request.project, commit=request.commit + assert_permissions( + resource=self.proxied_registry.apply_feature_view( + feature_view=feature_view, + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) return Empty() @@ -123,10 +179,16 @@ def ListFeatureViews( return RegistryServer_pb2.ListFeatureViewsResponse( feature_views=[ feature_view.to_proto() - for feature_view in self.proxied_registry.list_feature_views( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for feature_view in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_feature_views( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -134,6 +196,12 @@ def ListFeatureViews( def DeleteFeatureView( self, request: RegistryServer_pb2.DeleteFeatureViewRequest, context ): + assert_permissions( + resource=self.proxied_registry.get_feature_view( + name=request.name, project=request.project + ), + actions=[AuthzedAction.DELETE], + ) self.proxied_registry.delete_feature_view( name=request.name, project=request.project, commit=request.commit ) @@ -142,8 +210,13 @@ def DeleteFeatureView( def GetStreamFeatureView( self, request: RegistryServer_pb2.GetStreamFeatureViewRequest, context ): - return self.proxied_registry.get_stream_feature_view( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + resource=self.proxied_registry.get_stream_feature_view( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListStreamFeatureViews( @@ -152,10 +225,16 @@ def ListStreamFeatureViews( return RegistryServer_pb2.ListStreamFeatureViewsResponse( stream_feature_views=[ stream_feature_view.to_proto() - for stream_feature_view in self.proxied_registry.list_stream_feature_views( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for stream_feature_view in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_stream_feature_views( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -163,8 +242,13 @@ def ListStreamFeatureViews( def GetOnDemandFeatureView( self, request: RegistryServer_pb2.GetOnDemandFeatureViewRequest, context ): - return self.proxied_registry.get_on_demand_feature_view( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + resource=self.proxied_registry.get_on_demand_feature_view( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListOnDemandFeatureViews( @@ -173,10 +257,16 @@ def ListOnDemandFeatureViews( return RegistryServer_pb2.ListOnDemandFeatureViewsResponse( on_demand_feature_views=[ on_demand_feature_view.to_proto() - for on_demand_feature_view in self.proxied_registry.list_on_demand_feature_views( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for on_demand_feature_view in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_on_demand_feature_views( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -184,18 +274,27 @@ def ListOnDemandFeatureViews( def ApplyFeatureService( self, request: RegistryServer_pb2.ApplyFeatureServiceRequest, context ): - self.proxied_registry.apply_feature_service( - feature_service=FeatureService.from_proto(request.feature_service), - project=request.project, - commit=request.commit, + assert_permissions( + resource=self.proxied_registry.apply_feature_service( + feature_service=FeatureService.from_proto(request.feature_service), + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) + return Empty() def GetFeatureService( self, request: RegistryServer_pb2.GetFeatureServiceRequest, context ): - return self.proxied_registry.get_feature_service( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + resource=self.proxied_registry.get_feature_service( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListFeatureServices( @@ -204,10 +303,16 @@ def ListFeatureServices( return RegistryServer_pb2.ListFeatureServicesResponse( feature_services=[ feature_service.to_proto() - for feature_service in self.proxied_registry.list_feature_services( - project=request.project, - allow_cache=request.allow_cache, - tags=dict(request.tags), + for feature_service in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_feature_services( + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -215,6 +320,15 @@ def ListFeatureServices( def DeleteFeatureService( self, request: RegistryServer_pb2.DeleteFeatureServiceRequest, context ): + ( + assert_permissions( + resource=self.proxied_registry.get_feature_service( + name=request.name, project=request.project + ), + actions=[AuthzedAction.DELETE], + ), + ) + self.proxied_registry.delete_feature_service( name=request.name, project=request.project, commit=request.commit ) @@ -223,18 +337,27 @@ def DeleteFeatureService( def ApplySavedDataset( self, request: RegistryServer_pb2.ApplySavedDatasetRequest, context ): - self.proxied_registry.apply_saved_dataset( - saved_dataset=SavedDataset.from_proto(request.saved_dataset), - project=request.project, - commit=request.commit, + assert_permissions( + resource=self.proxied_registry.apply_saved_dataset( + saved_dataset=SavedDataset.from_proto(request.saved_dataset), + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) + return Empty() def GetSavedDataset( self, request: RegistryServer_pb2.GetSavedDatasetRequest, context ): - return self.proxied_registry.get_saved_dataset( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + self.proxied_registry.get_saved_dataset( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListSavedDatasets( @@ -243,8 +366,14 @@ def ListSavedDatasets( return RegistryServer_pb2.ListSavedDatasetsResponse( saved_datasets=[ saved_dataset.to_proto() - for saved_dataset in self.proxied_registry.list_saved_datasets( - project=request.project, allow_cache=request.allow_cache + for saved_dataset in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_saved_datasets( + project=request.project, allow_cache=request.allow_cache + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -252,6 +381,13 @@ def ListSavedDatasets( def DeleteSavedDataset( self, request: RegistryServer_pb2.DeleteSavedDatasetRequest, context ): + assert_permissions( + resource=self.proxied_registry.get_saved_dataset( + name=request.name, project=request.project + ), + actions=[AuthzedAction.DELETE], + ) + self.proxied_registry.delete_saved_dataset( name=request.name, project=request.project, commit=request.commit ) @@ -260,20 +396,29 @@ def DeleteSavedDataset( def ApplyValidationReference( self, request: RegistryServer_pb2.ApplyValidationReferenceRequest, context ): - self.proxied_registry.apply_validation_reference( - validation_reference=ValidationReference.from_proto( - request.validation_reference + assert_permissions( + resource=self.proxied_registry.apply_validation_reference( + validation_reference=ValidationReference.from_proto( + request.validation_reference + ), + project=request.project, + commit=request.commit, ), - project=request.project, - commit=request.commit, + actions=CRUD, ) + return Empty() def GetValidationReference( self, request: RegistryServer_pb2.GetValidationReferenceRequest, context ): - return self.proxied_registry.get_validation_reference( - name=request.name, project=request.project, allow_cache=request.allow_cache + return assert_permissions( + self.proxied_registry.get_validation_reference( + name=request.name, + project=request.project, + allow_cache=request.allow_cache, + ), + actions=[AuthzedAction.READ], ).to_proto() def ListValidationReferences( @@ -282,8 +427,15 @@ def ListValidationReferences( return RegistryServer_pb2.ListValidationReferencesResponse( validation_references=[ validation_reference.to_proto() - for validation_reference in self.proxied_registry.list_validation_references( - project=request.project, allow_cache=request.allow_cache + for validation_reference in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_validation_references( + project=request.project, + allow_cache=request.allow_cache, + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -291,6 +443,12 @@ def ListValidationReferences( def DeleteValidationReference( self, request: RegistryServer_pb2.DeleteValidationReferenceRequest, context ): + assert_permissions( + resource=self.proxied_registry.get_validation_reference( + name=request.name, project=request.project + ), + actions=[AuthzedAction.DELETE], + ) self.proxied_registry.delete_validation_reference( name=request.name, project=request.project, commit=request.commit ) @@ -311,6 +469,11 @@ def ListProjectMetadata( def ApplyMaterialization( self, request: RegistryServer_pb2.ApplyMaterializationRequest, context ): + assert_permissions( + resource=FeatureView.from_proto(request.feature_view), + actions=[AuthzedAction.WRITE_ONLINE], + ) + self.proxied_registry.apply_materialization( feature_view=FeatureView.from_proto(request.feature_view), project=request.project, @@ -340,19 +503,26 @@ def GetInfra(self, request: RegistryServer_pb2.GetInfraRequest, context): def ApplyPermission( self, request: RegistryServer_pb2.ApplyPermissionRequest, context ): - self.proxied_registry.apply_permission( - permission=Permission.from_proto(request.permission), - project=request.project, - commit=request.commit, + assert_permissions( + self.proxied_registry.apply_permission( + permission=Permission.from_proto(request.permission), + project=request.project, + commit=request.commit, + ), + actions=CRUD, ) return Empty() def GetPermission(self, request: RegistryServer_pb2.GetPermissionRequest, context): permission = self.proxied_registry.get_permission( name=request.name, project=request.project, allow_cache=request.allow_cache - ).to_proto() + ) + assert_permissions( + resource=permission, + actions=[AuthzedAction.READ], + ) + permission.to_proto().project = request.project - permission.project = request.project return permission def ListPermissions( @@ -361,9 +531,14 @@ def ListPermissions( return RegistryServer_pb2.ListPermissionsResponse( permissions=[ permission.to_proto() - for permission in self.proxied_registry.list_permissions( - project=request.project, - allow_cache=request.allow_cache, + for permission in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_permissions( + project=request.project, allow_cache=request.allow_cache + ), + ), + actions=AuthzedAction.READ, ) ] ) @@ -371,6 +546,14 @@ def ListPermissions( def DeletePermission( self, request: RegistryServer_pb2.DeletePermissionRequest, context ): + assert_permissions( + resource=self.proxied_registry.get_permission( + name=request.name, + project=request.project, + ), + actions=[AuthzedAction.DELETE], + ) + self.proxied_registry.delete_permission( name=request.name, project=request.project, commit=request.commit ) diff --git a/sdk/python/tests/unit/diff/test_registry_diff.py b/sdk/python/tests/unit/diff/test_registry_diff.py index 4f01cebe45..08d33c2366 100644 --- a/sdk/python/tests/unit/diff/test_registry_diff.py +++ b/sdk/python/tests/unit/diff/test_registry_diff.py @@ -6,14 +6,14 @@ tag_objects_for_keep_delete_update_add, ) from feast.entity import Entity +from feast.feast_object import ALL_RESOURCE_TYPES from feast.feature_view import FeatureView from feast.on_demand_feature_view import on_demand_feature_view -from feast.types import String -from tests.utils.data_source_test_creator import prep_file_source from feast.permissions.action import AuthzedAction from feast.permissions.permission import Permission -from feast.feast_object import ALL_RESOURCE_TYPES from feast.permissions.policy import RoleBasedPolicy +from feast.types import String +from tests.utils.data_source_test_creator import prep_file_source def test_tag_objects_for_keep_delete_update_add(simple_dataset_1): @@ -175,6 +175,7 @@ def test_diff_registry_objects_batch_to_push_source(simple_dataset_1): == "stream_source" ) + def test_diff_registry_objects_permissions(): pre_changed = Permission( name="reader", @@ -191,11 +192,6 @@ def test_diff_registry_objects_permissions(): actions=[AuthzedAction.CREATE], ) - feast_object_diffs = diff_registry_objects( - pre_changed, post_changed, "permission" - ) + feast_object_diffs = diff_registry_objects(pre_changed, post_changed, "permission") assert len(feast_object_diffs.feast_object_property_diffs) == 1 - assert ( - feast_object_diffs.feast_object_property_diffs[0].property_name - == "actions" - ) + assert feast_object_diffs.feast_object_property_diffs[0].property_name == "actions"