Skip to content

Commit

Permalink
firestore perf improvements + benchmark script (#1411)
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg Avdeev <oleg.v.avdeev@gmail.com>
  • Loading branch information
oavdeev authored Mar 26, 2021
1 parent 960f1ed commit 95c50ad
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 44 deletions.
4 changes: 3 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def _materialize_single_feature_view(
rows_to_write = _convert_arrow_to_proto(table, feature_view)

provider = self._get_provider()
provider.online_write_batch(self.config.project, feature_view, rows_to_write)
provider.online_write_batch(
self.config.project, feature_view, rows_to_write, None
)

feature_view.materialization_intervals.append((start_date, end_date))
self.apply([feature_view])
Expand Down
136 changes: 95 additions & 41 deletions sdk/python/feast/infra/gcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import itertools
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, Union
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import mmh3
from pytz import utc
Expand All @@ -24,7 +26,6 @@ def _delete_all_values(client, key) -> None:
return

for entity in entities:
print("Deleting: {}".format(entity))
client.delete(entity.key)


Expand Down Expand Up @@ -110,48 +111,15 @@ def online_write_batch(
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
from google.cloud import datastore

client = self._initialize_client()

for entity_key, features, timestamp, created_ts in data:
document_id = compute_datastore_entity_id(entity_key)

key = client.key(
"Project", project, "Table", table.name, "Row", document_id,
)
with client.transaction():
entity = client.get(key)
if entity is not None:
if entity["event_ts"] > _make_tzaware(timestamp):
# Do not overwrite feature values computed from fresher data
continue
elif (
entity["event_ts"] == _make_tzaware(timestamp)
and created_ts is not None
and entity["created_ts"] is not None
and entity["created_ts"] > _make_tzaware(created_ts)
):
# Do not overwrite feature values computed from the same data, but
# computed later than this one
continue
else:
entity = datastore.Entity(key=key)

entity.update(
dict(
key=entity_key.SerializeToString(),
values={k: v.SerializeToString() for k, v in features.items()},
event_ts=_make_tzaware(timestamp),
created_ts=(
_make_tzaware(created_ts)
if created_ts is not None
else None
),
)
)
client.put(entity)
pool = ThreadPool(processes=10)
pool.map(
lambda b: _write_minibatch(client, project, table, b, progress),
_to_minibatches(data),
)

def online_read(
self,
Expand All @@ -178,3 +146,89 @@ def online_read(
else:
result.append((None, None))
return result


ProtoBatch = Sequence[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
]


def _to_minibatches(data: ProtoBatch, batch_size=50) -> Iterator[ProtoBatch]:
"""
Split data into minibatches, making sure we stay under GCP datastore transaction size
limits.
"""
iterable = iter(data)

while True:
batch = list(itertools.islice(iterable, batch_size))
if len(batch) > 0:
yield batch
else:
break


def _write_minibatch(
client,
project: str,
table: Union[FeatureTable, FeatureView],
data: Sequence[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
):
from google.api_core.exceptions import Conflict
from google.cloud import datastore

num_retries_on_conflict = 3
row_count = 0
for retry_number in range(num_retries_on_conflict):
try:
row_count = 0
with client.transaction():
for entity_key, features, timestamp, created_ts in data:
document_id = compute_datastore_entity_id(entity_key)

key = client.key(
"Project", project, "Table", table.name, "Row", document_id,
)

entity = client.get(key)
if entity is not None:
if entity["event_ts"] > _make_tzaware(timestamp):
# Do not overwrite feature values computed from fresher data
continue
elif (
entity["event_ts"] == _make_tzaware(timestamp)
and created_ts is not None
and entity["created_ts"] is not None
and entity["created_ts"] > _make_tzaware(created_ts)
):
# Do not overwrite feature values computed from the same data, but
# computed later than this one
continue
else:
entity = datastore.Entity(key=key)

entity.update(
dict(
key=entity_key.SerializeToString(),
values={
k: v.SerializeToString() for k, v in features.items()
},
event_ts=_make_tzaware(timestamp),
created_ts=(
_make_tzaware(created_ts)
if created_ts is not None
else None
),
)
)
client.put(entity)
row_count += 1

if progress:
progress(1)
except Conflict:
if retry_number == num_retries_on_conflict - 1:
raise
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/local_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import pytz

Expand Down Expand Up @@ -66,6 +66,7 @@ def online_write_batch(
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
conn = self._get_conn()

Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from feast import FeatureTable, FeatureView
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
Expand Down Expand Up @@ -50,6 +50,7 @@ def online_write_batch(
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
"""
Write a batch of feature rows to the online store. This is a low level interface, not
Expand All @@ -63,6 +64,8 @@ def online_write_batch(
data: a list of quadruplets containing Feature data. Each quadruplet contains an Entity Key,
a dict containing feature values, an event timestamp for the row, and
the created timestamp for the row if it exists.
progress: Optional function to be called once every mini-batch of rows is written to
the online store. Can be used to display progress.
"""
...

Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/cli/online_read_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read):
created_ts,
)
],
progress=None,
)

read_rows = provider.online_read(
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/cli/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_basic(self) -> None:
datetime.utcnow(),
)
],
progress=None,
)

provider.online_write_batch(
Expand All @@ -56,6 +57,7 @@ def test_basic(self) -> None:
datetime.utcnow(),
)
],
progress=None,
)

