Skip to content

Commit

Permalink
Add online feature retrieval integration test using the universal repo (
Browse files Browse the repository at this point in the history
#1783)

* Add online feature retrieval integration test using the universal repo

Signed-off-by: Achal Shah <achals@gmail.com>

* Comments

Signed-off-by: Achal Shah <achals@gmail.com>

* Comments

Signed-off-by: Achal Shah <achals@gmail.com>

* meaty online tests

Signed-off-by: Achal Shah <achals@gmail.com>

* remove unused feature

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Aug 17, 2021
1 parent 26054aa commit da436b5
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def vary_providers_for_offline_stores(

@contextmanager
def construct_test_environment(
test_repo_config: TestRepoConfig, create_and_apply: bool = False
test_repo_config: TestRepoConfig,
create_and_apply: bool = False,
materialize: bool = False,
) -> Environment:
"""
This method should take in the parameters from the test repo config and created a feature repo, apply it,
Expand Down Expand Up @@ -256,6 +258,9 @@ def construct_test_environment(
)
fs.apply(fvs + entities)

if materialize:
fs.materialize(environment.start_date, environment.end_date)

yield environment
finally:
offline_creator.teardown()
Expand Down Expand Up @@ -286,13 +291,14 @@ def inner_test(config):

def parametrize_offline_retrieval_test(offline_retrieval_test):
"""
This decorator should be used for end-to-end tests. These tests are expected to be parameterized,
and receive an empty feature repo created for all supported configurations.
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
and receive an Environment object that contains a reference to a Feature Store with pre-applied
entities and feature views.
The decorator also ensures that sample data needed for the test is available in the relevant offline store.
Decorated tests should create and apply the objects needed by the tests, and perform any operations needed
(such as materialization and looking up feature values).
Decorated tests should interact with the offline store, via the FeatureStore.get_historical_features method. They
may perform more operations as needed.
The decorator takes care of tearing down the feature store, as well as the sample data.
"""
Expand All @@ -308,3 +314,30 @@ def inner_test(config):
offline_retrieval_test(environment)

return inner_test


def parametrize_online_test(online_test):
"""
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
and receive an Environment object that contains a reference to a Feature Store with pre-applied
entities and feature views.
The decorator also ensures that sample data needed for the test is available in the relevant offline store. This
data is also materialized into the online store.
The decorator takes care of tearing down the feature store, as well as the sample data.
"""

configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS)
configs = vary_full_feature_names(configs)
configs = vary_infer_event_timestamp_col(configs)

@pytest.mark.integration
@pytest.mark.parametrize("config", configs, ids=lambda v: str(v))
def inner_test(config):
with construct_test_environment(
config, create_and_apply=True, materialize=True
) as environment:
online_test(environment)

return inner_test
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_data_sources(
event_timestamp_column=event_timestamp_column,
created_timestamp_column=created_timestamp_column,
date_partition_column="",
field_mapping=field_mapping or {"ts_1": "ts", "id": "driver_id"},
field_mapping=field_mapping or {"ts_1": "ts"},
)

def get_prefixed_table_name(self, name: str, suffix: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def create_customer_daily_profile_feature_view(source):
Feature(name="current_balance", dtype=ValueType.FLOAT),
Feature(name="avg_passenger_count", dtype=ValueType.FLOAT),
Feature(name="lifetime_trip_count", dtype=ValueType.INT32),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
],
batch_source=source,
ttl=timedelta(days=2),
Expand Down
115 changes: 115 additions & 0 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import random
import unittest

import pandas as pd

from tests.integration.feature_repos.test_repo_configuration import (
Environment,
parametrize_online_test,
)


@parametrize_online_test
def test_online_retrieval(environment: Environment):
fs = environment.feature_store
full_feature_names = environment.test_repo_config.full_feature_names

sample_drivers = random.sample(environment.driver_entities, 10)
drivers_df = environment.driver_df[
environment.driver_df["driver_id"].isin(sample_drivers)
]

sample_customers = random.sample(environment.customer_entities, 10)
customers_df = environment.customer_df[
environment.customer_df["customer_id"].isin(sample_customers)
]

entity_rows = [
{"driver": d, "customer_id": c}
for (d, c) in zip(sample_drivers, sample_customers)
]

feature_refs = [
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
]
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs]

online_features = fs.get_online_features(
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)
assert online_features is not None

keys = online_features.to_dict().keys()
assert (
len(keys) == len(feature_refs) + 2
) # Add two for the driver id and the customer id entity keys.
for feature in feature_refs:
if full_feature_names:
assert feature.replace(":", "__") in keys
else:
assert feature.rsplit(":", 1)[-1] in keys
assert "driver_stats" not in keys and "customer_profile" not in keys

online_features_dict = online_features.to_dict()
tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_dataframes(
drivers_df, customers_df, entity_row
)

assert df_features["customer_id"] == online_features_dict["customer_id"][i]
assert df_features["driver_id"] == online_features_dict["driver_id"][i]
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertEqual(
df_features[unprefixed_feature_ref],
online_features_dict[
response_feature_name(unprefixed_feature_ref, full_feature_names)
][i],
)

# Check what happens for missing values
missing_responses_dict = fs.get_online_features(
features=feature_refs,
entity_rows=[{"driver": 0, "customer_id": 0}],
full_feature_names=full_feature_names,
).to_dict()
assert missing_responses_dict is not None
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertIsNone(
missing_responses_dict[
response_feature_name(unprefixed_feature_ref, full_feature_names)
][0]
)


def response_feature_name(feature: str, full_feature_names: bool) -> str:
if (
feature in {"current_balance", "avg_passenger_count", "lifetime_trip_count"}
and full_feature_names
):
return f"customer_profile__{feature}"

if feature in {"conv_rate", "avg_daily_trips"} and full_feature_names:
return f"driver_stats__{feature}"

return feature


def get_latest_feature_values_from_dataframes(driver_df, customer_df, entity_row):
driver_rows = driver_df[driver_df["driver_id"] == entity_row["driver"]]
latest_driver_row: pd.DataFrame = driver_rows.loc[
driver_rows["event_timestamp"].idxmax()
].to_dict()
customer_rows = customer_df[customer_df["customer_id"] == entity_row["customer_id"]]
latest_customer_row = customer_rows.loc[
customer_rows["event_timestamp"].idxmax()
].to_dict()

latest_customer_row.update(latest_driver_row)
return latest_customer_row

0 comments on commit da436b5

Please sign in to comment.