Skip to content

Commit

Permalink
Optimize write rate in Gcp Firestore
Browse files Browse the repository at this point in the history
Signed-off-by: Tsotne Tabidze <tsotne@tecton.ai>
  • Loading branch information
Tsotne Tabidze committed Apr 13, 2021
1 parent a65800c commit eec8f5b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 71 deletions.
80 changes: 25 additions & 55 deletions sdk/python/feast/infra/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def online_write_batch(
) -> None:
client = self._initialize_client()

pool = ThreadPool(processes=10)
pool = ThreadPool(processes=40)
pool.map(
lambda b: _write_minibatch(client, project, table, b, progress),
_to_minibatches(data),
Expand Down Expand Up @@ -231,62 +231,32 @@ def _write_minibatch(
],
progress: Optional[Callable[[int], Any]],
):
from google.api_core.exceptions import Conflict
from google.cloud import datastore

num_retries_on_conflict = 3
row_count = 0
for retry_number in range(num_retries_on_conflict):
try:
row_count = 0
with client.transaction():
for entity_key, features, timestamp, created_ts in data:
document_id = compute_datastore_entity_id(entity_key)

key = client.key(
"Project", project, "Table", table.name, "Row", document_id,
)

entity = client.get(key)
if entity is not None:
if entity["event_ts"] > utils.make_tzaware(timestamp):
# Do not overwrite feature values computed from fresher data
continue
elif (
entity["event_ts"] == utils.make_tzaware(timestamp)
and created_ts is not None
and entity["created_ts"] is not None
and entity["created_ts"] > utils.make_tzaware(created_ts)
):
# Do not overwrite feature values computed from the same data, but
# computed later than this one
continue
else:
entity = datastore.Entity(key=key)

entity.update(
dict(
key=entity_key.SerializeToString(),
values={
k: v.SerializeToString() for k, v in features.items()
},
event_ts=utils.make_tzaware(timestamp),
created_ts=(
utils.make_tzaware(created_ts)
if created_ts is not None
else None
),
)
)
client.put(entity)
row_count += 1

if progress:
progress(1)
break # make sure to break out of retry loop if all went well
except Conflict:
if retry_number == num_retries_on_conflict - 1:
raise
entities = []
for entity_key, features, timestamp, created_ts in data:
document_id = compute_datastore_entity_id(entity_key)

key = client.key("Project", project, "Table", table.name, "Row", document_id,)

entity = datastore.Entity(key=key)

entity.update(
dict(
key=entity_key.SerializeToString(),
values={k: v.SerializeToString() for k, v in features.items()},
event_ts=utils.make_tzaware(timestamp),
created_ts=(
utils.make_tzaware(created_ts) if created_ts is not None else None
),
)
)
entities.append(entity)
with client.transaction():
client.put_multi(entities)

if progress:
progress(len(entities))


def _delete_all_values(client, key) -> None:
Expand Down
7 changes: 1 addition & 6 deletions sdk/python/feast/infra/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,14 @@ def online_write_batch(
f"""
UPDATE {_table_id(project, table)}
SET value = ?, event_ts = ?, created_ts = ?
WHERE (event_ts < ? OR (event_ts = ? AND (created_ts IS NULL OR ? IS NULL OR created_ts < ?)))
AND (entity_key = ? AND feature_name = ?)
WHERE (entity_key = ? AND feature_name = ?)
""",
(
# SET
val.SerializeToString(),
timestamp,
created_ts,
# WHERE
timestamp,
timestamp,
created_ts,
created_ts,
entity_key_bin,
feature_name,
),
Expand Down
14 changes: 9 additions & 5 deletions sdk/python/tests/online_read_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read):
event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")
)

""" Values with an older event_ts should not overwrite newer ones """
# Note: This behavior has changed for performance. We should test that older
# value can't overwrite over a newer value once we add the respective flag
""" Values with an older event_ts should overwrite newer ones """
time_2 = datetime.utcnow()
_driver_rw_test(
event_ts=time_1 - timedelta(hours=1),
created_ts=time_2,
write=(-1000, "OLD"),
expect_read=(1.1, "3.1"),
expect_read=(-1000, "OLD"),
)

""" Values with an new event_ts should overwrite older ones """
Expand All @@ -72,15 +74,17 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read):
expect_read=(1123, "NEWER"),
)

""" created_ts is used as a tie breaker, using older created_ts here so no overwrite """
# Note: This behavior has changed for performance. We should test that older
# value can't overwrite over a newer value once we add the respective flag
""" created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """
_driver_rw_test(
event_ts=time_1 + timedelta(hours=1),
created_ts=time_3 - timedelta(hours=1),
write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
expect_read=(1123, "NEWER"),
expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
)

""" created_ts is used as a tie breaker, using older created_ts here so no overwrite """
""" created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """
_driver_rw_test(
event_ts=time_1 + timedelta(hours=1),
created_ts=time_3 + timedelta(hours=1),
Expand Down
10 changes: 5 additions & 5 deletions sdk/python/tests/online_write_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from tqdm import tqdm

from feast.data_source import FileSource
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.entity import Entity
from feast.feature import Feature
from feast.feature_store import FeatureStore
from feast.feature_view import FeatureView
from feast.infra.provider import _convert_arrow_to_proto
from feast.repo_config import RepoConfig
from feast.value_type import ValueType
from tests.driver_test_data import create_driver_hourly_stats_df


def create_driver_hourly_stats_feature_view(source):
Expand Down Expand Up @@ -75,7 +75,9 @@ def benchmark_writes():

# Show the data for reference
print(data)
proto_data = _convert_arrow_to_proto(pa.Table.from_pandas(data), table)
proto_data = _convert_arrow_to_proto(
pa.Table.from_pandas(data), table, ["driver_id"]
)

# Write it
with tqdm(total=len(proto_data)) as progress:
Expand All @@ -86,9 +88,7 @@ def benchmark_writes():
progress=progress.update,
)

registry_tables = store._get_registry().list_feature_views(
project=store.project
)
registry_tables = store.list_feature_views()
provider.teardown_infra(store.project, tables=registry_tables)


Expand Down

0 comments on commit eec8f5b

Please sign in to comment.