From 94d4f2bff6e965743e4a0e9db269c1f87a924571 Mon Sep 17 00:00:00 2001 From: Ritchie Martori Date: Sun, 22 Sep 2024 17:32:45 -0700 Subject: [PATCH] exec store cleanup and fixes --- fiftyone/factory/repos/execution_store.py | 50 ++++++++++++++++++----- tests/unittests/execution_store_tests.py | 47 +++++++-------------- 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/fiftyone/factory/repos/execution_store.py b/fiftyone/factory/repos/execution_store.py index 7f2c4aa117..69fa78b077 100644 --- a/fiftyone/factory/repos/execution_store.py +++ b/fiftyone/factory/repos/execution_store.py @@ -2,6 +2,7 @@ Execution store repository. """ +import datetime from pymongo.collection import Collection from fiftyone.operators.store.models import StoreDocument, KeyDocument @@ -35,18 +36,36 @@ def list_stores(self) -> list[str]: return self._collection.distinct("store_name") def set_key(self, store_name, key, value, ttl=None) -> KeyDocument: - """Sets or updates a key in the specified store.""" + """Sets or updates a key in the specified store""" + now = datetime.datetime.now() expiration = KeyDocument.get_expiration(ttl) key_doc = KeyDocument( - store_name=store_name, - key=key, - value=value, - expires_at=expiration if ttl else None, + store_name=store_name, key=key, value=value, updated_at=now ) - # Update or insert the key - self._collection.update_one( - _where(store_name, key), {"$set": key_doc.dict()}, upsert=True + + # Prepare the update operations + update_fields = { + "$set": key_doc.dict( + exclude={"created_at", "expires_at", "store_name", "key"} + ), + "$setOnInsert": { + "store_name": store_name, + "key": key, + "created_at": now, + "expires_at": expiration if ttl else None, + }, + } + + # Perform the upsert operation + result = self._collection.update_one( + _where(store_name, key), update_fields, upsert=True ) + + if result.upserted_id: + key_doc.created_at = now + else: + key_doc.updated_at = now + return key_doc def get_key(self, store_name, key) -> KeyDocument: @@ -57,8 +76,7 @@ def get_key(self, store_name, key) -> KeyDocument: def list_keys(self, store_name) -> list[str]: """Lists all keys in the specified store.""" - keys = self._collection.find(_where(store_name)) - # TODO: redact non-key fields + keys = self._collection.find(_where(store_name), {"key": 1}) return [key["key"] for key in keys] def update_ttl(self, store_name, key, ttl) -> bool: @@ -92,7 +110,19 @@ def __init__(self, collection: Collection): def _create_indexes(self): indices = self._collection.list_indexes() expires_at_name = "expires_at" + store_name_name = "store_name" + key_name = "key" + full_key_name = "store_name_and_key" if expires_at_name not in indices: self._collection.create_index( expires_at_name, name=expires_at_name, expireAfterSeconds=0 ) + if full_key_name not in indices: + self._collection.create_index( + [(store_name_name, 1), (key_name, 1)], + name=full_key_name, + unique=True, + ) + for name in [store_name_name, key_name]: + if name not in indices: + self._collection.create_index(name, name=name) diff --git a/tests/unittests/execution_store_tests.py b/tests/unittests/execution_store_tests.py index d5b145ba63..470d399026 100644 --- a/tests/unittests/execution_store_tests.py +++ b/tests/unittests/execution_store_tests.py @@ -64,13 +64,15 @@ def test_set_key(self): {"store_name": "widgets", "key": "widget_1"}, { "$set": { + "value": {"name": "Widget One", "value": 100}, + "updated_at": IsDateTime(), + }, + "$setOnInsert": { "store_name": "widgets", "key": "widget_1", - "value": {"name": "Widget One", "value": 100}, "created_at": IsDateTime(), - "updated_at": None, "expires_at": IsDateTime(), - } + }, }, upsert=True, ) @@ -140,7 +142,9 @@ def test_list_keys(self): keys = self.store_repo.list_keys("widgets") assert keys == ["widget_1", "widget_2"] self.mock_collection.find.assert_called_once() - self.mock_collection.find.assert_called_with({"store_name": "widgets"}) + self.mock_collection.find.assert_called_with( + {"store_name": "widgets"}, {"key": 1} + ) def test_list_stores(self): self.mock_collection.distinct.return_value = ["widgets", "gadgets"] @@ -166,42 +170,19 @@ def test_set(self): {"store_name": "mock_store", "key": "widget_1"}, { "$set": { + "updated_at": IsDateTime(), + "value": {"name": "Widget One", "value": 100}, + }, + "$setOnInsert": { "store_name": "mock_store", "key": "widget_1", - "value": {"name": "Widget One", "value": 100}, "created_at": IsDateTime(), - "updated_at": None, "expires_at": IsDateTime(), - } + }, }, upsert=True, ) - # def test_update(self): - # self.mock_collection.find_one.return_value = { - # "store_name": "mock_store", - # "key": "widget_1", - # "value": {"name": "Widget One", "value": 100}, - # "created_at": time.time(), - # "updated_at": time.time(), - # "expires_at": time.time() + 60000 - # } - # self.store.update_key("widget_1", {"name": "Widget One", "value": 200}) - # self.mock_collection.update_one.assert_called_once() - # self.mock_collection.update_one.assert_called_with( - # {"store_name": "mock_store", "key": "widget_1"}, - # { - # "$set": { - # "store_name": "mock_store", - # "key": "widget_1", - # "value": {"name": "Widget One", "value": 200}, - # "created_at": IsDateTime(), - # "updated_at": IsDateTime(), - # "expires_at": IsDateTime() - # } - # } - # ) - def test_get(self): self.mock_collection.find_one.return_value = { "store_name": "mock_store", @@ -227,7 +208,7 @@ def test_list_keys(self): assert keys == ["widget_1", "widget_2"] self.mock_collection.find.assert_called_once() self.mock_collection.find.assert_called_with( - {"store_name": "mock_store"} + {"store_name": "mock_store"}, {"key": 1} ) def test_delete(self):