Skip to content

Commit

Permalink
isolate ODFV list type feature test to smaller code changes
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff <jeffxl@apple.com>
  • Loading branch information
Agent007 committed Nov 12, 2021
1 parent 464e717 commit c655d1b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 18 deletions.
3 changes: 0 additions & 3 deletions sdk/python/feast/driver_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame
df_all_drivers["avg_daily_trips"] = np.random.randint(0, 1000, size=rows).astype(
np.int32
)
dummy_vector = [1.0, 0.0]
df_all_drivers["embedding_double"] = [dummy_vector] * rows
df_all_drivers["embedding_float"] = df_all_drivers["embedding_double"]
df_all_drivers["created"] = pd.to_datetime(pd.Timestamp.now(tz=None).round("ms"))

# Create duplicate rows that should be filtered by created timestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def customer():

def location():
return Entity(name="location_id", value_type=ValueType.INT64)


def item():
return Entity(name="item_id", value_type=ValueType.INT64)
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def create_similarity_request_data_source():
)


def create_item_embeddings_feature_view(source, infer_features: bool = False):
item_embeddings_feature_view = FeatureView(
name="item_embeddings",
entities=["item"],
features=None
if infer_features
else [
Feature(name="embedding_double", dtype=ValueType.DOUBLE_LIST),
Feature(name="embedding_float", dtype=ValueType.FLOAT_LIST),
],
batch_source=source,
ttl=timedelta(hours=2),
)
return item_embeddings_feature_view


def create_driver_hourly_stats_feature_view(source, infer_features: bool = False):
driver_stats_feature_view = FeatureView(
name="driver_stats",
Expand All @@ -138,8 +154,6 @@ def create_driver_hourly_stats_feature_view(source, infer_features: bool = False
Feature(name="conv_rate", dtype=ValueType.FLOAT),
Feature(name="acc_rate", dtype=ValueType.FLOAT),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
Feature(name="embedding_double", dtype=ValueType.DOUBLE_LIST),
Feature(name="embedding_float", dtype=ValueType.FLOAT_LIST),
],
batch_source=source,
ttl=timedelta(hours=2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,11 @@ def test_write_to_online_store(environment, universal_data_sources):
fs.apply([driver_hourly_stats, driver_entity])

# fake data to ingest into Online Store
dummy_vector = [1.0, 0.0]
data = {
"driver_id": [123],
"conv_rate": [0.85],
"acc_rate": [0.91],
"avg_daily_trips": [14],
"embedding_double": [dummy_vector],
"embedding_float": [dummy_vector],
"event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")],
"created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")],
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from datetime import datetime

import pandas as pd
import pytest

from feast import Feature, ValueType
from feast.errors import SpecifiedFeaturesNotPresentError
from tests.integration.feature_repos.universal.entities import customer, driver
from feast.infra.offline_stores.file_source import FileSource
from tests.integration.feature_repos.universal.entities import customer, driver, item
from tests.integration.feature_repos.universal.feature_views import (
conv_rate_plus_100_feature_view,
create_conv_rate_request_data_source,
create_driver_hourly_stats_feature_view,
create_item_embeddings_feature_view,
create_similarity_request_data_source,
similarity_feature_view,
)
Expand All @@ -29,19 +34,39 @@ def test_infer_odfv_features(environment, universal_data_sources, infer_features
infer_features=infer_features,
)

sim_odfv = similarity_feature_view(
{
"driver": driver_hourly_stats,
"input_request": create_similarity_request_data_source(),
},
infer_features=infer_features,
)

feast_objects = [driver_hourly_stats, driver_odfv, sim_odfv, driver(), customer()]
feast_objects = [driver_hourly_stats, driver_odfv, driver(), customer()]
store.apply(feast_objects)
odfv = store.get_on_demand_feature_view("conv_rate_plus_100")
assert len(odfv.features) == 3


@pytest.mark.integration
@pytest.mark.parametrize("infer_features", [True, False], ids=lambda v: str(v))
def test_infer_odfv_list_features(environment, infer_features, tmp_path):
fake_embedding = [1.0, 1.0]
items_df = pd.DataFrame(
data={
"item_id": [0],
"embedding_float": [fake_embedding],
"embedding_double": [fake_embedding],
"event_timestamp": [pd.Timestamp(datetime.utcnow())],
"created": [pd.Timestamp(datetime.utcnow())],
}
)
output_path = f"{tmp_path}/items.parquet"
items_df.to_parquet(output_path)
fake_items_src = FileSource(
path=output_path,
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
items = create_item_embeddings_feature_view(fake_items_src)
sim_odfv = similarity_feature_view(
{"items": items, "input_request": create_similarity_request_data_source()},
infer_features=infer_features,
)
store = environment.feature_store
store.apply([item(), items, sim_odfv])
odfv = store.get_on_demand_feature_view("similarity")
assert len(odfv.features) == 2

Expand Down

0 comments on commit c655d1b

Please sign in to comment.