Skip to content

Commit

Permalink
feat: Rewrite ibis point-in-time-join w/o feast abstractions (#4023)
Browse files Browse the repository at this point in the history
* feat: refactor ibis point-in-time-join

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

* fix formatting, linting

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

---------

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko authored Mar 30, 2024
1 parent afd52b8 commit 3980e0c
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 118 deletions.
248 changes: 132 additions & 116 deletions sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range(

return entity_df_event_timestamp_range

@staticmethod
def _get_historical_features_one(
feature_view: FeatureView,
entity_table: Table,
feature_refs: List[str],
full_feature_names: bool,
timestamp_range: Tuple,
acc_table: Table,
event_timestamp_col: str,
) -> Table:
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)

for old_name, new_name in feature_view.batch_source.field_mapping.items():
if old_name in fv_table.columns:
fv_table = fv_table.rename({new_name: old_name})

timestamp_field = feature_view.batch_source.timestamp_field

# TODO mutate only if tz-naive
fv_table = fv_table.mutate(
**{
timestamp_field: fv_table[timestamp_field].cast(
dt.Timestamp(timezone="UTC")
)
}
)

full_name_prefix = feature_view.projection.name_alias or feature_view.name

feature_refs = [
fr.split(":")[1]
for fr in feature_refs
if fr.startswith(f"{full_name_prefix}:")
]

timestamp_range_start_minus_ttl = (
timestamp_range[0] - feature_view.ttl
if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0)
else timestamp_range[0]
)

timestamp_range_start_minus_ttl = ibis.literal(
timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f")
).cast(dt.Timestamp(timezone="UTC"))

timestamp_range_end = ibis.literal(
timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f")
).cast(dt.Timestamp(timezone="UTC"))

fv_table = fv_table.filter(
ibis.and_(
fv_table[timestamp_field] <= timestamp_range_end,
fv_table[timestamp_field] >= timestamp_range_start_minus_ttl,
)
)

# join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}
# predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()]

if feature_view.projection.join_key_map:
predicates = [
fv_table[k] == entity_table[v]
for k, v in feature_view.projection.join_key_map.items()
]
else:
predicates = [
fv_table[e.name] == entity_table[e.name]
for e in feature_view.entity_columns
]

predicates.append(
fv_table[timestamp_field] <= entity_table[event_timestamp_col]
)

fv_table = fv_table.inner_join(
entity_table, predicates, lname="", rname="{name}_y"
)

fv_table = (
fv_table.group_by(by="entity_row_id")
.order_by(ibis.desc(fv_table[timestamp_field]))
.mutate(rn=ibis.row_number())
)

fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0))

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
fv_table = fv_table.select(select_cols)

if full_feature_names:
fv_table = fv_table.rename(
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
)

acc_table = acc_table.left_join(
fv_table,
predicates=[fv_table.entity_row_id == acc_table.entity_row_id],
lname="",
rname="{name}_yyyy",
)

acc_table = acc_table.drop(s.endswith("_yyyy"))

return acc_table

@staticmethod
def _to_utc(entity_df: pd.DataFrame, event_timestamp_col):
entity_df_event_timestamp = entity_df.loc[
Expand Down Expand Up @@ -228,30 +122,73 @@ def get_historical_features(
entity_schema=entity_schema,
)

# TODO get range with ibis
timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range(
entity_df, event_timestamp_col
)

entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col)

entity_table = ibis.memtable(entity_df)
entity_table = IbisOfflineStore._generate_row_id(
entity_table, feature_views, event_timestamp_col
)

res: Table = entity_table
def read_fv(feature_view, feature_refs, full_feature_names):
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)

