Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple entities in Redshift #1850

Merged
merged 7 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ def _upload_entity_df_and_get_entity_schema(
{{entity_df_event_timestamp_col}} AS entity_timestamp
{% for featureview in featureviews %}
{% if featureview.entities %}
,CONCAT(
,(
{% for entity in featureview.entities %}
CAST({{entity}} AS VARCHAR),
CAST({{entity}} as VARCHAR) ||
{% endfor %}
CAST({{entity_df_event_timestamp_col}} AS VARCHAR)
) AS {{featureview.name}}__entity_row_unique_id
Expand Down
13 changes: 4 additions & 9 deletions sdk/python/feast/infra/online_stores/helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import importlib
import struct
from typing import Any
from typing import Any, List

import mmh3

from feast import errors
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.storage.Redis_pb2 import RedisKeyV2 as RedisKeyProto
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto


Expand Down Expand Up @@ -37,13 +36,9 @@ def get_online_store_from_config(online_store_config: Any,) -> OnlineStore:
return online_store_class()


def _redis_key(project: str, entity_key: EntityKeyProto):
redis_key = RedisKeyProto(
project=project,
entity_names=entity_key.join_keys,
entity_values=entity_key.entity_values,
)
return redis_key.SerializeToString()
def _redis_key(project: str, entity_key: EntityKeyProto) -> bytes:
key: List[bytes] = [serialize_entity_key(entity_key), project.encode("utf-8")]
return b"".join(key)


def _mmh3(key: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
create_customer_daily_profile_feature_view,
create_driver_hourly_stats_feature_view,
create_global_stats_feature_view,
create_order_feature_view,
)


Expand Down Expand Up @@ -94,17 +95,19 @@ def construct_universal_datasets(
orders_df = driver_test_data.create_orders_df(
customers=entities["customer"],
drivers=entities["driver"],
start_date=end_time - timedelta(days=3),
end_date=end_time + timedelta(days=3),
start_date=start_time,
end_date=end_time,
order_count=20,
)
global_df = driver_test_data.create_global_daily_stats_df(start_time, end_time)
entity_df = orders_df[["customer_id", "driver_id", "order_id", "event_timestamp"]]

return {
"customer": customer_df,
"driver": driver_df,
"orders": orders_df,
"global": global_df,
"entity": entity_df,
}


Expand All @@ -127,7 +130,7 @@ def construct_universal_data_sources(
datasets["orders"],
destination_name="orders",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
created_timestamp_column=None,
)
global_ds = data_source_creator.create_data_source(
datasets["global"],
Expand Down Expand Up @@ -161,6 +164,7 @@ def construct_universal_feature_views(
"input_request": create_conv_rate_request_data_source(),
}
),
"order": create_order_feature_view(data_sources["orders"]),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,15 @@ def create_global_stats_feature_view(source, infer_features: bool = False):
ttl=timedelta(days=2),
)
return global_stats_feature_view


def create_order_feature_view(source, infer_features: bool = False):
return FeatureView(
name="order",
entities=["driver", "customer_id"],
features=None
if infer_features
else [Feature(name="order_is_success", dtype=ValueType.INT32)],
batch_source=source,
ttl=timedelta(days=2),
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,14 +37,23 @@ def find_asof_record(
ts_key: str,
ts_start: datetime,
ts_end: datetime,
filter_key: str = "",
filter_value: Any = None,
filter_keys: Optional[List[str]] = None,
filter_values: Optional[List[Any]] = None,
) -> Dict[str, Any]:
filter_keys = filter_keys or []
filter_values = filter_values or []
assert len(filter_keys) == len(filter_values)
found_record = {}
for record in records:
if (
not filter_key or record[filter_key] == filter_value
) and ts_start <= record[ts_key] <= ts_end:
all(
[
record[filter_key] == filter_value
for filter_key, filter_value in zip(filter_keys, filter_values)
]
)
and ts_start <= record[ts_key] <= ts_end
):
if not found_record or found_record[ts_key] < record[ts_key]:
found_record = record
return found_record
Expand All @@ -55,43 +64,57 @@ def get_expected_training_df(
customer_fv: FeatureView,
driver_df: pd.DataFrame,
driver_fv: FeatureView,
orders_df: pd.DataFrame,
order_fv: FeatureView,
global_df: pd.DataFrame,
global_fv: FeatureView,
orders_df: pd.DataFrame,
entity_df: pd.DataFrame,
event_timestamp: str,
full_feature_names: bool = False,
):
# Convert all pandas dataframes into records with UTC timestamps
order_records = convert_timestamp_records_to_utc(
orders_df.to_dict("records"), event_timestamp
customer_records = convert_timestamp_records_to_utc(
customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column
)
driver_records = convert_timestamp_records_to_utc(
driver_df.to_dict("records"), driver_fv.batch_source.event_timestamp_column
)
customer_records = convert_timestamp_records_to_utc(
customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column
order_records = convert_timestamp_records_to_utc(
orders_df.to_dict("records"), event_timestamp
)
global_records = convert_timestamp_records_to_utc(
global_df.to_dict("records"), global_fv.batch_source.event_timestamp_column
)
entity_rows = convert_timestamp_records_to_utc(
entity_df.to_dict("records"), event_timestamp
)

# Manually do point-in-time join of orders to drivers and customers records
for order_record in order_records:
# Manually do point-in-time join of driver, customer, and order records against
# the entity df
for entity_row in entity_rows:
customer_record = find_asof_record(
customer_records,
ts_key=customer_fv.batch_source.event_timestamp_column,
ts_start=entity_row[event_timestamp] - customer_fv.ttl,
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id"],
filter_values=[entity_row["customer_id"]],
)
driver_record = find_asof_record(
driver_records,
ts_key=driver_fv.batch_source.event_timestamp_column,
ts_start=order_record[event_timestamp] - driver_fv.ttl,
ts_end=order_record[event_timestamp],
filter_key="driver_id",
filter_value=order_record["driver_id"],
ts_start=entity_row[event_timestamp] - driver_fv.ttl,
ts_end=entity_row[event_timestamp],
filter_keys=["driver_id"],
filter_values=[entity_row["driver_id"]],
)
customer_record = find_asof_record(
customer_records,
order_record = find_asof_record(
order_records,
ts_key=customer_fv.batch_source.event_timestamp_column,
ts_start=order_record[event_timestamp] - customer_fv.ttl,
ts_end=order_record[event_timestamp],
filter_key="customer_id",
filter_value=order_record["customer_id"],
ts_start=entity_row[event_timestamp] - order_fv.ttl,
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id", "driver_id"],
filter_values=[entity_row["customer_id"], entity_row["driver_id"]],
)
global_record = find_asof_record(
global_records,
Expand All @@ -100,15 +123,7 @@ def get_expected_training_df(
ts_end=order_record[event_timestamp],
)

order_record.update(
{
(f"driver_stats__{k}" if full_feature_names else k): driver_record.get(
k, None
)
for k in ("conv_rate", "avg_daily_trips")
}
)
order_record.update(
entity_row.update(
{
(
f"customer_profile__{k}" if full_feature_names else k
Expand All @@ -120,7 +135,21 @@ def get_expected_training_df(
)
}
)
order_record.update(
entity_row.update(
{
(f"driver_stats__{k}" if full_feature_names else k): driver_record.get(
k, None
)
for k in ("conv_rate", "avg_daily_trips")
}
)
entity_row.update(
{
(f"order__{k}" if full_feature_names else k): order_record.get(k, None)
for k in ("order_is_success",)
}
)
entity_row.update(
{
(f"global_stats__{k}" if full_feature_names else k): global_record.get(
k, None
Expand All @@ -130,7 +159,7 @@ def get_expected_training_df(
)

# Convert records back to pandas dataframe
expected_df = pd.DataFrame(order_records)
expected_df = pd.DataFrame(entity_rows)

# Move "event_timestamp" column to front
current_cols = expected_df.columns.tolist()
Expand All @@ -140,7 +169,7 @@ def get_expected_training_df(
# Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects.
if full_feature_names:
expected_column_types = {
"order_is_success": "int32",
"order__order_is_success": "int32",
"driver_stats__conv_rate": "float32",
"customer_profile__current_balance": "float32",
"customer_profile__avg_passenger_count": "float32",
Expand Down Expand Up @@ -175,20 +204,23 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
(entities, datasets, data_sources) = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

customer_df, driver_df, orders_df, global_df = (
customer_df, driver_df, orders_df, global_df, entity_df = (
datasets["customer"],
datasets["driver"],
datasets["orders"],
datasets["global"],
datasets["entity"],
)
orders_df_with_request_data = orders_df.copy(deep=True)
orders_df_with_request_data["val_to_add"] = [
i for i in range(len(orders_df_with_request_data))
entity_df_with_request_data = entity_df.copy(deep=True)
entity_df_with_request_data["val_to_add"] = [
i for i in range(len(entity_df_with_request_data))
]
customer_fv, driver_fv, driver_odfv, global_fv = (

customer_fv, driver_fv, driver_odfv, order_fv, global_fv = (
feature_views["customer"],
feature_views["driver"],
feature_views["driver_odfv"],
feature_views["order"],
feature_views["global"],
)

Expand All @@ -203,6 +235,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
customer_fv,
driver_fv,
driver_odfv,
order_fv,
global_fv,
driver(),
customer(),
Expand All @@ -214,7 +247,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
entity_df_query = None
orders_table = table_name_from_data_source(data_sources["orders"])
if orders_table:
entity_df_query = f"SELECT * FROM {orders_table}"
entity_df_query = f"SELECT customer_id, driver_id, order_id, event_timestamp FROM {orders_table}"

event_timestamp = (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
Expand All @@ -226,9 +259,11 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
customer_fv,
driver_df,
driver_fv,
orders_df,
order_fv,
global_df,
global_fv,
orders_df_with_request_data,
entity_df_with_request_data,
event_timestamp,
full_feature_names,
)
Expand All @@ -242,6 +277,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
"order:order_is_success",
"global_stats:num_rides",
"global_stats:avg_ride_length",
],
Expand Down Expand Up @@ -297,7 +333,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
assert_frame_equal(expected_df_query, df_from_sql_entities)

job_from_df = store.get_historical_features(
entity_df=orders_df_with_request_data,
entity_df=entity_df_with_request_data,
features=[
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
Expand All @@ -306,6 +342,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
"customer_profile:lifetime_trip_count",
"conv_rate_plus_100:conv_rate_plus_100",
"conv_rate_plus_100:conv_rate_plus_val_to_add",
"order:order_is_success",
"global_stats:num_rides",
"global_stats:avg_ride_length",
],
Expand Down Expand Up @@ -341,7 +378,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
store,
feature_service,
full_feature_names,
orders_df_with_request_data,
entity_df_with_request_data,
expected_df,
event_timestamp,
)
Expand All @@ -361,7 +398,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
# If request data is missing that's needed for on demand transform, throw an error
with pytest.raises(RequestDataNotFoundInEntityDfException):
store.get_historical_features(
entity_df=orders_df,
entity_df=entity_df,
features=[
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
Expand All @@ -388,11 +425,11 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:


def assert_feature_service_correctness(
store, feature_service, full_feature_names, orders_df, expected_df, event_timestamp
store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp
):

job_from_df = store.get_historical_features(
entity_df=orders_df,
entity_df=entity_df,
features=feature_service,
full_feature_names=full_feature_names,
)
Expand Down
Loading