Skip to content

Commit

Permalink
ensure float list types in ODFV UDFs can be appied
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff <jeffxl@apple.com>
  • Loading branch information
Agent007 committed Nov 12, 2021
1 parent af3dc6b commit 464e717
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
3 changes: 2 additions & 1 deletion sdk/python/feast/driver_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame
np.int32
)
dummy_vector = [1.0, 0.0]
df_all_drivers["embedding"] = [dummy_vector] * rows
df_all_drivers["embedding_double"] = [dummy_vector] * rows
df_all_drivers["embedding_float"] = df_all_drivers["embedding_double"]
df_all_drivers["created"] = pd.to_datetime(pd.Timestamp.now(tz=None).round("ms"))

# Create duplicate rows that should be filtered by created timestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,18 @@ def conv_rate_plus_100_feature_view(

def similarity(features_df: pd.DataFrame) -> pd.DataFrame:
if features_df.size == 0:
return pd.DataFrame({"cos": [0.0]}) # give hint to Feast about return type
vectors_a = features_df["embedding"].apply(np.array)
vectors_b = features_df["vector"].apply(np.array)
# give hint to Feast about return type
df = pd.DataFrame({"cos_double": [0.0]})
df["cos_float"] = df["cos_double"].astype(np.float32)
return df
vectors_a = features_df["embedding_double"].apply(np.array)
vectors_b = features_df["vector_double"].apply(np.array)
dot_products = vectors_a.mul(vectors_b).apply(sum)
norms_q = vectors_a.apply(np.linalg.norm)
norms_doc = vectors_b.apply(np.linalg.norm)
df = pd.DataFrame()
df["cos"] = dot_products / (norms_q * norms_doc)
df["cos_double"] = dot_products / (norms_q * norms_doc)
df["cos_float"] = df["cos_double"].astype(np.float32)
return df


Expand All @@ -88,7 +92,8 @@ def similarity_feature_view(
features: Optional[List[Feature]] = None,
) -> OnDemandFeatureView:
_features = features or [
Feature("cos", ValueType.DOUBLE),
Feature("cos_double", ValueType.DOUBLE),
Feature("cos_float", ValueType.FLOAT),
]
return OnDemandFeatureView(
name=similarity.__name__,
Expand All @@ -115,7 +120,11 @@ def create_conv_rate_request_data_source():

def create_similarity_request_data_source():
return RequestDataSource(
name="similarity_input", schema={"vector": ValueType.DOUBLE_LIST}
name="similarity_input",
schema={
"vector_double": ValueType.DOUBLE_LIST,
"vector_float": ValueType.FLOAT_LIST,
},
)


Expand All @@ -129,7 +138,8 @@ def create_driver_hourly_stats_feature_view(source, infer_features: bool = False
Feature(name="conv_rate", dtype=ValueType.FLOAT),
Feature(name="acc_rate", dtype=ValueType.FLOAT),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
Feature(name="embedding", dtype=ValueType.DOUBLE_LIST),
Feature(name="embedding_double", dtype=ValueType.DOUBLE_LIST),
Feature(name="embedding_float", dtype=ValueType.FLOAT_LIST),
],
batch_source=source,
ttl=timedelta(hours=2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def test_write_to_online_store(environment, universal_data_sources):
"conv_rate": [0.85],
"acc_rate": [0.91],
"avg_daily_trips": [14],
"embedding": [dummy_vector],
"embedding_double": [dummy_vector],
"embedding_float": [dummy_vector],
"event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")],
"created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_infer_odfv_features(environment, universal_data_sources, infer_features
assert len(odfv.features) == 3

odfv = store.get_on_demand_feature_view("similarity")
assert len(odfv.features) == 1
assert len(odfv.features) == 2


@pytest.mark.integration
Expand Down

0 comments on commit 464e717

Please sign in to comment.