Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Reject undefined features when using get_historical_features or get_online_features #2665

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def get_historical_features(
DeprecationWarning,
)

# TODO(achal): _group_feature_refs returns the on demand feature views, but it's no passed into the provider.
# TODO(achal): _group_feature_refs returns the on demand feature views, but it's not passed into the provider.
# This is a weird interface quirk - we should revisit the `get_historical_features` to
# pass in the on demand feature views as well.
fvs, odfvs, request_fvs, request_fv_refs = _group_feature_refs(
Expand Down Expand Up @@ -2125,8 +2125,12 @@ def _group_feature_refs(
for ref in features:
view_name, feat_name = ref.split(":")
if view_name in view_index:
view_index[view_name].projection.get_feature(feat_name) # For validation
views_features[view_name].add(feat_name)
elif view_name in on_demand_view_index:
on_demand_view_index[view_name].projection.get_feature(
feat_name
) # For validation
on_demand_view_features[view_name].add(feat_name)
# Let's also add in any FV Feature dependencies here.
for input_fv_projection in on_demand_view_index[
Expand All @@ -2135,6 +2139,9 @@ def _group_feature_refs(
for input_feat in input_fv_projection.features:
views_features[input_fv_projection.name].add(input_feat.name)
elif view_name in request_view_index:
request_view_index[view_name].projection.get_feature(
feat_name
) # For validation
request_views_features[view_name].add(feat_name)
request_view_refs.add(ref)
else:
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/feature_view_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ def from_definition(base_feature_view: "BaseFeatureView"):
name_alias=None,
features=base_feature_view.features,
)

def get_feature(self, feature_name: str) -> Field:
try:
return next(field for field in self.features if field.name == feature_name)
except StopIteration:
raise KeyError(
f"Feature {feature_name} not found in projection {self.name_to_use()}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
)
from feast.types import Int32
from feast.types import Float32, Int32
from feast.value_type import ValueType
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
Expand Down Expand Up @@ -410,6 +410,46 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
)


@pytest.mark.integration
@pytest.mark.universal
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_historical_features_with_shared_batch_source(
environment, universal_data_sources, full_feature_names
):
# Addresses https://github.com/feast-dev/feast/issues/2576

store = environment.feature_store

entities, datasets, data_sources = universal_data_sources
driver_stats_v1 = FeatureView(
name="driver_stats_v1",
entities=["driver"],
schema=[Field(name="avg_daily_trips", dtype=Int32)],
source=data_sources.driver,
)
driver_stats_v2 = FeatureView(
name="driver_stats_v2",
entities=["driver"],
schema=[
Field(name="avg_daily_trips", dtype=Int32),
Field(name="conv_rate", dtype=Float32),
],
source=data_sources.driver,
)

store.apply([driver(), driver_stats_v1, driver_stats_v2])

with pytest.raises(KeyError):
store.get_historical_features(
entity_df=datasets.entity_df,
features=[
# `driver_stats_v1` does not have `conv_rate`
"driver_stats_v1:conv_rate",
],
full_feature_names=full_feature_names,
).to_df()


@pytest.mark.integration
@pytest.mark.universal_offline_stores
def test_historical_features_with_missing_request_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
RequestDataNotFoundInEntityRowsException,
)
from feast.online_response import TIMESTAMP_POSTFIX
from feast.types import String
from feast.types import Float32, Int32, String
from feast.wait import wait_retry_backoff
from tests.integration.feature_repos.repo_configuration import (
Environment,
Expand Down Expand Up @@ -324,6 +324,60 @@ def get_online_features_dict(
return dict1


@pytest.mark.integration
@pytest.mark.universal
def test_online_retrieval_with_shared_batch_source(environment, universal_data_sources):
# Addresses https://github.com/feast-dev/feast/issues/2576

fs = environment.feature_store

entities, datasets, data_sources = universal_data_sources
driver_stats_v1 = FeatureView(
name="driver_stats_v1",
entities=["driver"],
schema=[Field(name="avg_daily_trips", dtype=Int32)],
source=data_sources.driver,
)
driver_stats_v2 = FeatureView(
name="driver_stats_v2",
entities=["driver"],
schema=[
Field(name="avg_daily_trips", dtype=Int32),
Field(name="conv_rate", dtype=Float32),
],
source=data_sources.driver,
)

fs.apply([driver(), driver_stats_v1, driver_stats_v2])

data = pd.DataFrame(
{
"driver_id": [1, 2],
"avg_daily_trips": [4, 5],
"conv_rate": [0.5, 0.3],
"event_timestamp": [
pd.to_datetime(1646263500, utc=True, unit="s"),
pd.to_datetime(1646263600, utc=True, unit="s"),
],
"created": [
pd.to_datetime(1646263500, unit="s"),
pd.to_datetime(1646263600, unit="s"),
],
}
)
fs.write_to_online_store("driver_stats_v1", data.drop("conv_rate", axis=1))
fs.write_to_online_store("driver_stats_v2", data)

with pytest.raises(KeyError):
fs.get_online_features(
features=[
# `driver_stats_v1` does not have `conv_rate`
"driver_stats_v1:conv_rate",
],
entity_rows=[{"driver_id": 1}, {"driver_id": 2}],
)


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
Expand Down