From 174c8cf018a96e28f099fd2f21107cd8a15c5c0c Mon Sep 17 00:00:00 2001 From: "john.lemmon" Date: Wed, 14 Feb 2024 15:41:48 -0600 Subject: [PATCH] Fix for entityless feature views in Snowflake Signed-off-by: john.lemmon --- .../infra/materialization/snowflake_engine.py | 12 ++-- .../materialization/test_snowflake.py | 62 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 62b23dfade..28bec198a5 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -14,7 +14,7 @@ import feast from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity -from feast.feature_view import FeatureView +from feast.feature_view import DUMMY_ENTITY_ID, FeatureView from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, MaterializationJob, @@ -274,7 +274,11 @@ def _materialize_one( fv_latest_values_sql = offline_job.to_sql() - if feature_view.entity_columns: + if ( + feature_view.entity_columns[0].name == DUMMY_ENTITY_ID + ): # entityless Feature View's placeholder entity + entities_to_write = 1 + else: join_keys = [entity.name for entity in feature_view.entity_columns] unique_entities = '"' + '", "'.join(join_keys) + '"' @@ -287,10 +291,6 @@ def _materialize_one( with GetSnowflakeConnection(self.repo_config.offline_store) as conn: entities_to_write = conn.cursor().execute(query).fetchall()[0][0] - else: - entities_to_write = ( - 1 # entityless feature view has a placeholder entity - ) if feature_view.batch_source.field_mapping is not None: fv_latest_mapped_values_sql = _run_snowflake_field_mapping( diff --git a/sdk/python/tests/integration/materialization/test_snowflake.py b/sdk/python/tests/integration/materialization/test_snowflake.py index daa96a87c9..60fa9b30aa 100644 --- a/sdk/python/tests/integration/materialization/test_snowflake.py +++ b/sdk/python/tests/integration/materialization/test_snowflake.py @@ -185,3 +185,65 @@ def test_snowflake_materialization_consistency_internal_with_lists( finally: fs.teardown() snowflake_environment.data_source_creator.teardown() + + +@pytest.mark.integration +def test_snowflake_materialization_entityless_fv(): + snowflake_config = IntegrationTestRepoConfig( + online_store=SNOWFLAKE_ONLINE_CONFIG, + offline_store_creator=SnowflakeDataSourceCreator, + batch_engine=SNOWFLAKE_ENGINE_CONFIG, + ) + snowflake_environment = construct_test_environment(snowflake_config, None) + + df = create_basic_driver_dataset() + entityless_df = df.drop("driver_id", axis=1) + ds = snowflake_environment.data_source_creator.create_data_source( + entityless_df, + snowflake_environment.feature_store.project, + field_mapping={"ts_1": "ts"}, + ) + + fs = snowflake_environment.feature_store + + # We include the driver entity so we can provide an entity ID when fetching features + driver = Entity( + name="driver_id", + join_keys=["driver_id"], + ) + + overall_stats_fv = FeatureView( + name="overall_hourly_stats", + entities=[], + ttl=timedelta(weeks=52), + source=ds, + ) + + try: + fs.apply([overall_stats_fv, driver]) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + print(f"Split datetime: {split_dt}") + + now = datetime.utcnow() + + start_date = (now - timedelta(hours=5)).replace(tzinfo=utc) + end_date = split_dt + fs.materialize( + feature_views=[overall_stats_fv.name], + start_date=start_date, + end_date=end_date, + ) + + response_dict = fs.get_online_features( + [f"{overall_stats_fv.name}:value"], + [{"driver_id": 1}], # Included because we need an entity + ).to_dict() + assert response_dict["value"] == [0.3] + + finally: + fs.teardown() + snowflake_environment.data_source_creator.teardown()