# Retrieve two features using two keys, one valid one non-existing
Expand Down
95 changes: 95 additions & 0 deletions sdk/python/tests/online_write_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import random
import string
import tempfile
from datetime import datetime, timedelta

import click
import pyarrow as pa
from tqdm import tqdm

from feast.data_source import FileSource
from feast.entity import Entity
from feast.feature import Feature
from feast.feature_store import FeatureStore, _convert_arrow_to_proto
from feast.feature_view import FeatureView
from feast.repo_config import RepoConfig
from feast.value_type import ValueType
from tests.driver_test_data import create_driver_hourly_stats_df


def create_driver_hourly_stats_feature_view(source):
driver_stats_feature_view = FeatureView(
name="driver_stats",
entities=["driver_id"],
features=[
Feature(name="conv_rate", dtype=ValueType.FLOAT),
Feature(name="acc_rate", dtype=ValueType.FLOAT),
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
],
input=source,
ttl=timedelta(hours=2),
)
return driver_stats_feature_view


def create_driver_hourly_stats_source(parquet_path):
return FileSource(
path=parquet_path,
event_timestamp_column="datetime",
created_timestamp_column="created",
)


@click.command(name="run")
def benchmark_writes():
project_id = "test" + "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
)

with tempfile.TemporaryDirectory() as temp_dir:
store = FeatureStore(
config=RepoConfig(
metadata_store=os.path.join(temp_dir, "metadata.db"),
project=project_id,
provider="gcp",
)
)

# This is just to set data source to something, we're not reading from parquet source here.
parquet_path = os.path.join(temp_dir, "data.parquet")

driver = Entity(name="driver_id", value_type=ValueType.INT64, description="")
table = create_driver_hourly_stats_feature_view(
create_driver_hourly_stats_source(parquet_path=parquet_path)
)
store.apply([table, driver])

provider = store._get_provider()

end_date = datetime.utcnow()
start_date = end_date - timedelta(days=14)
customers = list(range(100))
data = create_driver_hourly_stats_df(customers, start_date, end_date)

# Show the data for reference
print(data)
proto_data = _convert_arrow_to_proto(pa.Table.from_pandas(data), table)

# Write it
with tqdm(total=len(proto_data)) as progress:
provider.online_write_batch(
project=store.project,
table=table,
data=proto_data,
progress=progress.update,
)

registry_tables = store._get_registry().list_feature_views(
project=store.project
)
provider.teardown_infra(store.project, tables=registry_tables)


if __name__ == "__main__":
benchmark_writes()

0 comments on commit 95c50ad

Please sign in to comment.