Skip to content

Commit

Permalink
Support join keys in historical feature retrieval (#1440)
Browse files Browse the repository at this point in the history
* Support join keys in historical feature retrieval

Signed-off-by: Willem Pienaar <git@willem.co>

* Rebase join key support

Signed-off-by: Willem Pienaar <git@willem.co>

* Remove unused methods

Signed-off-by: Willem Pienaar <git@willem.co>

Co-authored-by: Willem Pienaar <git@willem.co>
  • Loading branch information
jklegar and woop authored Apr 11, 2021
1 parent 22bf06c commit b60f6e4
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 13 deletions.
7 changes: 6 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,12 @@ def get_historical_features(
feature_views = _get_requested_feature_views(feature_refs, all_feature_views)
provider = self._get_provider()
job = provider.get_historical_features(
self.config, feature_views, feature_refs, entity_df
self.config,
feature_views,
feature_refs,
entity_df,
self._registry,
self.project,
)
return job

Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/infra/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
offline_store = get_offline_store_from_sources(
[feature_view.input for feature_view in feature_views]
Expand All @@ -194,6 +196,8 @@ def get_historical_features(
feature_views=feature_views,
feature_refs=feature_refs,
entity_df=entity_df,
registry=registry,
project=project,
)
return job

Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/infra/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
offline_store = get_offline_store_from_sources(
[feature_view.input for feature_view in feature_views]
Expand All @@ -205,6 +207,8 @@ def get_historical_features(
feature_views=feature_views,
feature_refs=feature_refs,
entity_df=entity_df,
registry=registry,
project=project,
)


Expand Down
19 changes: 15 additions & 4 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
RetrievalJob,
_get_requested_feature_views_to_features_dict,
)
from feast.registry import Registry
from feast.repo_config import RepoConfig


Expand Down Expand Up @@ -70,6 +71,8 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
# TODO: Add entity_df validation in order to fail before interacting with BigQuery

Expand All @@ -85,7 +88,9 @@ def get_historical_features(
)

# Build a query context containing all information required to template the BigQuery SQL query
query_context = get_feature_view_query_context(feature_refs, feature_views)
query_context = get_feature_view_query_context(
feature_refs, feature_views, registry, project
)

# TODO: Infer min_timestamp and max_timestamp from entity_df
# Generate the BigQuery SQL query from the query context
Expand Down Expand Up @@ -155,7 +160,10 @@ def _upload_entity_df_into_bigquery(project, entity_df) -> str:


def get_feature_view_query_context(
feature_refs: List[str], feature_views: List[FeatureView]
feature_refs: List[str],
feature_views: List[FeatureView],
registry: Registry,
project: str,
) -> List[FeatureViewQueryContext]:
"""Build a query context containing all information required to template a BigQuery point-in-time SQL query"""

Expand All @@ -165,7 +173,10 @@ def get_feature_view_query_context(

query_context = []
for feature_view, features in feature_views_to_feature_map.items():
entity_names = [entity for entity in feature_view.entities]
join_keys = []
for entity_name in feature_view.entities:
entity = registry.get_entity(entity_name, project)
join_keys.append(entity.join_key)

if isinstance(feature_view.ttl, timedelta):
ttl_seconds = int(feature_view.ttl.total_seconds())
Expand All @@ -177,7 +188,7 @@ def get_feature_view_query_context(
context = FeatureViewQueryContext(
name=feature_view.name,
ttl=ttl_seconds,
entities=entity_names,
entities=join_keys,
features=features,
table_ref=feature_view.input.table_ref,
event_timestamp_column=feature_view.input.event_timestamp_column,
Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ENTITY_DF_EVENT_TIMESTAMP_COL,
_get_requested_feature_views_to_features_dict,
)
from feast.registry import Registry
from feast.repo_config import RepoConfig


Expand All @@ -35,6 +36,8 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
registry: Registry,
project: str,
) -> FileRetrievalJob:
if not isinstance(entity_df, pd.DataFrame):
raise ValueError(
Expand Down Expand Up @@ -80,7 +83,11 @@ def evaluate_historical_retrieval():
)

# Build a list of entity columns to join on (from the right table)
right_entity_columns = [entity for entity in feature_view.entities]
join_keys = []
for entity_name in feature_view.entities:
entity = registry.get_entity(entity_name, project)
join_keys.append(entity.join_key)
right_entity_columns = join_keys
right_entity_key_columns = [
event_timestamp_column
] + right_entity_columns
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from feast.data_source import DataSource
from feast.feature_view import FeatureView
from feast.registry import Registry
from feast.repo_config import RepoConfig


Expand Down Expand Up @@ -63,5 +64,7 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
pass
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def get_historical_features(
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
registry: Registry,
project: str,
) -> RetrievalJob:
pass

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]
entities.append(Entity.from_proto(entity_proto))
return entities

def get_entity(self, name: str, project: str) -> Entity:
def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
"""
Retrieves an entity.
Expand All @@ -117,7 +117,7 @@ def get_entity(self, name: str, project: str) -> Entity:
Returns either the specified entity, or raises an exception if
none is found
"""
registry_proto = self._get_registry_proto()
registry_proto = self._get_registry_proto(allow_cache=allow_cache)
for entity_proto in registry_proto.entities:
if entity_proto.spec.name == name and entity_proto.spec.project == project:
return Entity.from_proto(entity_proto)
Expand Down
10 changes: 5 additions & 5 deletions sdk/python/tests/test_historical_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def stage_driver_hourly_stats_bigquery_source(df, table_id):
def create_driver_hourly_stats_feature_view(source):
driver_stats_feature_view = FeatureView(
name="driver_stats",
entities=["driver_id"],
entities=["driver"],
features=[
Feature(name="conv_rate", dtype=ValueType.FLOAT),
Feature(name="acc_rate", dtype=ValueType.FLOAT),
Expand Down Expand Up @@ -226,8 +226,8 @@ def test_historical_features_from_parquet_sources():
temp_dir, customer_df
)
customer_fv = create_customer_daily_profile_feature_view(customer_source)
driver = Entity(name="driver", value_type=ValueType.INT64)
customer = Entity(name="customer", value_type=ValueType.INT64)
driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
customer = Entity(name="customer_id", value_type=ValueType.INT64)

store = FeatureStore(
config=RepoConfig(
Expand Down Expand Up @@ -331,8 +331,8 @@ def test_historical_features_from_bigquery_sources(provider_type):
)
customer_fv = create_customer_daily_profile_feature_view(customer_source)

driver = Entity(name="driver", value_type=ValueType.INT64)
customer = Entity(name="customer", value_type=ValueType.INT64)
driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
customer = Entity(name="customer_id", value_type=ValueType.INT64)

if provider_type == "local":
store = FeatureStore(
Expand Down
Empty file.

0 comments on commit b60f6e4

Please sign in to comment.