From 353028062e16ebb4038f04b10ee9dea824cde7ae Mon Sep 17 00:00:00 2001 From: Abhin Chhabra Date: Tue, 10 May 2022 17:56:37 -0400 Subject: [PATCH] Fixes #2576 Reject undefined features when using `get_historical_features` or `get_online_features`. Signed-off-by: Abhin Chhabra --- sdk/python/feast/feature_store.py | 9 ++- sdk/python/feast/feature_view_projection.py | 8 +++ .../test_universal_historical_retrieval.py | 42 +++++++++++++- .../online_store/test_universal_online.py | 56 ++++++++++++++++++- 4 files changed, 112 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 60517021e1..086479236f 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -875,7 +875,7 @@ def get_historical_features( DeprecationWarning, ) - # TODO(achal): _group_feature_refs returns the on demand feature views, but it's no passed into the provider. + # TODO(achal): _group_feature_refs returns the on demand feature views, but it's not passed into the provider. # This is a weird interface quirk - we should revisit the `get_historical_features` to # pass in the on demand feature views as well. fvs, odfvs, request_fvs, request_fv_refs = _group_feature_refs( @@ -2125,8 +2125,12 @@ def _group_feature_refs( for ref in features: view_name, feat_name = ref.split(":") if view_name in view_index: + view_index[view_name].projection.get_feature(feat_name) # For validation views_features[view_name].add(feat_name) elif view_name in on_demand_view_index: + on_demand_view_index[view_name].projection.get_feature( + feat_name + ) # For validation on_demand_view_features[view_name].add(feat_name) # Let's also add in any FV Feature dependencies here. for input_fv_projection in on_demand_view_index[ @@ -2135,6 +2139,9 @@ def _group_feature_refs( for input_feat in input_fv_projection.features: views_features[input_fv_projection.name].add(input_feat.name) elif view_name in request_view_index: + request_view_index[view_name].projection.get_feature( + feat_name + ) # For validation request_views_features[view_name].add(feat_name) request_view_refs.add(ref) else: diff --git a/sdk/python/feast/feature_view_projection.py b/sdk/python/feast/feature_view_projection.py index a8e0e8cfe5..fbf0db5ccd 100644 --- a/sdk/python/feast/feature_view_projection.py +++ b/sdk/python/feast/feature_view_projection.py @@ -64,3 +64,11 @@ def from_definition(base_feature_view: "BaseFeatureView"): name_alias=None, features=base_feature_view.features, ) + + def get_feature(self, feature_name: str) -> Field: + try: + return next(field for field in self.features if field.name == feature_name) + except StopIteration: + raise KeyError( + f"Feature {feature_name} not found in projection {self.name_to_use()}" + ) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 1b7dab2110..d5f49a1f95 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -21,7 +21,7 @@ from feast.infra.offline_stores.offline_utils import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, ) -from feast.types import Int32 +from feast.types import Float32, Int32 from feast.value_type import ValueType from tests.integration.feature_repos.repo_configuration import ( construct_universal_feature_views, @@ -410,6 +410,46 @@ def test_historical_features(environment, universal_data_sources, full_feature_n ) +@pytest.mark.integration +@pytest.mark.universal +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_historical_features_with_shared_batch_source( + environment, universal_data_sources, full_feature_names +): + # Addresses https://github.com/feast-dev/feast/issues/2576 + + store = environment.feature_store + + entities, datasets, data_sources = universal_data_sources + driver_stats_v1 = FeatureView( + name="driver_stats_v1", + entities=["driver"], + schema=[Field(name="avg_daily_trips", dtype=Int32)], + source=data_sources.driver, + ) + driver_stats_v2 = FeatureView( + name="driver_stats_v2", + entities=["driver"], + schema=[ + Field(name="avg_daily_trips", dtype=Int32), + Field(name="conv_rate", dtype=Float32), + ], + source=data_sources.driver, + ) + + store.apply([driver(), driver_stats_v1, driver_stats_v2]) + + with pytest.raises(KeyError): + store.get_historical_features( + entity_df=datasets.entity_df, + features=[ + # `driver_stats_v1` does not have `conv_rate` + "driver_stats_v1:conv_rate", + ], + full_feature_names=full_feature_names, + ).to_df() + + @pytest.mark.integration @pytest.mark.universal_offline_stores def test_historical_features_with_missing_request_data( diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index ababb25c39..259a094426 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -19,7 +19,7 @@ RequestDataNotFoundInEntityRowsException, ) from feast.online_response import TIMESTAMP_POSTFIX -from feast.types import String +from feast.types import Float32, Int32, String from feast.wait import wait_retry_backoff from tests.integration.feature_repos.repo_configuration import ( Environment, @@ -324,6 +324,60 @@ def get_online_features_dict( return dict1 +@pytest.mark.integration +@pytest.mark.universal +def test_online_retrieval_with_shared_batch_source(environment, universal_data_sources): + # Addresses https://github.com/feast-dev/feast/issues/2576 + + fs = environment.feature_store + + entities, datasets, data_sources = universal_data_sources + driver_stats_v1 = FeatureView( + name="driver_stats_v1", + entities=["driver"], + schema=[Field(name="avg_daily_trips", dtype=Int32)], + source=data_sources.driver, + ) + driver_stats_v2 = FeatureView( + name="driver_stats_v2", + entities=["driver"], + schema=[ + Field(name="avg_daily_trips", dtype=Int32), + Field(name="conv_rate", dtype=Float32), + ], + source=data_sources.driver, + ) + + fs.apply([driver(), driver_stats_v1, driver_stats_v2]) + + data = pd.DataFrame( + { + "driver_id": [1, 2], + "avg_daily_trips": [4, 5], + "conv_rate": [0.5, 0.3], + "event_timestamp": [ + pd.to_datetime(1646263500, utc=True, unit="s"), + pd.to_datetime(1646263600, utc=True, unit="s"), + ], + "created": [ + pd.to_datetime(1646263500, unit="s"), + pd.to_datetime(1646263600, unit="s"), + ], + } + ) + fs.write_to_online_store("driver_stats_v1", data.drop("conv_rate", axis=1)) + fs.write_to_online_store("driver_stats_v2", data) + + with pytest.raises(KeyError): + fs.get_online_features( + features=[ + # `driver_stats_v1` does not have `conv_rate` + "driver_stats_v1:conv_rate", + ], + entity_rows=[{"driver_id": 1}, {"driver_id": 2}], + ) + + @pytest.mark.integration @pytest.mark.universal_online_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))