From 16317d0e10376f1c282d71dc18404b44baf20364 Mon Sep 17 00:00:00 2001 From: Tommy Hughes Date: Wed, 29 May 2024 12:07:35 -0500 Subject: [PATCH] add tags filtering capability to 'list' for objects Signed-off-by: Tommy Hughes --- protos/feast/registry/RegistryServer.proto | 6 ++ sdk/python/feast/cli.py | 37 +++++--- sdk/python/feast/feature_store.py | 86 ++++++++++++++----- .../feast/infra/registry/base_registry.py | 41 +++++++-- .../feast/infra/registry/caching_registry.py | 80 +++++++++++------ .../infra/registry/proto_registry_utils.py | 77 +++++++++++++---- sdk/python/feast/infra/registry/registry.py | 46 +++++++--- sdk/python/feast/infra/registry/remote.py | 44 +++++++--- sdk/python/feast/infra/registry/snowflake.py | 66 ++++++++++---- sdk/python/feast/infra/registry/sql.py | 51 ++++++++--- sdk/python/feast/registry_server.py | 42 ++++++--- sdk/python/feast/utils.py | 24 +++++- .../example_repos/example_feature_repo_1.py | 5 +- ...ple_feature_repo_with_feature_service_2.py | 2 +- .../feature_repos/universal/feature_views.py | 6 ++ .../online_store/test_universal_online.py | 16 ++++ .../local_feast_tests/test_feature_service.py | 5 ++ .../test_local_feature_store.py | 66 ++++++++++++-- .../online_store/test_online_retrieval.py | 4 + 19 files changed, 553 insertions(+), 151 deletions(-) diff --git a/protos/feast/registry/RegistryServer.proto b/protos/feast/registry/RegistryServer.proto index 3ca7398fdc1..44529f5409c 100644 --- a/protos/feast/registry/RegistryServer.proto +++ b/protos/feast/registry/RegistryServer.proto @@ -117,6 +117,7 @@ message GetEntityRequest { message ListEntitiesRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListEntitiesResponse { @@ -146,6 +147,7 @@ message GetDataSourceRequest { message ListDataSourcesRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListDataSourcesResponse { @@ -179,6 +181,7 @@ message GetFeatureViewRequest { message ListFeatureViewsRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListFeatureViewsResponse { @@ -202,6 +205,7 @@ message GetStreamFeatureViewRequest { message ListStreamFeatureViewsRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListStreamFeatureViewsResponse { @@ -219,6 +223,7 @@ message GetOnDemandFeatureViewRequest { message ListOnDemandFeatureViewsRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListOnDemandFeatureViewsResponse { @@ -242,6 +247,7 @@ message GetFeatureServiceRequest { message ListFeatureServicesRequest { string project = 1; bool allow_cache = 2; + map tags = 3; } message ListFeatureServicesResponse { diff --git a/sdk/python/feast/cli.py b/sdk/python/feast/cli.py index eeffc29fab0..76bcefa9ac6 100644 --- a/sdk/python/feast/cli.py +++ b/sdk/python/feast/cli.py @@ -47,6 +47,11 @@ from feast.utils import maybe_local_tz _logger = logging.getLogger(__name__) +tagsOption = click.option( + "--tags", + help="Filter by tags (e.g. 'key:value, key:value, ...')", + multiple=True, +) class NoOptionDefaultFormat(click.Command): @@ -226,14 +231,16 @@ def data_source_describe(ctx: click.Context, name: str): @data_sources_cmd.command(name="list") +@tagsOption @click.pass_context -def data_source_list(ctx: click.Context): +def data_source_list(ctx: click.Context, tags: Optional[str]): """ List all data sources """ store = create_feature_store(ctx) table = [] - for datasource in store.list_data_sources(): + tags_filter = utils.tags_str_to_dict(tags) + for datasource in store.list_data_sources(tags=tags_filter): table.append([datasource.name, datasource.__class__]) from tabulate import tabulate @@ -272,14 +279,16 @@ def entity_describe(ctx: click.Context, name: str): @entities_cmd.command(name="list") +@tagsOption @click.pass_context -def entity_list(ctx: click.Context): +def entity_list(ctx: click.Context, tags: Optional[str]): """ List all entities """ store = create_feature_store(ctx) table = [] - for entity in store.list_entities(): + tags_filter = utils.tags_str_to_dict(tags) + for entity in store.list_entities(tags=tags_filter): table.append([entity.name, entity.description, entity.value_type]) from tabulate import tabulate @@ -320,14 +329,16 @@ def feature_service_describe(ctx: click.Context, name: str): @feature_services_cmd.command(name="list") +@tagsOption @click.pass_context -def feature_service_list(ctx: click.Context): +def feature_service_list(ctx: click.Context, tags: Optional[str]): """ List all feature services """ store = create_feature_store(ctx) feature_services = [] - for feature_service in store.list_feature_services(): + tags_filter = utils.tags_str_to_dict(tags) + for feature_service in store.list_feature_services(tags=tags_filter): feature_names = [] for projection in feature_service.feature_view_projections: feature_names.extend( @@ -371,16 +382,18 @@ def feature_view_describe(ctx: click.Context, name: str): @feature_views_cmd.command(name="list") +@tagsOption @click.pass_context -def feature_view_list(ctx: click.Context): +def feature_view_list(ctx: click.Context, tags: Optional[str]): """ List all feature views """ store = create_feature_store(ctx) table = [] + tags_filter = utils.tags_str_to_dict(tags) for feature_view in [ - *store.list_feature_views(), - *store.list_on_demand_feature_views(), + *store.list_batch_feature_views(tags=tags_filter), + *store.list_on_demand_feature_views(tags=tags_filter), ]: entities = set() if isinstance(feature_view, FeatureView): @@ -434,14 +447,16 @@ def on_demand_feature_view_describe(ctx: click.Context, name: str): @on_demand_feature_views_cmd.command(name="list") +@tagsOption @click.pass_context -def on_demand_feature_view_list(ctx: click.Context): +def on_demand_feature_view_list(ctx: click.Context, tags: Optional[str]): """ [Experimental] List all on demand feature views """ store = create_feature_store(ctx) table = [] - for on_demand_feature_view in store.list_on_demand_feature_views(): + tags_filter = utils.tags_str_to_dict(tags) + for on_demand_feature_view in store.list_on_demand_feature_views(tags=tags_filter): table.append([on_demand_feature_view.name]) from tabulate import tabulate diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 716e706ebe5..93c792eb05e 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -215,23 +215,29 @@ def refresh_registry(self): self._registry = registry - def list_entities(self, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + ) -> List[Entity]: """ Retrieves the list of entities from the registry. Args: allow_cache: Whether to allow returning entities from a cached registry. + tags: Filter by tags. Returns: A list of entities. """ - return self._list_entities(allow_cache) + return self._list_entities(allow_cache, tags=tags) def _list_entities( - self, allow_cache: bool = False, hide_dummy_entity: bool = True + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[Entity]: all_entities = self._registry.list_entities( - self.project, allow_cache=allow_cache + self.project, allow_cache=allow_cache, tags=tags ) return [ entity @@ -239,17 +245,22 @@ def _list_entities( if entity.name != DUMMY_ENTITY_NAME or not hide_dummy_entity ] - def list_feature_services(self) -> List[FeatureService]: + def list_feature_services( + self, tags: Optional[dict[str, str]] = None + ) -> List[FeatureService]: """ Retrieves the list of feature services from the registry. + Args: + tags: Filter by tags. + Returns: A list of feature services. """ - return self._registry.list_feature_services(self.project) + return self._registry.list_feature_services(self.project, tags=tags) def list_all_feature_views( - self, allow_cache: bool = False + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[Union[FeatureView, StreamFeatureView, OnDemandFeatureView]]: """ Retrieves the list of feature views from the registry. @@ -260,14 +271,17 @@ def list_all_feature_views( Returns: A list of feature views. """ - return self._list_all_feature_views(allow_cache) + return self._list_all_feature_views(allow_cache, tags=tags) - def list_feature_views(self, allow_cache: bool = False) -> List[FeatureView]: + def list_feature_views( + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + ) -> List[FeatureView]: """ Retrieves the list of feature views from the registry. Args: allow_cache: Whether to allow returning entities from a cached registry. + tags: Filter by tags. Returns: A list of feature views. @@ -276,16 +290,32 @@ def list_feature_views(self, allow_cache: bool = False) -> List[FeatureView]: "list_feature_views will make breaking changes. Please use list_batch_feature_views instead. " "list_feature_views will behave like list_all_feature_views in the future." ) - return self._list_feature_views(allow_cache) + return self._list_feature_views(allow_cache=allow_cache, tags=tags) + + def list_batch_feature_views( + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + ) -> List[FeatureView]: + """ + Retrieves the list of feature views from the registry. + + Args: + allow_cache: Whether to allow returning entities from a cached registry. + tags: Filter by tags. + + Returns: + A list of feature views. + """ + return self._list_batch_feature_views(allow_cache=allow_cache, tags=tags) def _list_all_feature_views( self, allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[Union[FeatureView, StreamFeatureView, OnDemandFeatureView]]: all_feature_views = ( - self._list_feature_views(allow_cache) - + self._list_stream_feature_views(allow_cache) - + self.list_on_demand_feature_views(allow_cache) + self._list_feature_views(allow_cache, tags=tags) + + self._list_stream_feature_views(allow_cache, tags=tags) + + self.list_on_demand_feature_views(allow_cache, tags=tags) ) return all_feature_views @@ -293,6 +323,7 @@ def _list_feature_views( self, allow_cache: bool = False, hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: logging.warning( "_list_feature_views will make breaking changes. Please use _list_batch_feature_views instead. " @@ -300,7 +331,7 @@ def _list_feature_views( ) feature_views = [] for fv in self._registry.list_feature_views( - self.project, allow_cache=allow_cache + self.project, allow_cache=allow_cache, tags=tags ): if ( hide_dummy_entity @@ -316,10 +347,11 @@ def _list_batch_feature_views( self, allow_cache: bool = False, hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: feature_views = [] for fv in self._registry.list_feature_views( - self.project, allow_cache=allow_cache + self.project, allow_cache=allow_cache, tags=tags ): if ( hide_dummy_entity @@ -335,10 +367,11 @@ def _list_stream_feature_views( self, allow_cache: bool = False, hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: stream_feature_views = [] for sfv in self._registry.list_stream_feature_views( - self.project, allow_cache=allow_cache + self.project, allow_cache=allow_cache, tags=tags ): if hide_dummy_entity and sfv.entities[0] == DUMMY_ENTITY_NAME: sfv.entities = [] @@ -347,20 +380,24 @@ def _list_stream_feature_views( return stream_feature_views def list_on_demand_feature_views( - self, allow_cache: bool = False + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[OnDemandFeatureView]: """ Retrieves the list of on demand feature views from the registry. + Args: + allow_cache: Whether to allow returning entities from a cached registry. + tags: Filter by tags. + Returns: A list of on demand feature views. """ return self._registry.list_on_demand_feature_views( - self.project, allow_cache=allow_cache + self.project, allow_cache=allow_cache, tags=tags ) def list_stream_feature_views( - self, allow_cache: bool = False + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[StreamFeatureView]: """ Retrieves the list of stream feature views from the registry. @@ -368,19 +405,24 @@ def list_stream_feature_views( Returns: A list of stream feature views. """ - return self._list_stream_feature_views(allow_cache) + return self._list_stream_feature_views(allow_cache, tags=tags) - def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]: + def list_data_sources( + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + ) -> List[DataSource]: """ Retrieves the list of data sources from the registry. Args: allow_cache: Whether to allow returning data sources from a cached registry. + tags: Filter by tags. Returns: A list of data sources. """ - return self._registry.list_data_sources(self.project, allow_cache=allow_cache) + return self._registry.list_data_sources( + self.project, allow_cache=allow_cache, tags=tags + ) def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity: """ diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index b52749a9b2f..bc08796e39d 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -84,13 +84,19 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti raise NotImplementedError @abstractmethod - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: """ Retrieve a list of entities from the registry Args: allow_cache: Whether to allow returning entities from a cached registry project: Filter entities based on project name + tags: Filter by tags Returns: List of entities @@ -143,7 +149,10 @@ def get_data_source( @abstractmethod def list_data_sources( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: """ Retrieve a list of data sources from the registry @@ -151,6 +160,7 @@ def list_data_sources( Args: project: Filter data source based on project name allow_cache: Whether to allow returning data sources from a cached registry + tags: Filter by tags Returns: List of data sources @@ -203,7 +213,10 @@ def get_feature_service( @abstractmethod def list_feature_services( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: """ Retrieve a list of feature services from the registry @@ -211,6 +224,7 @@ def list_feature_services( Args: allow_cache: Whether to allow returning entities from a cached registry project: Filter entities based on project name + tags: Filter by tags Returns: List of feature services @@ -265,7 +279,10 @@ def get_stream_feature_view( @abstractmethod def list_stream_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: """ Retrieve a list of stream feature views from the registry @@ -273,6 +290,7 @@ def list_stream_feature_views( Args: project: Filter stream feature views based on project name allow_cache: Whether to allow returning stream feature views from a cached registry + tags: Filter by tags Returns: List of stream feature views @@ -300,7 +318,10 @@ def get_on_demand_feature_view( @abstractmethod def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: """ Retrieve a list of on demand feature views from the registry @@ -308,6 +329,7 @@ def list_on_demand_feature_views( Args: project: Filter on demand feature views based on project name allow_cache: Whether to allow returning on demand feature views from a cached registry + tags: Filter by tags Returns: List of on demand feature views @@ -335,7 +357,10 @@ def get_feature_view( @abstractmethod def list_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: """ Retrieve a list of feature views from the registry @@ -343,6 +368,7 @@ def list_feature_views( Args: allow_cache: Allow returning feature views from the cached registry project: Filter feature views based on project name + tags: Filter by tags Returns: List of feature views @@ -598,7 +624,8 @@ def to_dict(self, project: str) -> Dict[str, List[Any]]: self._message_to_sorted_dict(data_source.to_proto()) ) for entity in sorted( - self.list_entities(project=project), key=lambda entity: entity.name + self.list_entities(project=project), + key=lambda entity: entity.name, ): registry_dict["entities"].append( self._message_to_sorted_dict(entity.to_proto()) diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 0f660128086..6336dd7fee5 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -48,18 +48,23 @@ def get_data_source( return self._get_data_source(name, project) @abstractmethod - def _list_data_sources(self, project: str) -> List[DataSource]: + def _list_data_sources( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[DataSource]: pass def list_data_sources( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_data_sources(project) + return self._list_data_sources(project, tags) @abstractmethod def _get_entity(self, name: str, project: str) -> Entity: @@ -74,16 +79,23 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti return self._get_entity(name, project) @abstractmethod - def _list_entities(self, project: str) -> List[Entity]: + def _list_entities( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Entity]: pass - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_entities( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_entities(project) + return self._list_entities(project, tags) @abstractmethod def _get_feature_view(self, name: str, project: str) -> FeatureView: @@ -100,18 +112,23 @@ def get_feature_view( return self._get_feature_view(name, project) @abstractmethod - def _list_feature_views(self, project: str) -> List[FeatureView]: + def _list_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureView]: pass def list_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_feature_views(project) + return self._list_feature_views(project, tags) @abstractmethod def _get_on_demand_feature_view( @@ -130,18 +147,23 @@ def get_on_demand_feature_view( return self._get_on_demand_feature_view(name, project) @abstractmethod - def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureView]: + def _list_on_demand_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[OnDemandFeatureView]: pass def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_on_demand_feature_views(project) + return self._list_on_demand_feature_views(project, tags) @abstractmethod def _get_stream_feature_view(self, name: str, project: str) -> StreamFeatureView: @@ -158,18 +180,23 @@ def get_stream_feature_view( return self._get_stream_feature_view(name, project) @abstractmethod - def _list_stream_feature_views(self, project: str) -> List[StreamFeatureView]: + def _list_stream_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[StreamFeatureView]: pass def list_stream_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_stream_feature_views(project) + return self._list_stream_feature_views(project, tags) @abstractmethod def _get_feature_service(self, name: str, project: str) -> FeatureService: @@ -186,18 +213,23 @@ def get_feature_service( return self._get_feature_service(name, project) @abstractmethod - def _list_feature_services(self, project: str) -> List[FeatureService]: + def _list_feature_services( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureService]: pass def list_feature_services( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) - return self._list_feature_services(project) + return self._list_feature_services(project, tags) @abstractmethod def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: diff --git a/sdk/python/feast/infra/registry/proto_registry_utils.py b/sdk/python/feast/infra/registry/proto_registry_utils.py index 60e9cfa3abc..0e85f5b0a9f 100644 --- a/sdk/python/feast/infra/registry/proto_registry_utils.py +++ b/sdk/python/feast/infra/registry/proto_registry_utils.py @@ -2,6 +2,7 @@ from functools import wraps from typing import List, Optional +from feast import utils from feast.data_source import DataSource from feast.entity import Entity from feast.errors import ( @@ -42,6 +43,30 @@ def wrapper(registry_proto: RegistryProto, project: str): return wrapper +def registry_proto_cache_with_tags(func): + cache_key = None + cache_value = None + + @wraps(func) + def wrapper( + registry_proto: RegistryProto, + project: str, + tags: Optional[dict[str, str]], + ): + nonlocal cache_key, cache_value + + key = tuple([id(registry_proto), registry_proto.version_id, project, tags]) + + if key == cache_key: + return cache_value + else: + cache_value = func(registry_proto, project, tags) + cache_key = key + return cache_value + + return wrapper + + def init_project_metadata(cached_registry_proto: RegistryProto, project: str): new_project_uuid = f"{uuid.uuid4()}" cached_registry_proto.project_metadata.append( @@ -145,68 +170,84 @@ def get_validation_reference( raise ValidationReferenceNotFound(name, project=project) -@registry_proto_cache +@registry_proto_cache_with_tags def list_feature_services( - registry_proto: RegistryProto, project: str + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] ) -> List[FeatureService]: feature_services = [] for feature_service_proto in registry_proto.feature_services: - if feature_service_proto.spec.project == project: + if feature_service_proto.spec.project == project and utils.has_all_tags( + feature_service_proto.spec.tags, tags + ): feature_services.append(FeatureService.from_proto(feature_service_proto)) return feature_services -@registry_proto_cache +@registry_proto_cache_with_tags def list_feature_views( - registry_proto: RegistryProto, project: str + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] ) -> List[FeatureView]: feature_views: List[FeatureView] = [] for feature_view_proto in registry_proto.feature_views: - if feature_view_proto.spec.project == project: + if feature_view_proto.spec.project == project and utils.has_all_tags( + feature_view_proto.spec.tags, tags + ): feature_views.append(FeatureView.from_proto(feature_view_proto)) return feature_views -@registry_proto_cache +@registry_proto_cache_with_tags def list_stream_feature_views( - registry_proto: RegistryProto, project: str + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] ) -> List[StreamFeatureView]: stream_feature_views = [] for stream_feature_view in registry_proto.stream_feature_views: - if stream_feature_view.spec.project == project: + if stream_feature_view.spec.project == project and utils.has_all_tags( + stream_feature_view.spec.tags, tags + ): stream_feature_views.append( StreamFeatureView.from_proto(stream_feature_view) ) return stream_feature_views -@registry_proto_cache +@registry_proto_cache_with_tags def list_on_demand_feature_views( - registry_proto: RegistryProto, project: str + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] ) -> List[OnDemandFeatureView]: on_demand_feature_views = [] for on_demand_feature_view in registry_proto.on_demand_feature_views: - if on_demand_feature_view.spec.project == project: + if on_demand_feature_view.spec.project == project and utils.has_all_tags( + on_demand_feature_view.spec.tags, tags + ): on_demand_feature_views.append( OnDemandFeatureView.from_proto(on_demand_feature_view) ) return on_demand_feature_views -@registry_proto_cache -def list_entities(registry_proto: RegistryProto, project: str) -> List[Entity]: +@registry_proto_cache_with_tags +def list_entities( + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] +) -> List[Entity]: entities = [] for entity_proto in registry_proto.entities: - if entity_proto.spec.project == project: + if entity_proto.spec.project == project and utils.has_all_tags( + entity_proto.spec.tags, tags + ): entities.append(Entity.from_proto(entity_proto)) return entities -@registry_proto_cache -def list_data_sources(registry_proto: RegistryProto, project: str) -> List[DataSource]: +@registry_proto_cache_with_tags +def list_data_sources( + registry_proto: RegistryProto, project: str, tags: Optional[dict[str, str]] +) -> List[DataSource]: data_sources = [] for data_source_proto in registry_proto.data_sources: - if data_source_proto.project == project: + if data_source_proto.project == project and utils.has_all_tags( + data_source_proto.tags, tags + ): data_sources.append(DataSource.from_proto(data_source_proto)) return data_sources diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index df1a419ccf7..39cdedb4906 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -272,19 +272,27 @@ def apply_entity(self, entity: Entity, project: str, commit: bool = True): if commit: self.commit() - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_entities(registry_proto, project) + return proto_registry_utils.list_entities(registry_proto, project, tags) def list_data_sources( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_data_sources(registry_proto, project) + return proto_registry_utils.list_data_sources(registry_proto, project, tags) def apply_data_source( self, data_source: DataSource, project: str, commit: bool = True @@ -344,12 +352,15 @@ def apply_feature_service( self.commit() def list_feature_services( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_feature_services(registry_proto, project) + return proto_registry_utils.list_feature_services(registry_proto, project, tags) def get_feature_service( self, name: str, project: str, allow_cache: bool = False @@ -418,21 +429,29 @@ def apply_feature_view( self.commit() def list_stream_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_stream_feature_views(registry_proto, project) + return proto_registry_utils.list_stream_feature_views( + registry_proto, project, tags + ) def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) return proto_registry_utils.list_on_demand_feature_views( - registry_proto, project + registry_proto, project, tags ) def get_on_demand_feature_view( @@ -513,12 +532,15 @@ def apply_materialization( raise FeatureViewNotFoundException(feature_view.name, project) def list_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_feature_views(registry_proto, project) + return proto_registry_utils.list_feature_views(registry_proto, project, tags) def get_feature_view( self, name: str, project: str, allow_cache: bool = False diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 4336db232fb..0eddf03cf64 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -65,9 +65,14 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti return Entity.from_proto(response) - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: request = RegistryServer_pb2.ListEntitiesRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListEntities(request) @@ -102,10 +107,13 @@ def get_data_source( return DataSource.from_proto(response) def list_data_sources( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: request = RegistryServer_pb2.ListDataSourcesRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListDataSources(request) @@ -142,10 +150,13 @@ def get_feature_service( return FeatureService.from_proto(response) def list_feature_services( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: request = RegistryServer_pb2.ListFeatureServicesRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListFeatureServices(request) @@ -200,10 +211,13 @@ def get_stream_feature_view( return StreamFeatureView.from_proto(response) def list_stream_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: request = RegistryServer_pb2.ListStreamFeatureViewsRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListStreamFeatureViews(request) @@ -225,10 +239,13 @@ def get_on_demand_feature_view( return OnDemandFeatureView.from_proto(response) def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: request = RegistryServer_pb2.ListOnDemandFeatureViewsRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListOnDemandFeatureViews(request) @@ -250,10 +267,13 @@ def get_feature_view( return FeatureView.from_proto(response) def list_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: request = RegistryServer_pb2.ListFeatureViewsRequest( - project=project, allow_cache=allow_cache + project=project, allow_cache=allow_cache, tags=tags ) response = self.stub.ListFeatureViews(request) diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 87d89af9c87..8f110322a56 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -10,6 +10,7 @@ from pydantic import ConfigDict, Field, StrictStr import feast +from feast import utils from feast.base_feature_view import BaseFeatureView from feast.data_source import DataSource from feast.entity import Entity @@ -619,34 +620,50 @@ def _get_object( # list operations def list_data_sources( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[DataSource]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( - "DATA_SOURCES", project, DataSourceProto, DataSource, "DATA_SOURCE_PROTO" + "DATA_SOURCES", + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_PROTO", + tags=tags, ) - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + def list_entities( + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Entity]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_entities( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( - "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO" + "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO", tags=tags ) def list_feature_services( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureService]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( "FEATURE_SERVICES", @@ -654,15 +671,19 @@ def list_feature_services( FeatureServiceProto, FeatureService, "FEATURE_SERVICE_PROTO", + tags=tags, ) def list_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( "FEATURE_VIEWS", @@ -670,15 +691,19 @@ def list_feature_views( FeatureViewProto, FeatureView, "FEATURE_VIEW_PROTO", + tags=tags, ) def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[OnDemandFeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( "ON_DEMAND_FEATURE_VIEWS", @@ -686,6 +711,7 @@ def list_on_demand_feature_views( OnDemandFeatureViewProto, OnDemandFeatureView, "ON_DEMAND_FEATURE_VIEW_PROTO", + tags=tags, ) def list_saved_datasets( @@ -705,12 +731,15 @@ def list_saved_datasets( ) def list_stream_feature_views( - self, project: str, allow_cache: bool = False + self, + project: str, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project + self.cached_registry_proto, project, tags ) return self._list_objects( "STREAM_FEATURE_VIEWS", @@ -718,6 +747,7 @@ def list_stream_feature_views( StreamFeatureViewProto, StreamFeatureView, "STREAM_FEATURE_VIEW_PROTO", + tags=tags, ) def list_validation_references( @@ -738,6 +768,7 @@ def _list_objects( proto_class: Any, python_class: Any, proto_field_name: str, + tags: Optional[dict[str, str]] = None, ): self._maybe_init_project_metadata(project) with GetSnowflakeConnection(self.registry_config) as conn: @@ -750,14 +781,17 @@ def _list_objects( project_id = '{project}' """ df = execute_snowflake_statement(conn, query).fetch_pandas_all() - if not df.empty: - return [ + objects = [ python_class.from_proto( proto_class.FromString(row[1][proto_field_name]) ) for row in df.iterrows() ] + for obj in objects: + if not utils.has_all_tags(obj.tags, tags): + objects.remove(obj) + return objects return [] def apply_materialization( diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 26f9da19e18..b10040ae014 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -21,6 +21,7 @@ ) from sqlalchemy.engine import Engine +from feast import utils from feast.base_feature_view import BaseFeatureView from feast.data_source import DataSource from feast.entity import Entity @@ -220,13 +221,16 @@ def _get_stream_feature_view(self, name: str, project: str): not_found_exception=FeatureViewNotFoundException, ) - def _list_stream_feature_views(self, project: str) -> List[StreamFeatureView]: + def _list_stream_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[StreamFeatureView]: return self._list_objects( stream_feature_views, project, StreamFeatureViewProto, StreamFeatureView, "feature_view_proto", + tags=tags, ) def apply_entity(self, entity: Entity, project: str, commit: bool = True): @@ -321,9 +325,11 @@ def _list_validation_references(self, project: str) -> List[ValidationReference] proto_field_name="validation_reference_proto", ) - def _list_entities(self, project: str) -> List[Entity]: + def _list_entities( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[Entity]: return self._list_objects( - entities, project, EntityProto, Entity, "entity_proto" + entities, project, EntityProto, Entity, "entity_proto", tags=tags ) def delete_entity(self, name: str, project: str, commit: bool = True): @@ -365,9 +371,16 @@ def _get_data_source(self, name: str, project: str) -> DataSource: not_found_exception=DataSourceObjectNotFoundException, ) - def _list_data_sources(self, project: str) -> List[DataSource]: + def _list_data_sources( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[DataSource]: return self._list_objects( - data_sources, project, DataSourceProto, DataSource, "data_source_proto" + data_sources, + project, + DataSourceProto, + DataSource, + "data_source_proto", + tags=tags, ) def apply_data_source( @@ -407,18 +420,28 @@ def delete_data_source(self, name: str, project: str, commit: bool = True): if rows.rowcount < 1: raise DataSourceObjectNotFoundException(name, project) - def _list_feature_services(self, project: str) -> List[FeatureService]: + def _list_feature_services( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureService]: return self._list_objects( feature_services, project, FeatureServiceProto, FeatureService, "feature_service_proto", + tags=tags, ) - def _list_feature_views(self, project: str) -> List[FeatureView]: + def _list_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[FeatureView]: return self._list_objects( - feature_views, project, FeatureViewProto, FeatureView, "feature_view_proto" + feature_views, + project, + FeatureViewProto, + FeatureView, + "feature_view_proto", + tags=tags, ) def _list_saved_datasets(self, project: str) -> List[SavedDataset]: @@ -430,13 +453,16 @@ def _list_saved_datasets(self, project: str) -> List[SavedDataset]: "saved_dataset_proto", ) - def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureView]: + def _list_on_demand_feature_views( + self, project: str, tags: Optional[dict[str, str]] + ) -> List[OnDemandFeatureView]: return self._list_objects( on_demand_feature_views, project, OnDemandFeatureViewProto, OnDemandFeatureView, "feature_view_proto", + tags=tags, ) def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: @@ -796,18 +822,23 @@ def _list_objects( proto_class: Any, python_class: Any, proto_field_name: str, + tags: Optional[dict[str, str]] = None, ): self._maybe_init_project_metadata(project) with self.engine.begin() as conn: stmt = select(table).where(table.c.project_id == project) rows = conn.execute(stmt).all() if rows: - return [ + objects = [ python_class.from_proto( proto_class.FromString(row._mapping[proto_field_name]) ) for row in rows ] + for obj in objects: + if not utils.has_all_tags(obj.tags, tags): + objects.remove(obj) + return objects return [] def _set_last_updated_metadata(self, last_updated: datetime, project: str): diff --git a/sdk/python/feast/registry_server.py b/sdk/python/feast/registry_server.py index 85038ad6ff3..1b6798b022c 100644 --- a/sdk/python/feast/registry_server.py +++ b/sdk/python/feast/registry_server.py @@ -35,12 +35,14 @@ def GetEntity(self, request: RegistryServer_pb2.GetEntityRequest, context): name=request.name, project=request.project, allow_cache=request.allow_cache ).to_proto() - def ListEntities(self, request, context): + def ListEntities(self, request: RegistryServer_pb2.ListEntitiesRequest, context): return RegistryServer_pb2.ListEntitiesResponse( entities=[ entity.to_proto() for entity in self.proxied_registry.list_entities( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) @@ -66,12 +68,16 @@ def GetDataSource(self, request: RegistryServer_pb2.GetDataSourceRequest, contex name=request.name, project=request.project, allow_cache=request.allow_cache ).to_proto() - def ListDataSources(self, request, context): + def ListDataSources( + self, request: RegistryServer_pb2.ListDataSourcesRequest, context + ): return RegistryServer_pb2.ListDataSourcesResponse( data_sources=[ data_source.to_proto() for data_source in self.proxied_registry.list_data_sources( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) @@ -109,12 +115,16 @@ def ApplyFeatureView( ) return Empty() - def ListFeatureViews(self, request, context): + def ListFeatureViews( + self, request: RegistryServer_pb2.ListFeatureViewsRequest, context + ): return RegistryServer_pb2.ListFeatureViewsResponse( feature_views=[ feature_view.to_proto() for feature_view in self.proxied_registry.list_feature_views( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) @@ -134,12 +144,16 @@ def GetStreamFeatureView( name=request.name, project=request.project, allow_cache=request.allow_cache ).to_proto() - def ListStreamFeatureViews(self, request, context): + def ListStreamFeatureViews( + self, request: RegistryServer_pb2.ListStreamFeatureViewsRequest, context + ): return RegistryServer_pb2.ListStreamFeatureViewsResponse( stream_feature_views=[ stream_feature_view.to_proto() for stream_feature_view in self.proxied_registry.list_stream_feature_views( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) @@ -151,12 +165,16 @@ def GetOnDemandFeatureView( name=request.name, project=request.project, allow_cache=request.allow_cache ).to_proto() - def ListOnDemandFeatureViews(self, request, context): + def ListOnDemandFeatureViews( + self, request: RegistryServer_pb2.ListOnDemandFeatureViewsRequest, context + ): return RegistryServer_pb2.ListOnDemandFeatureViewsResponse( on_demand_feature_views=[ on_demand_feature_view.to_proto() for on_demand_feature_view in self.proxied_registry.list_on_demand_feature_views( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) @@ -185,7 +203,9 @@ def ListFeatureServices( feature_services=[ feature_service.to_proto() for feature_service in self.proxied_registry.list_feature_services( - project=request.project, allow_cache=request.allow_cache + project=request.project, + allow_cache=request.allow_cache, + tags=dict(request.tags), ) ] ) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 47faa7d8c48..dcd8b1d9273 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -3,7 +3,7 @@ from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, cast import pandas as pd import pyarrow @@ -256,3 +256,25 @@ def _convert_arrow_to_proto( created_timestamps = [None] * table.num_rows return list(zip(entity_keys, features, event_timestamps, created_timestamps)) + + +def has_all_tags( + object_tags: dict[str, str], requested_tags: Optional[dict[str, str]] = None +) -> bool: + if requested_tags is None: + return True + return all(object_tags.get(key, None) == val for key, val in requested_tags.items()) + + +def tags_str_to_dict(tags: Optional[str] = None) -> Optional[dict[str, str]]: + if tags is None: + return None + tags_list = ( + str(tags).strip().strip("()").replace('"', "").replace("'", "").split(",") + ) + return { + key.strip(): value.strip() + for key, value in dict( + cast(tuple[str, str], tag.split(":", 1)) for tag in tags_list if ":" in tag + ).items() + } diff --git a/sdk/python/tests/example_repos/example_feature_repo_1.py b/sdk/python/tests/example_repos/example_feature_repo_1.py index 20a8ad7bd86..daf7b7e7e6f 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_1.py +++ b/sdk/python/tests/example_repos/example_feature_repo_1.py @@ -5,6 +5,7 @@ from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource from feast.on_demand_feature_view import on_demand_feature_view from feast.types import Array, Float32, Int64, String +from tests.integration.feature_repos.universal.feature_views import TAGS # Note that file source paths are not validated, so there doesn't actually need to be any data # at the paths for these file sources. Since these paths are effectively fake, this example @@ -42,11 +43,13 @@ name="driver", # The name is derived from this argument, not object name. join_keys=["driver_id"], description="driver id", + tags=TAGS, ) customer = Entity( name="customer", # The name is derived from this argument, not object name. join_keys=["customer_id"], + tags=TAGS, ) item = Entity( @@ -137,5 +140,5 @@ def customer_profile_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame: all_drivers_feature_service = FeatureService( name="driver_locations_service", features=[driver_locations], - tags={"release": "production"}, + tags=TAGS, ) diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_feature_service_2.py b/sdk/python/tests/example_repos/example_feature_repo_with_feature_service_2.py index 3547c3de86a..49f5bbaf054 100644 --- a/sdk/python/tests/example_repos/example_feature_repo_with_feature_service_2.py +++ b/sdk/python/tests/example_repos/example_feature_repo_with_feature_service_2.py @@ -59,5 +59,5 @@ driver_hourly_stats_view[["conv_rate"]], global_stats_feature_view[["num_rides"]], ], - tags={"release": "production"}, + tags={"release": "qa"}, ) diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index 32649fe5bf0..11ddcb0ecc6 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -25,6 +25,8 @@ location, ) +TAGS = {"release": "production"} + def driver_feature_view( data_source: DataSource, @@ -202,6 +204,7 @@ def create_driver_hourly_stats_feature_view(source, infer_features: bool = False ], source=source, ttl=timedelta(hours=2), + tags=TAGS, ) return driver_stats_feature_view @@ -221,6 +224,7 @@ def create_driver_hourly_stats_batch_feature_view( ], source=source, ttl=timedelta(hours=2), + tags=TAGS, ) return driver_stats_feature_view @@ -238,6 +242,7 @@ def create_customer_daily_profile_feature_view(source, infer_features: bool = Fa ], source=source, ttl=timedelta(days=2), + tags=TAGS, ) return customer_profile_feature_view @@ -254,6 +259,7 @@ def create_global_stats_feature_view(source, infer_features: bool = False): ], source=source, ttl=timedelta(days=2), + tags=TAGS, ) return global_stats_feature_view diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 4cb474d2f1a..e78c1053bf8 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -29,6 +29,7 @@ ) from tests.integration.feature_repos.universal.entities import driver, item from tests.integration.feature_repos.universal.feature_views import ( + TAGS, create_driver_hourly_stats_feature_view, create_item_embeddings_feature_view, driver_feature_view, @@ -150,9 +151,13 @@ def test_write_to_online_store_event_check(environment): entities=[e], source=file_source, ttl=timedelta(minutes=5), + tags=TAGS, ) # Register Feature View and Entity fs.apply([fv1, e]) + assert len(fs.list_all_feature_views(tags=TAGS)) == 1 + assert len(fs.list_feature_views(tags=TAGS)) == 1 + assert len(fs.list_batch_feature_views(tags=TAGS)) == 1 # data to ingest into Online Store (recent) data = { @@ -410,6 +415,7 @@ def setup_feature_store_universal_feature_views( feature_views = construct_universal_feature_views(data_sources) fs.apply([driver(), feature_views.driver, feature_views.global_fv]) + assert len(fs.list_batch_feature_views(TAGS)) == 2 data = { "driver_id": [1, 2], @@ -499,6 +505,16 @@ def test_async_online_retrieval_with_event_timestamps( assert_feature_store_universal_feature_views_response(df) +@pytest.mark.integration +@pytest.mark.universal_online_stores +def test_online_list_retrieval(environment, universal_data_sources): + fs = setup_feature_store_universal_feature_views( + environment, universal_data_sources + ) + + assert len(fs.list_batch_feature_views(tags=TAGS)) == 2 + + @pytest.mark.integration @pytest.mark.universal_online_stores(only=["redis"]) def test_online_store_cleanup(environment, universal_data_sources): diff --git a/sdk/python/tests/unit/local_feast_tests/test_feature_service.py b/sdk/python/tests/unit/local_feast_tests/test_feature_service.py index 82c1dd2a1d9..75ceb463085 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_feature_service.py +++ b/sdk/python/tests/unit/local_feast_tests/test_feature_service.py @@ -6,6 +6,7 @@ create_driver_hourly_stats_df, create_global_daily_stats_df, ) +from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.basic_read_write_test import basic_rw_test from tests.utils.cli_repo_creator import CliRunner, get_example_repo @@ -19,6 +20,9 @@ def test_apply_without_fv_inference() -> None: get_example_repo("example_feature_repo_with_feature_service_2.py"), "file" ) as store: assert len(store.list_feature_services()) == 2 + assert len(store.list_feature_services(tags={"release": "qa"})) == 1 + assert len(store.list_feature_services(tags=TAGS)) == 1 + assert len(store.list_feature_services(tags={"wrong": "tag"})) == 0 fs = store.get_feature_service("all_stats") assert len(fs.feature_view_projections) == 2 @@ -35,6 +39,7 @@ def test_apply_without_fv_inference() -> None: assert len(fs.feature_view_projections[0].desired_features) == 0 assert len(fs.feature_view_projections[0].features) == 1 assert len(fs.feature_view_projections[0].desired_features) == 0 + assert fs.tags["release"] == "qa" def test_apply_with_fv_inference() -> None: diff --git a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py index b3e6762c17d..63eafe6fc9a 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py +++ b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py @@ -4,7 +4,7 @@ import pytest from pytest_lazyfixture import lazy_fixture -from feast import BatchFeatureView +from feast import BatchFeatureView, utils from feast.aggregation import Aggregation from feast.data_format import AvroFormat, ParquetFormat from feast.data_source import KafkaSource @@ -17,6 +17,7 @@ from feast.repo_config import RepoConfig from feast.stream_feature_view import stream_feature_view from feast.types import Array, Bytes, Float32, Int64, String +from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.cli_repo_creator import CliRunner, get_example_repo from tests.utils.data_source_test_creator import prep_file_source @@ -89,7 +90,7 @@ def test_apply_feature_view(test_feature_store): Field(name="entity_id", dtype=Int64), ], entities=[entity], - tags={"team": "matchmaking"}, + tags={"team": "matchmaking", "tag": "two"}, source=batch_source, ttl=timedelta(minutes=5), ) @@ -97,11 +98,36 @@ def test_apply_feature_view(test_feature_store): # Register Feature View test_feature_store.apply([entity, fv1, bfv]) - feature_views = test_feature_store.list_feature_views() + # List Feature Views + assert len(test_feature_store.list_batch_feature_views({})) == 2 + feature_views = test_feature_store.list_batch_feature_views() + assert ( + len(feature_views) == 2 + and feature_views[0].name == "my_feature_view_1" + and feature_views[0].features[0].name == "fs1_my_feature_1" + and feature_views[0].features[0].dtype == Int64 + and feature_views[0].features[1].name == "fs1_my_feature_2" + and feature_views[0].features[1].dtype == String + and feature_views[0].features[2].name == "fs1_my_feature_3" + and feature_views[0].features[2].dtype == Array(String) + and feature_views[0].features[3].name == "fs1_my_feature_4" + and feature_views[0].features[3].dtype == Array(Bytes) + and feature_views[0].entities[0] == "fs1_my_entity_1" + ) + + assert utils.tags_str_to_dict() is None + assert utils.has_all_tags({}) + + tags_dict = {"team": "matchmaking"} + tags_filter = utils.tags_str_to_dict("('team:matchmaking',)") + assert tags_filter == tags_dict # List Feature Views + feature_views = test_feature_store.list_batch_feature_views(tags=tags_filter) assert ( len(feature_views) == 2 + and utils.has_all_tags(feature_views[0].tags, tags_filter) + and utils.has_all_tags(feature_views[1].tags, tags_filter) and feature_views[0].name == "my_feature_view_1" and feature_views[0].features[0].name == "fs1_my_feature_1" and feature_views[0].features[0].dtype == Int64 @@ -114,6 +140,34 @@ def test_apply_feature_view(test_feature_store): and feature_views[0].entities[0] == "fs1_my_entity_1" ) + tags_dict = {"team": "matchmaking", "tag": "two"} + tags_filter = utils.tags_str_to_dict("(' team :matchmaking, tag: two ',)") + assert tags_filter == tags_dict + + # List Feature Views + feature_views = test_feature_store.list_batch_feature_views(tags=tags_filter) + assert ( + len(feature_views) == 1 + and utils.has_all_tags(feature_views[0].tags, tags_filter) + and feature_views[0].name == "batch_feature_view" + and feature_views[0].features[0].name == "fs1_my_feature_1" + and feature_views[0].features[0].dtype == Int64 + and feature_views[0].features[1].name == "fs1_my_feature_2" + and feature_views[0].features[1].dtype == String + and feature_views[0].features[2].name == "fs1_my_feature_3" + and feature_views[0].features[2].dtype == Array(String) + and feature_views[0].features[3].name == "fs1_my_feature_4" + and feature_views[0].features[3].dtype == Array(Bytes) + and feature_views[0].entities[0] == "fs1_my_entity_1" + ) + + tags_dict = {"missing": "tag"} + tags_filter = utils.tags_str_to_dict("('missing:tag,fdsa',fdas)") + assert tags_filter == tags_dict + + # List Feature Views + assert len(test_feature_store.list_batch_feature_views(tags=tags_filter)) == 0 + test_feature_store.teardown() @@ -136,7 +190,7 @@ def test_apply_feature_view_with_inline_batch_source( test_feature_store.apply([entity, driver_fv]) - fvs = test_feature_store.list_feature_views() + fvs = test_feature_store.list_batch_feature_views() assert len(fvs) == 1 assert fvs[0] == driver_fv @@ -185,7 +239,7 @@ def test_apply_feature_view_with_inline_stream_source( test_feature_store.apply([entity, driver_fv]) - fvs = test_feature_store.list_feature_views() + fvs = test_feature_store.list_batch_feature_views() assert len(fvs) == 1 assert fvs[0] == driver_fv @@ -525,10 +579,12 @@ def test_apply_stream_source(test_feature_store, simple_dataset_1) -> None: topic="topic", batch_source=file_source, watermark_delay_threshold=timedelta(days=1), + tags=TAGS, ) test_feature_store.apply([stream_source]) + assert len(test_feature_store.list_data_sources(tags=TAGS)) == 1 ds = test_feature_store.list_data_sources() assert len(ds) == 2 if isinstance(ds[0], FileSource): diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 13b220fbb97..1e8cf45dcc6 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -17,6 +17,7 @@ from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RegistryConfig +from tests.integration.feature_repos.universal.feature_views import TAGS from tests.utils.cli_repo_creator import CliRunner, get_example_repo @@ -96,6 +97,9 @@ def test_get_online_features() -> None: progress=None, ) + assert len(store.list_entities()) == 3 + assert len(store.list_entities(tags=TAGS)) == 2 + # Retrieve two features using two keys, one valid one non-existing result = store.get_online_features( features=[