for fv in feature_views:
res = IbisOfflineStore._get_historical_features_one(
fv,
entity_table,
for old_name, new_name in feature_view.batch_source.field_mapping.items():
if old_name in fv_table.columns:
fv_table = fv_table.rename({new_name: old_name})

timestamp_field = feature_view.batch_source.timestamp_field

# TODO mutate only if tz-naive
fv_table = fv_table.mutate(
**{
timestamp_field: fv_table[timestamp_field].cast(
dt.Timestamp(timezone="UTC")
)
}
)

full_name_prefix = feature_view.projection.name_alias or feature_view.name

feature_refs = [
fr.split(":")[1]
for fr in feature_refs
if fr.startswith(f"{full_name_prefix}:")
]

if full_feature_names:
fv_table = fv_table.rename(
{
f"{full_name_prefix}__{feature}": feature
for feature in feature_refs
}
)

feature_refs = [
f"{full_name_prefix}__{feature}" for feature in feature_refs
]

return (
fv_table,
feature_view.batch_source.timestamp_field,
feature_view.projection.join_key_map
or {e.name: e.name for e in feature_view.entity_columns},
feature_refs,
full_feature_names,
timestamp_range,
res,
event_timestamp_col,
feature_view.ttl,
)

res = res.drop("entity_row_id")
res = point_in_time_join(
entity_table=entity_table,
feature_tables=[
read_fv(feature_view, feature_refs, full_feature_names)
for feature_view in feature_views
],
event_timestamp_col=event_timestamp_col,
)

return IbisRetrievalJob(
res,
Expand Down Expand Up @@ -285,6 +222,10 @@ def pull_all_from_table_or_query(

table = table.select(*fields)

# TODO get rid of this fix
if "__log_date" in table.columns:
table = table.drop("__log_date")

table = table.filter(
ibis.and_(
table[timestamp_field] >= ibis.literal(start_date),
Expand Down Expand Up @@ -320,6 +261,7 @@ def write_logged_features(
else:
kwargs = {}

# TODO always write to directory
table.to_parquet(
f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs
)
Expand Down Expand Up @@ -405,3 +347,77 @@ def persist(
@property
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata


def point_in_time_join(
entity_table: Table,
feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col="event_timestamp",
):
# TODO handle ttl
all_entities = [event_timestamp_col]
for feature_table, timestamp_field, join_key_map, _, _ in feature_tables:
all_entities.extend(join_key_map.values())

r = ibis.literal("")

for e in set(all_entities):
r = r.concat(entity_table[e].cast("string")) # type: ignore

entity_table = entity_table.mutate(entity_row_id=r)

acc_table = entity_table

for (
feature_table,
timestamp_field,
join_key_map,
feature_refs,
ttl,
) in feature_tables:
predicates = [
feature_table[k] == entity_table[v] for k, v in join_key_map.items()
]

predicates.append(
feature_table[timestamp_field] <= entity_table[event_timestamp_col],
)

if ttl:
predicates.append(
feature_table[timestamp_field]
>= entity_table[event_timestamp_col] - ibis.literal(ttl)
)

feature_table = feature_table.inner_join(
entity_table, predicates, lname="", rname="{name}_y"
)

feature_table = feature_table.drop(s.endswith("_y"))

feature_table = (
feature_table.group_by(by="entity_row_id")
.order_by(ibis.desc(feature_table[timestamp_field]))
.mutate(rn=ibis.row_number())
)

feature_table = feature_table.filter(
feature_table["rn"] == ibis.literal(0)
).drop("rn")

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
feature_table = feature_table.select(select_cols)

acc_table = acc_table.left_join(
feature_table,
predicates=[feature_table.entity_row_id == acc_table.entity_row_id],
lname="",
rname="{name}_yyyy",
)

acc_table = acc_table.drop(s.endswith("_yyyy"))

acc_table = acc_table.drop("entity_row_id")

return acc_table
13 changes: 12 additions & 1 deletion sdk/python/requirements/py3.10-ci-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ docker==7.0.0
# testcontainers
docutils==0.19
# via sphinx
duckdb==0.10.1
# via
# duckdb-engine
# ibis-framework
duckdb-engine==0.11.2
# via ibis-framework
entrypoints==0.4
# via altair
exceptiongroup==1.2.0
Expand Down Expand Up @@ -310,7 +316,7 @@ httpx==0.27.0
# via
# feast (setup.py)
# jupyterlab
ibis-framework==8.0.0
ibis-framework[duckdb]==8.0.0
# via
# feast (setup.py)
# ibis-substrait
Expand Down Expand Up @@ -848,8 +854,13 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sqlalchemy[mypy]==1.4.52
# via
# duckdb-engine
# feast (setup.py)
# ibis-framework
# sqlalchemy
# sqlalchemy-views
sqlalchemy-views==0.3.2
# via ibis-framework
sqlalchemy2-stubs==0.0.2a38
# via sqlalchemy
sqlglot==20.11.0
Expand Down
13 changes: 12 additions & 1 deletion sdk/python/requirements/py3.9-ci-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ docker==7.0.0
# testcontainers
docutils==0.19
# via sphinx
duckdb==0.10.1
# via
# duckdb-engine
# ibis-framework
duckdb-engine==0.11.2
# via ibis-framework
entrypoints==0.4
# via altair
exceptiongroup==1.2.0
Expand Down Expand Up @@ -310,7 +316,7 @@ httpx==0.27.0
# via
# feast (setup.py)
# jupyterlab
ibis-framework==8.0.0
ibis-framework[duckdb]==8.0.0
# via
# feast (setup.py)
# ibis-substrait
Expand Down Expand Up @@ -858,8 +864,13 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sqlalchemy[mypy]==1.4.52
# via
# duckdb-engine
# feast (setup.py)
# ibis-framework
# sqlalchemy
# sqlalchemy-views
sqlalchemy-views==0.3.2
# via ibis-framework
sqlalchemy2-stubs==0.0.2a38
# via sqlalchemy
sqlglot==20.11.0
Expand Down
Loading

0 comments on commit 3980e0c

Please sign in to comment.