Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add online feature retrieval integration test using the universal repo #1783

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def vary_providers_for_offline_stores(

@contextmanager
def construct_test_environment(
test_repo_config: TestRepoConfig, create_and_apply: bool = False
test_repo_config: TestRepoConfig,
create_and_apply: bool = False,
materialize: bool = False,
) -> Environment:
"""
This method should take in the parameters from the test repo config and created a feature repo, apply it,
Expand Down Expand Up @@ -236,6 +238,9 @@ def construct_test_environment(
)
fs.apply(fvs + entities)

if materialize:
fs.materialize(environment.start_date, environment.end_date)

yield environment
finally:
offline_creator.teardown()
Expand Down Expand Up @@ -266,13 +271,14 @@ def inner_test(config):

def parametrize_offline_retrieval_test(offline_retrieval_test):
"""
This decorator should be used for end-to-end tests. These tests are expected to be parameterized,
and receive an empty feature repo created for all supported configurations.
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
and receive an Environment object that contains a reference to a Feature Store with pre-applied
entities and feature views.

The decorator also ensures that sample data needed for the test is available in the relevant offline store.

Decorated tests should create and apply the objects needed by the tests, and perform any operations needed
(such as materialization and looking up feature values).
Decorated tests should interact with the offline store, via the FeatureStore.get_historical_features method. They
may perform more operations as needed.

The decorator takes care of tearing down the feature store, as well as the sample data.
"""
Expand All @@ -288,3 +294,30 @@ def inner_test(config):
offline_retrieval_test(environment)

return inner_test


def parametrize_online_test(online_test):
"""
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
and receive an Environment object that contains a reference to a Feature Store with pre-applied
entities and feature views.

The decorator also ensures that sample data needed for the test is available in the relevant offline store. This
data is also materialized into the online store.

The decorator takes care of tearing down the feature store, as well as the sample data.
"""

configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS)
configs = vary_full_feature_names(configs)
configs = vary_infer_event_timestamp_col(configs)

@pytest.mark.integration
@pytest.mark.parametrize("config", configs, ids=lambda v: str(v))
def inner_test(config):
with construct_test_environment(
config, create_and_apply=True, materialize=True
) as environment:
online_test(environment)

return inner_test
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_data_sources(
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
date_partition_column="",
field_mapping=field_mapping or {"ts_1": "ts", "id": "driver_id"},
adchia marked this conversation as resolved.
Show resolved Hide resolved
field_mapping=field_mapping or {"ts_1": "ts"},
)

def get_prefixed_table_name(self, name: str, suffix: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def create_customer_daily_profile_feature_view(source):
Feature(name="current_balance", dtype=ValueType.FLOAT),
Feature(name="avg_passenger_count", dtype=ValueType.FLOAT),
Feature(name="lifetime_trip_count", dtype=ValueType.INT32),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
],
batch_source=source,
ttl=timedelta(days=2),
Expand Down
115 changes: 115 additions & 0 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import random
import unittest

import pandas as pd

from tests.integration.feature_repos.test_repo_configuration import (
Environment,
parametrize_online_test,
)


@parametrize_online_test
achals marked this conversation as resolved.
Show resolved Hide resolved
def test_online_retrieval(environment: Environment):
fs = environment.feature_store
full_feature_names = environment.test_repo_config.full_feature_names

sample_drivers = random.sample(environment.driver_entities, 10)
drivers_df = environment.driver_df[
environment.driver_df["driver_id"].isin(sample_drivers)
]

sample_customers = random.sample(environment.customer_entities, 10)
customers_df = environment.customer_df[
environment.customer_df["customer_id"].isin(sample_customers)
]

entity_rows = [
{"driver": d, "customer_id": c}
for (d, c) in zip(sample_drivers, sample_customers)
]

feature_refs = [
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
]
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs]

online_features = fs.get_online_features(
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)
assert online_features is not None

keys = online_features.to_dict().keys()
assert (
len(keys) == len(feature_refs) + 2
) # Add two for the driver id and the customer id entity keys.
for feature in feature_refs:
if full_feature_names:
assert feature.replace(":", "__") in keys
else:
assert feature.rsplit(":", 1)[-1] in keys
assert "driver_stats" not in keys and "customer_profile" not in keys

online_features_dict = online_features.to_dict()
tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_dataframes(
drivers_df, customers_df, entity_row
)

assert df_features["customer_id"] == online_features_dict["customer_id"][i]
assert df_features["driver_id"] == online_features_dict["driver_id"][i]
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertEqual(
df_features[unprefixed_feature_ref],
online_features_dict[
response_feature_name(unprefixed_feature_ref, full_feature_names)
][i],
)

# Check what happens for missing values
missing_responses_dict = fs.get_online_features(
features=feature_refs,
entity_rows=[{"driver": 0, "customer_id": 0}],
full_feature_names=full_feature_names,
).to_dict()
assert missing_responses_dict is not None
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertIsNone(
missing_responses_dict[
response_feature_name(unprefixed_feature_ref, full_feature_names)
][0]
)


def response_feature_name(feature: str, full_feature_names: bool) -> str:
if (
feature in {"current_balance", "avg_passenger_count", "lifetime_trip_count"}
and full_feature_names
):
return f"customer_profile__{feature}"

if feature in {"conv_rate", "avg_daily_trips"} and full_feature_names:
return f"driver_stats__{feature}"

return feature


def get_latest_feature_values_from_dataframes(driver_df, customer_df, entity_row):
driver_rows = driver_df[driver_df["driver_id"] == entity_row["driver"]]
latest_driver_row: pd.DataFrame = driver_rows.loc[
driver_rows["event_timestamp"].idxmax()
].to_dict()
customer_rows = customer_df[customer_df["customer_id"] == entity_row["customer_id"]]
latest_customer_row = customer_rows.loc[
customer_rows["event_timestamp"].idxmax()
].to_dict()

latest_customer_row.update(latest_driver_row)
adchia marked this conversation as resolved.
Show resolved Hide resolved
return latest_customer_row