Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_online_features on demand transform bug fixes + local integration test mode #2004

Merged
merged 7 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ test-python:
test-python-integration:
FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration sdk/python/tests

test-python-universal-local:
FEAST_USAGE=False IS_TEST=True FEAST_IS_LOCAL_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests

test-python-universal:
FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests

Expand Down
131 changes: 93 additions & 38 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,9 @@ def get_online_features(
"""
_feature_refs = self._get_features(features, feature_refs)
(
all_feature_views,
all_request_feature_views,
all_on_demand_feature_views,
requested_feature_views,
requested_request_feature_views,
requested_on_demand_feature_views,
) = self._get_feature_views_to_use(
features=features, allow_cache=True, hide_dummy_entity=False
)
Expand All @@ -895,9 +895,9 @@ def get_online_features(
_,
) = _group_feature_refs(
_feature_refs,
all_feature_views,
all_request_feature_views,
all_on_demand_feature_views,
requested_feature_views,
requested_request_feature_views,
requested_on_demand_feature_views,
)
if len(grouped_odfv_refs) > 0:
log_event(UsageEvent.GET_ONLINE_FEATURES_WITH_ODFV)
Expand All @@ -913,10 +913,10 @@ def get_online_features(

provider = self._get_provider()
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map = {}
entity_name_to_join_key_map: Dict[str, str] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
for feature_view in all_feature_views:
for feature_view in requested_feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
Expand Down Expand Up @@ -976,17 +976,6 @@ def get_online_features(
# Also create entity values to append to the result
result_rows.append(_entity_row_to_field_values(entity_row_proto))

# Add more feature values to the existing result rows for the request data features
for feature_name, feature_values in request_data_features.items():
for row_idx, feature_value in enumerate(feature_values):
result_row = result_rows[row_idx]
result_row.fields[feature_name].CopyFrom(
python_value_to_proto_value(feature_value)
)
result_row.statuses[
feature_name
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

for table, requested_features in grouped_refs:
table_join_keys = [
entity_name_to_join_key_map[entity_name]
Expand All @@ -1002,17 +991,85 @@ def get_online_features(
union_of_entity_keys,
)

requested_result_row_names = self._get_requested_result_fields(
result_rows, needed_request_fv_features
)
self._populate_odfv_dependencies(
entity_name_to_join_key_map,
full_feature_names,
grouped_odfv_refs,
provider,
request_data_features,
result_rows,
union_of_entity_keys,
)

initial_response = OnlineResponse(
GetOnlineFeaturesResponse(field_values=result_rows)
)
return self._augment_response_with_on_demand_transforms(
_feature_refs,
all_on_demand_feature_views,
requested_result_row_names,
requested_on_demand_feature_views,
full_feature_names,
initial_response,
result_rows,
)

def _get_requested_result_fields(
self,
result_rows: List[GetOnlineFeaturesResponse.FieldValues],
needed_request_fv_features: Set[str],
):
# Get requested feature values so we can drop odfv dependencies that aren't requested
requested_result_row_names: Set[str] = set()
for result_row in result_rows:
for feature_name in result_row.fields.keys():
requested_result_row_names.add(feature_name)
# Request feature view values are also request data features that should be in the
# final output
requested_result_row_names.update(needed_request_fv_features)
return requested_result_row_names

def _populate_odfv_dependencies(
self,
entity_name_to_join_key_map: Dict[str, str],
full_feature_names: bool,
grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]],
provider: Provider,
request_data_features: Dict[str, List[Any]],
result_rows: List[GetOnlineFeaturesResponse.FieldValues],
union_of_entity_keys: List[EntityKeyProto],
):
# Add more feature values to the existing result rows for the request data features
for feature_name, feature_values in request_data_features.items():
for row_idx, feature_value in enumerate(feature_values):
result_row = result_rows[row_idx]
result_row.fields[feature_name].CopyFrom(
python_value_to_proto_value(feature_value)
)
result_row.statuses[
feature_name
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

# Add data if odfv requests specific feature views as dependencies
if len(grouped_odfv_refs) > 0:
for odfv, _ in grouped_odfv_refs:
for fv in odfv.input_feature_views.values():
table_join_keys = [
entity_name_to_join_key_map[entity_name]
for entity_name in fv.entities
]
self._populate_result_rows_from_feature_view(
adchia marked this conversation as resolved.
Show resolved Hide resolved
table_join_keys,
full_feature_names,
provider,
[feature.name for feature in fv.features],
result_rows,
fv,
union_of_entity_keys,
)

def get_needed_request_data(
self,
grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]],
Expand Down Expand Up @@ -1097,27 +1154,10 @@ def _populate_result_rows_from_feature_view(
feature_ref
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

def _get_needed_request_data_features(
self,
grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]],
grouped_request_fv_refs: List[Tuple[RequestFeatureView, List[str]]],
) -> Set[str]:
needed_request_data_features = set()
for odfv_to_feature_names in grouped_odfv_refs:
odfv, requested_feature_names = odfv_to_feature_names
odfv_request_data_schema = odfv.get_request_data_schema()
for feature_name in odfv_request_data_schema.keys():
needed_request_data_features.add(feature_name)
for request_fv_to_feature_names in grouped_request_fv_refs:
request_fv, requested_feature_names = request_fv_to_feature_names
for fv in request_fv.features:
needed_request_data_features.add(fv.name)
return needed_request_data_features

# TODO(adchia): remove request data, which isn't part of the feature_refs
def _augment_response_with_on_demand_transforms(
self,
feature_refs: List[str],
requested_result_row_names: Set[str],
odfvs: List[OnDemandFeatureView],
full_feature_names: bool,
initial_response: OnlineResponse,
Expand All @@ -1137,6 +1177,7 @@ def _augment_response_with_on_demand_transforms(
odfv_feature_refs[view_name].append(feature_name)

# Apply on demand transformations
odfv_result_names = set()
for odfv_name, _feature_refs in odfv_feature_refs.items():
odfv = all_on_demand_feature_views[odfv_name]
transformed_features_df = odfv.get_transformed_features_df(
Expand All @@ -1155,13 +1196,27 @@ def _augment_response_with_on_demand_transforms(
if full_feature_names
else transformed_feature
)
odfv_result_names.add(transformed_feature_name)
proto_value = python_value_to_proto_value(
transformed_features_df[transformed_feature].values[row_idx]
)
result_row.fields[transformed_feature_name].CopyFrom(proto_value)
result_row.statuses[
transformed_feature_name
] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

# Drop values that aren't needed
unneeded_features = [
val
for val in result_rows[0].fields
if val not in requested_result_row_names and val not in odfv_result_names
]
for row_idx in range(len(result_rows)):
result_row = result_rows[row_idx]
for unneeded_feature in unneeded_features:
result_row.fields.pop(unneeded_feature)
result_row.statuses.pop(unneeded_feature)

return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows))

def _get_feature_views_to_use(
Expand Down
49 changes: 27 additions & 22 deletions sdk/python/tests/integration/feature_repos/repo_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,35 @@
DEFAULT_FULL_REPO_CONFIGS: List[IntegrationTestRepoConfig] = [
# Local configurations
IntegrationTestRepoConfig(),
]
if os.getenv("FEAST_IS_LOCAL_TEST", "False") != "True":
IntegrationTestRepoConfig(online_store=REDIS_CONFIG),
# GCP configurations
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store="datastore",
),
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store=REDIS_CONFIG,
),
# AWS configurations
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=DYNAMO_CONFIG,
),
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=REDIS_CONFIG,
),
]
DEFAULT_FULL_REPO_CONFIGS.extend(
[
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store="datastore",
),
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store=REDIS_CONFIG,
),
# AWS configurations
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=DYNAMO_CONFIG,
),
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=REDIS_CONFIG,
),
]
)
full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME)
if full_repo_configs_module is not None:
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import itertools
import os
import unittest
from datetime import timedelta

Expand Down Expand Up @@ -29,6 +30,8 @@
# TODO: make this work with all universal (all online store types)
@pytest.mark.integration
def test_write_to_online_store_event_check(local_redis_environment):
if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True":
return
fs = local_redis_environment.feature_store

# write same data points 3 with different timestamps
Expand Down Expand Up @@ -274,11 +277,20 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name
)
assert online_features is not None

# Test that the on demand feature views compute properly even if the dependent conv_rate
# feature isn't requested.
online_features_no_conv_rate = fs.get_online_features(
features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"],
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)
assert online_features_no_conv_rate is not None

online_features_dict = online_features.to_dict()
keys = online_features_dict.keys()
assert (
len(keys) == len(feature_refs) + 3
) # Add three for the driver id and the customer id entity keys + val_to_add request data.
len(keys) == len(feature_refs) + 2
) # Add two for the driver id and the customer id entity keys
for feature in feature_refs:
# full_feature_names does not apply to request feature views
if full_feature_names and feature != "driver_age:driver_age":
Expand Down Expand Up @@ -526,8 +538,8 @@ def assert_feature_service_correctness(
for projection in feature_service.feature_view_projections
]
)
+ 3
) # Add two for the driver id and the customer id entity keys and val_to_add request data
+ 2
) # Add two for the driver id and the customer id entity keys

tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
Expand Down