diff --git a/sdk/python/tests/integration/feature_repos/test_repo_configuration.py b/sdk/python/tests/integration/feature_repos/test_repo_configuration.py index c7a2046dca..f8f3d42cdf 100644 --- a/sdk/python/tests/integration/feature_repos/test_repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/test_repo_configuration.py @@ -176,7 +176,9 @@ def vary_providers_for_offline_stores( @contextmanager def construct_test_environment( - test_repo_config: TestRepoConfig, create_and_apply: bool = False + test_repo_config: TestRepoConfig, + create_and_apply: bool = False, + materialize: bool = False, ) -> Environment: """ This method should take in the parameters from the test repo config and created a feature repo, apply it, @@ -236,6 +238,9 @@ def construct_test_environment( ) fs.apply(fvs + entities) + if materialize: + fs.materialize(environment.start_date, environment.end_date) + yield environment finally: offline_creator.teardown() @@ -266,13 +271,14 @@ def inner_test(config): def parametrize_offline_retrieval_test(offline_retrieval_test): """ - This decorator should be used for end-to-end tests. These tests are expected to be parameterized, - and receive an empty feature repo created for all supported configurations. + This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized, + and receive an Environment object that contains a reference to a Feature Store with pre-applied + entities and feature views. The decorator also ensures that sample data needed for the test is available in the relevant offline store. - Decorated tests should create and apply the objects needed by the tests, and perform any operations needed - (such as materialization and looking up feature values). + Decorated tests should interact with the offline store, via the FeatureStore.get_historical_features method. They + may perform more operations as needed. The decorator takes care of tearing down the feature store, as well as the sample data. """ @@ -288,3 +294,30 @@ def inner_test(config): offline_retrieval_test(environment) return inner_test + + +def parametrize_online_test(online_test): + """ + This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized, + and receive an Environment object that contains a reference to a Feature Store with pre-applied + entities and feature views. + + The decorator also ensures that sample data needed for the test is available in the relevant offline store. This + data is also materialized into the online store. + + The decorator takes care of tearing down the feature store, as well as the sample data. + """ + + configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS) + configs = vary_full_feature_names(configs) + configs = vary_infer_event_timestamp_col(configs) + + @pytest.mark.integration + @pytest.mark.parametrize("config", configs, ids=lambda v: str(v)) + def inner_test(config): + with construct_test_environment( + config, create_and_apply=True, materialize=True + ) as environment: + online_test(environment) + + return inner_test diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py index bb1957eb4a..b386c399dc 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/file.py @@ -35,7 +35,7 @@ def create_data_sources( event_timestamp_column=event_timestamp_column, created_timestamp_column=created_timestamp_column, date_partition_column="", - field_mapping=field_mapping or {"ts_1": "ts", "id": "driver_id"}, + field_mapping=field_mapping or {"ts_1": "ts"}, ) def get_prefixed_table_name(self, name: str, suffix: str) -> str: diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index b5a120f453..0306044ecd 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -39,7 +39,6 @@ def create_customer_daily_profile_feature_view(source): Feature(name="current_balance", dtype=ValueType.FLOAT), Feature(name="avg_passenger_count", dtype=ValueType.FLOAT), Feature(name="lifetime_trip_count", dtype=ValueType.INT32), - Feature(name="avg_daily_trips", dtype=ValueType.INT32), ], batch_source=source, ttl=timedelta(days=2), diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py new file mode 100644 index 0000000000..a4337a305c --- /dev/null +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -0,0 +1,115 @@ +import random +import unittest + +import pandas as pd + +from tests.integration.feature_repos.test_repo_configuration import ( + Environment, + parametrize_online_test, +) + + +@parametrize_online_test +def test_online_retrieval(environment: Environment): + fs = environment.feature_store + full_feature_names = environment.test_repo_config.full_feature_names + + sample_drivers = random.sample(environment.driver_entities, 10) + drivers_df = environment.driver_df[ + environment.driver_df["driver_id"].isin(sample_drivers) + ] + + sample_customers = random.sample(environment.customer_entities, 10) + customers_df = environment.customer_df[ + environment.customer_df["customer_id"].isin(sample_customers) + ] + + entity_rows = [ + {"driver": d, "customer_id": c} + for (d, c) in zip(sample_drivers, sample_customers) + ] + + feature_refs = [ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ] + unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs] + + online_features = fs.get_online_features( + features=feature_refs, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + assert online_features is not None + + keys = online_features.to_dict().keys() + assert ( + len(keys) == len(feature_refs) + 2 + ) # Add two for the driver id and the customer id entity keys. + for feature in feature_refs: + if full_feature_names: + assert feature.replace(":", "__") in keys + else: + assert feature.rsplit(":", 1)[-1] in keys + assert "driver_stats" not in keys and "customer_profile" not in keys + + online_features_dict = online_features.to_dict() + tc = unittest.TestCase() + for i, entity_row in enumerate(entity_rows): + df_features = get_latest_feature_values_from_dataframes( + drivers_df, customers_df, entity_row + ) + + assert df_features["customer_id"] == online_features_dict["customer_id"][i] + assert df_features["driver_id"] == online_features_dict["driver_id"][i] + for unprefixed_feature_ref in unprefixed_feature_refs: + tc.assertEqual( + df_features[unprefixed_feature_ref], + online_features_dict[ + response_feature_name(unprefixed_feature_ref, full_feature_names) + ][i], + ) + + # Check what happens for missing values + missing_responses_dict = fs.get_online_features( + features=feature_refs, + entity_rows=[{"driver": 0, "customer_id": 0}], + full_feature_names=full_feature_names, + ).to_dict() + assert missing_responses_dict is not None + for unprefixed_feature_ref in unprefixed_feature_refs: + tc.assertIsNone( + missing_responses_dict[ + response_feature_name(unprefixed_feature_ref, full_feature_names) + ][0] + ) + + +def response_feature_name(feature: str, full_feature_names: bool) -> str: + if ( + feature in {"current_balance", "avg_passenger_count", "lifetime_trip_count"} + and full_feature_names + ): + return f"customer_profile__{feature}" + + if feature in {"conv_rate", "avg_daily_trips"} and full_feature_names: + return f"driver_stats__{feature}" + + return feature + + +def get_latest_feature_values_from_dataframes(driver_df, customer_df, entity_row): + driver_rows = driver_df[driver_df["driver_id"] == entity_row["driver"]] + latest_driver_row: pd.DataFrame = driver_rows.loc[ + driver_rows["event_timestamp"].idxmax() + ].to_dict() + customer_rows = customer_df[customer_df["customer_id"] == entity_row["customer_id"]] + latest_customer_row = customer_rows.loc[ + customer_rows["event_timestamp"].idxmax() + ].to_dict() + + latest_customer_row.update(latest_driver_row) + return latest_customer_row