diff --git a/odmantic/__init__.py b/odmantic/__init__.py index 2bd83c4a..cb7f6f0c 100644 --- a/odmantic/__init__.py +++ b/odmantic/__init__.py @@ -1,10 +1,18 @@ from .bson import ObjectId -from .engine import AIOEngine +from .engine import AIOEngine, SyncEngine from .field import Field from .model import EmbeddedModel, Model from .reference import Reference -__all__ = ["AIOEngine", "Model", "EmbeddedModel", "Field", "Reference", "ObjectId"] +__all__ = [ + "AIOEngine", + "Model", + "EmbeddedModel", + "Field", + "Reference", + "ObjectId", + "SyncEngine", +] # Cleanest way to handle version changes with poetry while not hardcoding the version # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 diff --git a/odmantic/engine.py b/odmantic/engine.py index 5261009f..0f8e87fc 100644 --- a/odmantic/engine.py +++ b/odmantic/engine.py @@ -8,6 +8,8 @@ Dict, Generator, Generic, + Iterable, + Iterator, List, Optional, Sequence, @@ -18,27 +20,37 @@ cast, ) -from motor.motor_asyncio import ( - AsyncIOMotorClient, - AsyncIOMotorClientSession, - AsyncIOMotorCollection, - AsyncIOMotorCursor, -) from pydantic.utils import lenient_issubclass +from pymongo import MongoClient +from pymongo.client_session import ClientSession +from pymongo.collection import Collection +from pymongo.command_cursor import CommandCursor +from pymongo.database import Database from odmantic.exceptions import DocumentNotFoundError from odmantic.field import FieldProxy, ODMReference from odmantic.model import Model from odmantic.query import QueryExpression, SortExpression, and_ +try: + import motor + from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorClientSession, + AsyncIOMotorCollection, + AsyncIOMotorCursor, + AsyncIOMotorDatabase, + ) +except ImportError: # pragma: no cover + motor = None + + ModelType = TypeVar("ModelType", bound=Model) SortExpressionType = Optional[Union[FieldProxy, Tuple[FieldProxy]]] -class AIOCursor( - Generic[ModelType], AsyncIterable[ModelType], Awaitable[List[ModelType]] -): +class BaseCursor(Generic[ModelType]): """This object has to be built from the [odmantic.engine.AIOEngine.find][] method. An AIOCursor object support multiple async operations: @@ -47,9 +59,13 @@ class AIOCursor( - **await** : when awaited it will return a list of the fetched models """ - def __init__(self, model: Type[ModelType], motor_cursor: AsyncIOMotorCursor): + def __init__( + self, + model: Type[ModelType], + cursor: Union["AsyncIOMotorCursor", "CommandCursor"], + ): self._model = model - self._motor_cursor = motor_cursor + self._cursor = cursor self._results: Optional[List[ModelType]] = None def _parse_document(self, raw_doc: Dict) -> ModelType: @@ -57,10 +73,27 @@ def _parse_document(self, raw_doc: Dict) -> ModelType: object.__setattr__(instance, "__fields_modified__", set()) return instance + +class AIOCursor( + BaseCursor[ModelType], AsyncIterable[ModelType], Awaitable[List[ModelType]] +): + """This object has to be built from the [odmantic.engine.AIOEngine.find][] method. + + An AIOCursor object support multiple async operations: + + - **async for**: asynchronously iterate over the query results + - **await** : when awaited it will return a list of the fetched models + """ + + _cursor: "AsyncIOMotorCursor" + + def __init__(self, model: Type[ModelType], cursor: "AsyncIOMotorCursor"): + super().__init__(model=model, cursor=cursor) + def __await__(self) -> Generator[None, None, List[ModelType]]: if self._results is not None: return self._results - raw_docs = yield from self._motor_cursor.to_list(length=None).__await__() + raw_docs = yield from self._cursor.to_list(length=None).__await__() instances = [] for raw_doc in raw_docs: instances.append(self._parse_document(raw_doc)) @@ -74,33 +107,53 @@ async def __aiter__(self) -> AsyncGenerator[ModelType, None]: yield res return results = [] - async for raw_doc in self._motor_cursor: + async for raw_doc in self._cursor: instance = self._parse_document(raw_doc) results.append(instance) yield instance self._results = results -_FORBIDDEN_DATABASE_CHARACTERS = set(("/", "\\", ".", '"', "$")) +class SyncCursor(BaseCursor[ModelType], Iterable[ModelType]): + """This object has to be built from the [odmantic.engine.SyncEngine.find][] method. + A SyncCursor object supports iterating over the query results using **`for`**. -class AIOEngine: - """The AIOEngine object is responsible for handling database operations with MongoDB - in an asynchronous way using motor. + To get a list of all the results you can wrap it with `list`, as in `list(cursor)`. """ - def __init__(self, motor_client: AsyncIOMotorClient = None, database: str = "test"): - """Engine constructor. + _cursor: "CommandCursor" - Args: - motor_client: instance of an AsyncIO motor client. If None, a default one - will be created - database: name of the database to use + def __init__(self, model: Type[ModelType], cursor: "CommandCursor"): + super().__init__(model=model, cursor=cursor) - - """ + def __iter__(self) -> Iterator[ModelType]: + if self._results is not None: + for res in self._results: + yield res + return + results = [] + for raw_doc in self._cursor: + instance = self._parse_document(raw_doc) + results.append(instance) + yield instance + self._results = results + + +_FORBIDDEN_DATABASE_CHARACTERS = set(("/", "\\", ".", '"', "$")) + + +class BaseEngine: + """The BaseEngine is the base class for the async and sync engines. It holds the + common functionality, like generating the MongoDB queries, that is then used by the + two engines. + """ + + def __init__( + self, + client: Union["AsyncIOMotorClient", "MongoClient"], + database: str = "test", + ): # https://docs.mongodb.com/manual/reference/limits/#naming-restrictions forbidden_characters = _FORBIDDEN_DATABASE_CHARACTERS.intersection( set(database) @@ -109,22 +162,9 @@ def __init__(self, motor_client: AsyncIOMotorClient = None, database: str = "tes raise ValueError( f"database name cannot contain: {' '.join(forbidden_characters)}" ) - if motor_client is None: - motor_client = AsyncIOMotorClient() - self.client = motor_client + self.client = client self.database_name = database - self.database = motor_client[self.database_name] - - def get_collection(self, model: Type[ModelType]) -> AsyncIOMotorCollection: - """Get the motor collection associated to a Model. - - Args: - model: model class - - Returns: - the AsyncIO motor collection object - """ - return self.database[model.__collection__] + self.database = client[self.database_name] @staticmethod def _build_query(*queries: Union[QueryExpression, Dict, bool]) -> QueryExpression: @@ -158,7 +198,7 @@ def _cascade_find_pipeline( "$expr": {"$eq": ["$_id", "$$foreign_id"]} } }, - *AIOEngine._cascade_find_pipeline( + *BaseEngine._cascade_find_pipeline( odm_reference.model, doc_namespace=f"{doc_namespace}{ref_field_name}.", ), @@ -212,6 +252,79 @@ def _validate_sort_argument(cls, sort: Any) -> Optional[SortExpression]: return cls._build_sort_expression(sort) + def _prepare_find_pipeline( + self, + model: Type[ModelType], + *queries: Union[ + QueryExpression, Dict, bool + ], # bool: allow using binary operators with mypy + sort: Optional[Any] = None, + skip: int = 0, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + if not lenient_issubclass(model, Model): + raise TypeError("Can only call find with a Model class") + sort_expression = self._validate_sort_argument(sort) + if limit is not None and limit <= 0: + raise ValueError("limit has to be a strict positive value or None") + if skip < 0: + raise ValueError("skip has to be a positive integer") + query = BaseEngine._build_query(*queries) + pipeline: List[Dict[str, Any]] = [{"$match": query}] + if sort_expression is not None: + pipeline.append({"$sort": sort_expression}) + if skip > 0: + pipeline.append({"$skip": skip}) + if limit is not None and limit > 0: + pipeline.append({"$limit": limit}) + pipeline.extend(BaseEngine._cascade_find_pipeline(model)) + return pipeline + + +class AIOEngine(BaseEngine): + """The AIOEngine object is responsible for handling database operations with MongoDB + in an asynchronous way using motor. + """ + + client: "AsyncIOMotorClient" + database: "AsyncIOMotorDatabase" + + def __init__( + self, + client: Union["AsyncIOMotorClient", None] = None, + database: str = "test", + ): + """Engine constructor. + + Args: + client: instance of an AsyncIO motor client. If None, a default one + will be created + database: name of the database to use + + + """ + if not motor: + raise RuntimeError( + "motor is required to use AIOEngine, install it with:\n\n" + + 'pip install "odmantic[motor]"' + ) + if client is None: + client = AsyncIOMotorClient() + super().__init__(client=client, database=database) + + def get_collection(self, model: Type[ModelType]) -> "AsyncIOMotorCollection": + """Get the motor collection associated to a Model. + + Args: + model: model class + + Returns: + the AsyncIO motor collection object + """ + return self.database[model.__collection__] + def find( self, model: Type[ModelType], @@ -243,23 +356,14 @@ def find( #noqa: DAR402 DocumentParsingError --> """ - if not lenient_issubclass(model, Model): - raise TypeError("Can only call find with a Model class") - sort_expression = self._validate_sort_argument(sort) - if limit is not None and limit <= 0: - raise ValueError("limit has to be a strict positive value or None") - if skip < 0: - raise ValueError("skip has to be a positive integer") - query = AIOEngine._build_query(*queries) + pipeline = self._prepare_find_pipeline( + model, + *queries, + sort=sort, + skip=skip, + limit=limit, + ) collection = self.get_collection(model) - pipeline: List[Dict] = [{"$match": query}] - if sort_expression is not None: - pipeline.append({"$sort": sort_expression}) - if skip > 0: - pipeline.append({"$skip": skip}) - if limit is not None and limit > 0: - pipeline.append({"$limit": limit}) - pipeline.extend(AIOEngine._cascade_find_pipeline(model)) motor_cursor = collection.aggregate(pipeline) return AIOCursor(model, motor_cursor) @@ -268,7 +372,7 @@ async def find_one( model: Type[ModelType], *queries: Union[ QueryExpression, Dict, bool - ], # bool: allow using binary operators w/o plugin, + ], # bool: allow using binary operators w/o plugin sort: Optional[Any] = None, ) -> Optional[ModelType]: """Search for a Model instance matching the query filter provided @@ -297,7 +401,7 @@ async def find_one( return results[0] async def _save( - self, instance: ModelType, session: AsyncIOMotorClientSession + self, instance: ModelType, session: "AsyncIOMotorClientSession" ) -> ModelType: """Perform an atomic save operation in the specified session""" save_tasks = [] @@ -412,7 +516,234 @@ async def count( """ if not lenient_issubclass(model, Model): raise TypeError("Can only call count with a Model class") - query = AIOEngine._build_query(*queries) + query = BaseEngine._build_query(*queries) collection = self.database[model.__collection__] count = await collection.count_documents(query) return int(count) + + +class SyncEngine(BaseEngine): + """The SyncEngine object is responsible for handling database operations with + MongoDB in an synchronous way using pymongo. + """ + + client: "MongoClient" + database: "Database" + + def __init__( + self, + client: "Union[MongoClient, None]" = None, + database: str = "test", + ): + """Engine constructor. + + Args: + client: instance of a PyMongo client. If None, a default one + will be created + database: name of the database to use + """ + if client is None: + client = MongoClient() + super().__init__(client=client, database=database) + + def get_collection(self, model: Type[ModelType]) -> "Collection": + """Get the pymongo collection associated to a Model. + + Args: + model: model class + + Returns: + the pymongo collection object + """ + collection = self.database[model.__collection__] + return collection + + def find( + self, + model: Type[ModelType], + *queries: Union[ + QueryExpression, Dict, bool + ], # bool: allow using binary operators with mypy + sort: Optional[Any] = None, + skip: int = 0, + limit: Optional[int] = None, + ) -> SyncCursor[ModelType]: + """Search for Model instances matching the query filter provided + + Args: + model: model to perform the operation on + *queries: query filter to apply + sort: sort expression + skip: number of document to skip + limit: maximum number of instance fetched + + Raises: + DocumentParsingError: unable to parse one of the resulting documents + + Returns: + [odmantic.engine.SyncCursor][] of the query + + + """ + pipeline = self._prepare_find_pipeline( + model, + *queries, + sort=sort, + skip=skip, + limit=limit, + ) + collection = self.get_collection(model) + cursor = collection.aggregate(pipeline) + return SyncCursor(model, cursor) + + def find_one( + self, + model: Type[ModelType], + *queries: Union[ + QueryExpression, Dict, bool + ], # bool: allow using binary operators w/o plugin + sort: Optional[Any] = None, + ) -> Optional[ModelType]: + """Search for a Model instance matching the query filter provided + + Args: + model: model to perform the operation on + *queries: query filter to apply + sort: sort expression + + Raises: + DocumentParsingError: unable to parse the resulting document + + Returns: + the fetched instance if found otherwise None + + + """ + if not lenient_issubclass(model, Model): + raise TypeError("Can only call find_one with a Model class") + results = list(self.find(model, *queries, sort=sort, limit=1)) + if len(results) == 0: + return None + return results[0] + + def _save(self, instance: ModelType, session: "ClientSession") -> ModelType: + """Perform an atomic save operation in the specified session""" + for ref_field_name in instance.__references__: + sub_instance = cast(Model, getattr(instance, ref_field_name)) + self._save(sub_instance, session) + + fields_to_update = ( + instance.__fields_modified__ | instance.__mutable_fields__ + ) - set([instance.__primary_field__]) + if len(fields_to_update) > 0: + doc = instance.doc(include=fields_to_update) + collection = self.get_collection(type(instance)) + collection.update_one( + {"_id": getattr(instance, instance.__primary_field__)}, + {"$set": doc}, + upsert=True, + ) + object.__setattr__(instance, "__fields_modified__", set()) + return instance + + def save(self, instance: ModelType) -> ModelType: + """Persist an instance to the database + + This method behaves as an 'upsert' operation. If a document already exists + with the same primary key, it will be overwritten. + + All the other models referenced by this instance will be saved as well. + + Args: + instance: instance to persist + + Returns: + the saved instance + + NOTE: + The save operation actually modify the instance argument in place. However, + the instance is still returned for convenience. + + + """ + if not isinstance(instance, Model): + raise TypeError("Can only call find_one with a Model class") + + with self.client.start_session() as s: + with s.start_transaction(): + self._save(instance, s) + return instance + + def save_all(self, instances: Sequence[ModelType]) -> List[ModelType]: + """Persist instances to the database + + This method behaves as multiple 'upsert' operations. If one of the document + already exists with the same primary key, it will be overwritten. + + All the other models referenced by this instance will be recursively saved as + well. + + Args: + instances: instances to persist + + Returns: + the saved instances + + NOTE: + The save_all operation actually modify the arguments in place. However, the + instances are still returned for convenience. + """ + with self.client.start_session() as s: + with s.start_transaction(): + added_instances = [self._save(instance, s) for instance in instances] + return added_instances + + def delete(self, instance: ModelType) -> None: + """Delete an instance from the database + + Args: + instance: the instance to delete + + Raises: + DocumentNotFoundError: the instance has not been persisted to the database + + """ + # TODO handle cascade deletion + collection = self.database[instance.__collection__] + pk_name = instance.__primary_field__ + result = collection.delete_many({"_id": getattr(instance, pk_name)}) + count = int(result.deleted_count) + if count == 0: + raise DocumentNotFoundError(instance) + + def count( + self, model: Type[ModelType], *queries: Union[QueryExpression, Dict, bool] + ) -> int: + """Get the count of documents matching a query + + Args: + model: model to perform the operation on + *queries: query filters to apply + + Returns: + number of document matching the query + + + """ + if not lenient_issubclass(model, Model): + raise TypeError("Can only call count with a Model class") + query = BaseEngine._build_query(*queries) + collection = self.database[model.__collection__] + count = collection.count_documents(query) + return int(count) diff --git a/pyproject.toml b/pyproject.toml index c058e01e..49deceb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,13 +36,15 @@ classifiers = [ requires-python = ">=3.6.1" dependencies = [ "pydantic >=1.6.2,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<1.10.0", - "motor >=2.1.0,<3.1.0", "importlib-metadata >=1,<5; python_version<'3.8'", "typing-extensions >= 3.7.4.3; python_version<'3.8'", + "pymongo >=3.11.0,<5.0.0", ] [project.optional-dependencies] fastapi = ["fastapi >=0.61.1,<0.69.0"] test = [ + "motor >=2.1.0,<3.1.0", + "pymongo >=3.11.0,<5.0.0", "black ~= 22.3.0", "isort ~=5.8.0", "flake8 ~= 4.0.1", @@ -68,6 +70,7 @@ doc = [ "mkdocstrings[python] ~= 0.19.0", ] dev = ["ipython ~= 7.16.1"] +motor = ["motor >=2.1.0,<3.1.0"] [project.urls] Documentation = "https://art049.github.io/odmantic" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fc598a20..cd8f37c3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,8 +6,9 @@ import pytest from motor.motor_asyncio import AsyncIOMotorClient +from pymongo import MongoClient -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine try: from unittest.mock import AsyncMock @@ -44,6 +45,14 @@ def motor_client(event_loop): client.close() +@pytest.fixture(scope="session") +def pymongo_client(): + mongo_uri = TEST_MONGO_URI + client = MongoClient(mongo_uri) + yield client + client.close() + + @pytest.fixture(scope="function") def database_name(): return f"odmantic-test-{uuid4()}" @@ -51,7 +60,7 @@ def database_name(): @pytest.mark.asyncio @pytest.fixture(scope="function") -async def engine(motor_client, database_name): +async def aio_engine(motor_client: AsyncIOMotorClient, database_name: str): sess = AIOEngine(motor_client, database_name) yield sess if os.getenv("TEST_DEBUG") is None: @@ -59,17 +68,42 @@ async def engine(motor_client, database_name): @pytest.fixture(scope="function") -def motor_database(database_name, motor_client): +def sync_engine(pymongo_client: MongoClient, database_name: str): + sess = SyncEngine(pymongo_client, database_name) + yield sess + if os.getenv("TEST_DEBUG") is None: + pymongo_client.drop_database(database_name) + + +@pytest.fixture(scope="function") +def motor_database(database_name: str, motor_client: AsyncIOMotorClient): return motor_client[database_name] @pytest.fixture(scope="function") -def mock_collection(engine: AIOEngine, monkeypatch): +def pymongo_database(database_name: str, pymongo_client: MongoClient): + return pymongo_client[database_name] + + +@pytest.fixture(scope="function") +def aio_mock_collection(aio_engine: AIOEngine, monkeypatch): def f(): collection = Mock() collection.update_one = AsyncMock() collection.aggregate = AsyncMock() - monkeypatch.setattr(engine, "get_collection", lambda _: collection) + monkeypatch.setattr(aio_engine, "get_collection", lambda _: collection) + return collection + + return f + + +@pytest.fixture(scope="function") +def sync_mock_collection(sync_engine: SyncEngine, monkeypatch): + def f(): + collection = Mock() + collection.update_one = Mock() + collection.aggregate = Mock() + monkeypatch.setattr(sync_engine, "get_collection", lambda _: collection) return collection return f diff --git a/tests/integration/fastapi/test_doc_example.py b/tests/integration/fastapi/test_doc_example.py index 58ed6e2f..aafbedeb 100644 --- a/tests/integration/fastapi/test_doc_example.py +++ b/tests/integration/fastapi/test_doc_example.py @@ -11,8 +11,8 @@ @pytest.fixture -async def base_example_client(engine: AIOEngine) -> TestClient: - with patch("docs.examples_src.usage_fastapi.base_example.engine", engine): +async def base_example_client(aio_engine: AIOEngine) -> TestClient: + with patch("docs.examples_src.usage_fastapi.base_example.engine", aio_engine): from docs.examples_src.usage_fastapi.base_example import app async with TestClient(app) as client: @@ -22,10 +22,10 @@ async def base_example_client(engine: AIOEngine) -> TestClient: EXAMPLE_TREE_BODY = dict(name="MyTree", average_size=152, discovery_year=1992) -async def test_create_tree(base_example_client: TestClient, engine: AIOEngine): +async def test_create_tree(base_example_client: TestClient, aio_engine: AIOEngine): response = await base_example_client.put("/trees/", json=EXAMPLE_TREE_BODY) assert response.status_code == 200 - assert await engine.find_one(Tree, EXAMPLE_TREE_BODY) is not None + assert await aio_engine.find_one(Tree, EXAMPLE_TREE_BODY) is not None def is_sub_dict(a: Dict, b: Dict) -> bool: @@ -34,20 +34,20 @@ def is_sub_dict(a: Dict, b: Dict) -> bool: @pytest.mark.parametrize("count", [2, 10]) async def test_create_trees_count_get( - base_example_client: TestClient, engine: AIOEngine, count: int + base_example_client: TestClient, aio_engine: AIOEngine, count: int ): for _ in range(count): response = await base_example_client.put("/trees/", json=EXAMPLE_TREE_BODY) assert response.status_code == 200 - assert await engine.count(Tree) == count - async for tree in engine.find(Tree): + assert await aio_engine.count(Tree) == count + async for tree in aio_engine.find(Tree): assert is_sub_dict(EXAMPLE_TREE_BODY, tree.dict()) -async def test_get_tree_by_id(base_example_client: TestClient, engine: AIOEngine): +async def test_get_tree_by_id(base_example_client: TestClient, aio_engine: AIOEngine): tree = Tree(**EXAMPLE_TREE_BODY) - await engine.save(tree) + await aio_engine.save(tree) response = await base_example_client.get( f"/trees/{tree.id}", json=EXAMPLE_TREE_BODY ) @@ -56,8 +56,8 @@ async def test_get_tree_by_id(base_example_client: TestClient, engine: AIOEngine @pytest.fixture -async def example_update_client(engine: AIOEngine) -> TestClient: - with patch("docs.examples_src.usage_fastapi.example_update.engine", engine): +async def example_update_client(aio_engine: AIOEngine) -> TestClient: + with patch("docs.examples_src.usage_fastapi.example_update.engine", aio_engine): from docs.examples_src.usage_fastapi.example_update import app async with TestClient(app) as client: @@ -68,33 +68,35 @@ async def example_update_client(engine: AIOEngine) -> TestClient: async def test_update_tree_name_by_id( - example_update_client: TestClient, engine: AIOEngine + example_update_client: TestClient, aio_engine: AIOEngine ): tree = Tree(**EXAMPLE_TREE_BODY) - await engine.save(tree) + await aio_engine.save(tree) response = await example_update_client.patch( f"/trees/{tree.id}", json=dict(name=PATCHED_NAME) ) assert response.status_code == 200 assert response.json()["name"] == PATCHED_NAME - assert await engine.find_one(Tree, {"name": PATCHED_NAME}) is not None + assert await aio_engine.find_one(Tree, {"name": PATCHED_NAME}) is not None @pytest.fixture -async def example_delete_client(engine: AIOEngine) -> TestClient: - with patch("docs.examples_src.usage_fastapi.example_delete.engine", engine): +async def example_delete_client(aio_engine: AIOEngine) -> TestClient: + with patch("docs.examples_src.usage_fastapi.example_delete.engine", aio_engine): from docs.examples_src.usage_fastapi.example_delete import app async with TestClient(app) as client: yield client -async def test_delete_tree_by_id(example_delete_client: TestClient, engine: AIOEngine): +async def test_delete_tree_by_id( + example_delete_client: TestClient, aio_engine: AIOEngine +): tree = Tree(**EXAMPLE_TREE_BODY) - await engine.save(tree) + await aio_engine.save(tree) # Create other trees not affected by the delete to come for _ in range(10): - await engine.save(Tree(**EXAMPLE_TREE_BODY)) + await aio_engine.save(Tree(**EXAMPLE_TREE_BODY)) response = await example_delete_client.delete(f"/trees/{tree.id}") assert response.status_code == 200 - assert await engine.find_one(Tree, Tree.id == tree.id) is None + assert await aio_engine.find_one(Tree, Tree.id == tree.id) is None diff --git a/tests/integration/test_embedded_model.py b/tests/integration/test_embedded_model.py index 14e46fbd..8c5903f1 100644 --- a/tests/integration/test_embedded_model.py +++ b/tests/integration/test_embedded_model.py @@ -2,7 +2,7 @@ import pytest -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine from odmantic.model import EmbeddedModel, Model from ..zoo.book_embedded import Book, Publisher @@ -11,20 +11,51 @@ pytestmark = pytest.mark.asyncio -async def test_add_fetch_single(engine: AIOEngine): +async def test_add_fetch_single(aio_engine: AIOEngine): publisher = Publisher(name="O'Reilly Media", founded=1980, location="CA") book = Book(title="MongoDB: The Definitive Guide", pages=216, publisher=publisher) - instance = await engine.save(book) + instance = await aio_engine.save(book) assert instance.id is not None assert isinstance(instance.publisher, Publisher) assert instance.publisher == publisher - fetched_instance = await engine.find_one(Book, Book.id == instance.id) + fetched_instance = await aio_engine.find_one(Book, Book.id == instance.id) assert fetched_instance is not None assert fetched_instance.publisher == publisher -async def test_add_multiple(engine: AIOEngine): +def test_sync_add_fetch_single(sync_engine: SyncEngine): + publisher = Publisher(name="O'Reilly Media", founded=1980, location="CA") + book = Book(title="MongoDB: The Definitive Guide", pages=216, publisher=publisher) + instance = sync_engine.save(book) + assert instance.id is not None + assert isinstance(instance.publisher, Publisher) + assert instance.publisher == publisher + + fetched_instance = sync_engine.find_one(Book, Book.id == instance.id) + assert fetched_instance is not None + assert fetched_instance.publisher == publisher + + +async def test_add_multiple(aio_engine: AIOEngine): + addresses = [ + Address(street="81 Lafayette St.", city="Brownsburg", state="IN", zip="46112"), + Address( + street="862 West Euclid St.", city="Indian Trail", state="NC", zip="28079" + ), + ] + patron = Patron(name="The Princess Royal", addresses=addresses) + instance = await aio_engine.save(patron) + assert instance.id is not None + assert isinstance(instance.addresses, list) + assert instance.addresses == addresses + + fetched_instance = await aio_engine.find_one(Patron) + assert fetched_instance is not None + assert fetched_instance.addresses == addresses + + +def test_sync_add_multiple(sync_engine: SyncEngine): addresses = [ Address(street="81 Lafayette St.", city="Brownsburg", state="IN", zip="46112"), Address( @@ -32,46 +63,64 @@ async def test_add_multiple(engine: AIOEngine): ), ] patron = Patron(name="The Princess Royal", addresses=addresses) - instance = await engine.save(patron) + instance = sync_engine.save(patron) assert instance.id is not None assert isinstance(instance.addresses, list) assert instance.addresses == addresses - fetched_instance = await engine.find_one(Patron) + fetched_instance = sync_engine.find_one(Patron) assert fetched_instance is not None assert fetched_instance.addresses == addresses @pytest.fixture -async def books_with_embedded_publisher(engine: AIOEngine): +async def books_with_embedded_publisher(aio_engine: AIOEngine): publisher_1 = Publisher(name="O'Reilly Media", founded=1980, location="CA") book_1 = Book( title="MongoDB: The Definitive Guide", pages=216, publisher=publisher_1 ) publisher_2 = Publisher(name="O'Reilly Media", founded=2020, location="EU") book_2 = Book(title="MySQL: The Definitive Guide", pages=516, publisher=publisher_2) - return await engine.save_all([book_1, book_2]) + return await aio_engine.save_all([book_1, book_2]) async def test_query_filter_on_embedded_doc( - engine: AIOEngine, books_with_embedded_publisher: Tuple[Book, Book] + aio_engine: AIOEngine, books_with_embedded_publisher: Tuple[Book, Book] ): _, book_2 = books_with_embedded_publisher - fetched_instances = await engine.find(Book, Book.publisher == book_2.publisher) + fetched_instances = await aio_engine.find(Book, Book.publisher == book_2.publisher) + assert len(fetched_instances) == 1 + assert fetched_instances[0] == book_2 + + +def test_sync_query_filter_on_embedded_doc( + sync_engine: SyncEngine, books_with_embedded_publisher: Tuple[Book, Book] +): + _, book_2 = books_with_embedded_publisher + fetched_instances = list(sync_engine.find(Book, Book.publisher == book_2.publisher)) assert len(fetched_instances) == 1 assert fetched_instances[0] == book_2 async def test_query_filter_on_embedded_field( - engine: AIOEngine, books_with_embedded_publisher: Tuple[Book, Book] + aio_engine: AIOEngine, books_with_embedded_publisher: Tuple[Book, Book] +): + _, book_2 = books_with_embedded_publisher + fetched_instances = await aio_engine.find(Book, Book.publisher.location == "EU") + assert len(fetched_instances) == 1 + assert fetched_instances[0] == book_2 + + +def test_sync_query_filter_on_embedded_field( + sync_engine: SyncEngine, books_with_embedded_publisher: Tuple[Book, Book] ): _, book_2 = books_with_embedded_publisher - fetched_instances = await engine.find(Book, Book.publisher.location == "EU") + fetched_instances = list(sync_engine.find(Book, Book.publisher.location == "EU")) assert len(fetched_instances) == 1 assert fetched_instances[0] == book_2 -async def test_query_filter_on_embedded_nested(engine: AIOEngine): +async def test_query_filter_on_embedded_nested(aio_engine: AIOEngine): class ThirdModel(EmbeddedModel): field: int @@ -83,9 +132,9 @@ class TopModel(Model): instance_0 = TopModel(nested_0=SecondaryModel(nested_1=ThirdModel(field=12))) instance_1 = TopModel(nested_0=SecondaryModel(nested_1=ThirdModel(field=0))) - await engine.save_all([instance_0, instance_1]) + await aio_engine.save_all([instance_0, instance_1]) - fetched_instances = await engine.find( + fetched_instances = await aio_engine.find( TopModel, TopModel.nested_0.nested_1.field == 12 ) @@ -93,7 +142,46 @@ class TopModel(Model): assert fetched_instances[0] == instance_0 -async def test_fields_modified_embedded_model_modification(engine: AIOEngine): +def test_sync_query_filter_on_embedded_nested(sync_engine: SyncEngine): + class ThirdModel(EmbeddedModel): + field: int + + class SecondaryModel(EmbeddedModel): + nested_1: ThirdModel + + class TopModel(Model): + nested_0: SecondaryModel + + instance_0 = TopModel(nested_0=SecondaryModel(nested_1=ThirdModel(field=12))) + instance_1 = TopModel(nested_0=SecondaryModel(nested_1=ThirdModel(field=0))) + sync_engine.save_all([instance_0, instance_1]) + + fetched_instances = list( + sync_engine.find(TopModel, TopModel.nested_0.nested_1.field == 12) + ) + + assert len(fetched_instances) == 1 + assert fetched_instances[0] == instance_0 + + +async def test_fields_modified_embedded_model_modification(aio_engine: AIOEngine): + class E(EmbeddedModel): + f: int + + class M(Model): + e: E + + e = E(f=0) + m = M(e=e) + await aio_engine.save(m) + e.f = 1 + await aio_engine.save(m) + fetched = await aio_engine.find_one(M) + assert fetched is not None + assert fetched.e.f == 1 + + +def test_sync_fields_modified_embedded_model_modification(sync_engine: SyncEngine): class E(EmbeddedModel): f: int @@ -102,9 +190,9 @@ class M(Model): e = E(f=0) m = M(e=e) - await engine.save(m) + sync_engine.save(m) e.f = 1 - await engine.save(m) - fetched = await engine.find_one(M) + sync_engine.save(m) + fetched = sync_engine.find_one(M) assert fetched is not None assert fetched.e.f == 1 diff --git a/tests/integration/test_engine.py b/tests/integration/test_engine.py index 07166e60..9c2562d7 100644 --- a/tests/integration/test_engine.py +++ b/tests/integration/test_engine.py @@ -2,9 +2,10 @@ import pytest from motor.motor_asyncio import AsyncIOMotorClient +from pymongo import MongoClient from odmantic.bson import ObjectId -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine from odmantic.exceptions import DocumentNotFoundError, DocumentParsingError from odmantic.field import Field from odmantic.model import EmbeddedModel, Model @@ -21,28 +22,42 @@ async def test_default_motor_client_creation(): assert isinstance(engine.client, AsyncIOMotorClient) +def test_default_pymongo_client_creation(): + engine = SyncEngine() + assert isinstance(engine.client, MongoClient) + + @pytest.mark.parametrize("illegal_character", ("/", "\\", ".", '"', "$")) def test_invalid_database_name(illegal_character: str): with pytest.raises(ValueError, match="database name cannot contain"): AIOEngine(database=f"prefix{illegal_character}suffix") + with pytest.raises(ValueError, match="database name cannot contain"): + SyncEngine(database=f"prefix{illegal_character}suffix") + + +async def test_save(aio_engine: AIOEngine): + instance = await aio_engine.save( + PersonModel(first_name="Jean-Pierre", last_name="Pernaud") + ) + assert isinstance(instance.id, ObjectId) -async def test_save(engine: AIOEngine): - instance = await engine.save( +def test_sync_save(sync_engine: SyncEngine): + instance = sync_engine.save( PersonModel(first_name="Jean-Pierre", last_name="Pernaud") ) assert isinstance(instance.id, ObjectId) -async def test_save_find_find_one(engine: AIOEngine): +async def test_save_find_find_one(aio_engine: AIOEngine): initial_instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") - await engine.save(initial_instance) - found_instances = await engine.find(PersonModel) + await aio_engine.save(initial_instance) + found_instances = await aio_engine.find(PersonModel) assert len(found_instances) == 1 assert found_instances[0].first_name == initial_instance.first_name assert found_instances[0].last_name == initial_instance.last_name - single_fetched_instance = await engine.find_one( + single_fetched_instance = await aio_engine.find_one( PersonModel, PersonModel.first_name == "Jean-Pierre" ) assert single_fetched_instance is not None @@ -50,37 +65,85 @@ async def test_save_find_find_one(engine: AIOEngine): assert single_fetched_instance.last_name == initial_instance.last_name -async def test_find_one_not_existing(engine: AIOEngine): - fetched = await engine.find_one(PersonModel) +def test_sync_save_find_find_one(sync_engine: SyncEngine): + initial_instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") + sync_engine.save(initial_instance) + found_instances = list(sync_engine.find(PersonModel)) + assert len(found_instances) == 1 + assert found_instances[0].first_name == initial_instance.first_name + assert found_instances[0].last_name == initial_instance.last_name + + single_fetched_instance = sync_engine.find_one( + PersonModel, PersonModel.first_name == "Jean-Pierre" + ) + assert single_fetched_instance is not None + assert single_fetched_instance.first_name == initial_instance.first_name + assert single_fetched_instance.last_name == initial_instance.last_name + + +async def test_find_one_not_existing(aio_engine: AIOEngine): + fetched = await aio_engine.find_one(PersonModel) + assert fetched is None + + +def test_sync_find_one_not_existing(sync_engine: SyncEngine): + fetched = sync_engine.find_one(PersonModel) assert fetched is None @pytest.fixture(scope="function") -async def person_persisted(engine: AIOEngine): +async def person_persisted(aio_engine: AIOEngine): initial_instances = [ PersonModel(first_name="Jean-Pierre", last_name="Pernaud"), PersonModel(first_name="Jean-Pierre", last_name="Castaldi"), PersonModel(first_name="Michel", last_name="Drucker"), ] - return await engine.save_all(initial_instances) + return await aio_engine.save_all(initial_instances) async def test_save_multiple_simple_find_find_one( - engine: AIOEngine, person_persisted: List[PersonModel] + aio_engine: AIOEngine, person_persisted: List[PersonModel] ): - found_instances = await engine.find(PersonModel, PersonModel.first_name == "Michel") + found_instances = await aio_engine.find( + PersonModel, PersonModel.first_name == "Michel" + ) assert len(found_instances) == 1 assert found_instances[0].first_name == person_persisted[2].first_name assert found_instances[0].last_name == person_persisted[2].last_name - found_instances = await engine.find( + found_instances = await aio_engine.find( PersonModel, PersonModel.first_name == "Jean-Pierre" ) assert len(found_instances) == 2 assert found_instances[0].id != found_instances[1].id - single_retrieved = await engine.find_one( + single_retrieved = await aio_engine.find_one( + PersonModel, PersonModel.first_name == "Jean-Pierre" + ) + + assert single_retrieved is not None + assert single_retrieved in person_persisted + + +def test_sync_save_multiple_simple_find_find_one( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + + found_instances = list( + sync_engine.find(PersonModel, PersonModel.first_name == "Michel") + ) + assert len(found_instances) == 1 + assert found_instances[0].first_name == person_persisted[2].first_name + assert found_instances[0].last_name == person_persisted[2].last_name + + found_instances = list( + sync_engine.find(PersonModel, PersonModel.first_name == "Jean-Pierre") + ) + assert len(found_instances) == 2 + assert found_instances[0].id != found_instances[1].id + + single_retrieved = sync_engine.find_one( PersonModel, PersonModel.first_name == "Jean-Pierre" ) @@ -89,32 +152,52 @@ async def test_save_multiple_simple_find_find_one( async def test_find_sync_iteration( - engine: AIOEngine, person_persisted: List[PersonModel] + aio_engine: AIOEngine, person_persisted: List[PersonModel] ): fetched = set() - for inst in await engine.find(PersonModel): + for inst in await aio_engine.find(PersonModel): + fetched.add(inst.id) + + assert set(i.id for i in person_persisted) == fetched + + +def test_sync_find_sync_iteration( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + fetched = set() + for inst in sync_engine.find(PersonModel): fetched.add(inst.id) assert set(i.id for i in person_persisted) == fetched @pytest.mark.usefixtures("person_persisted") -async def test_find_sync_iteration_cached(engine: AIOEngine, mock_collection): - cursor = engine.find(PersonModel) +async def test_find_sync_iteration_cached(aio_engine: AIOEngine, aio_mock_collection): + cursor = aio_engine.find(PersonModel) initial = await cursor - collection = mock_collection() + collection = aio_mock_collection() cached = await cursor collection.aggregate.assert_not_awaited() assert cached == initial @pytest.mark.usefixtures("person_persisted") -async def test_find_async_iteration_cached(engine: AIOEngine, mock_collection): - cursor = engine.find(PersonModel) +def test_sync_find_sync_iteration_cached(sync_engine: SyncEngine, sync_mock_collection): + cursor = sync_engine.find(PersonModel) + initial = list(cursor) + collection = sync_mock_collection() + cached = list(cursor) + collection.aggregate.assert_not_called() + assert cached == initial + + +@pytest.mark.usefixtures("person_persisted") +async def test_find_async_iteration_cached(aio_engine: AIOEngine, aio_mock_collection): + cursor = aio_engine.find(PersonModel) initial = [] async for inst in cursor: initial.append(inst) - collection = mock_collection() + collection = aio_mock_collection() cached = [] async for inst in cursor: cached.append(inst) @@ -122,110 +205,237 @@ async def test_find_async_iteration_cached(engine: AIOEngine, mock_collection): assert cached == initial -async def test_find_skip(engine: AIOEngine, person_persisted: List[PersonModel]): - results = await engine.find(PersonModel, skip=1) +@pytest.mark.usefixtures("person_persisted") +def test_sync_find_async_iteration_cached( + sync_engine: SyncEngine, sync_mock_collection +): + cursor = sync_engine.find(PersonModel) + initial = [] + for inst in cursor: + initial.append(inst) + collection = sync_mock_collection() + cached = [] + for inst in cursor: + cached.append(inst) + collection.aggregate.assert_not_called() + assert cached == initial + + +async def test_find_skip(aio_engine: AIOEngine, person_persisted: List[PersonModel]): + results = await aio_engine.find(PersonModel, skip=1) assert len(results) == 2 for instance in results: assert instance in person_persisted -async def test_find_one_bad_query(engine: AIOEngine): +def test_sync_find_skip(sync_engine: SyncEngine, person_persisted: List[PersonModel]): + results = list(sync_engine.find(PersonModel, skip=1)) + assert len(results) == 2 + for instance in results: + assert instance in person_persisted + + +async def test_find_one_bad_query(aio_engine: AIOEngine): with pytest.raises(TypeError): - await engine.find_one(PersonModel, True, False) + await aio_engine.find_one(PersonModel, True, False) -async def test_find_one_on_non_model(engine: AIOEngine): +def test_sync_find_one_bad_query(sync_engine: SyncEngine): + with pytest.raises(TypeError): + sync_engine.find_one(PersonModel, True, False) + + +async def test_find_one_on_non_model(aio_engine: AIOEngine): class BadModel: pass with pytest.raises(TypeError): - await engine.find_one(BadModel) # type: ignore + await aio_engine.find_one(BadModel) # type: ignore + + +def test_sync_find_one_on_non_model(sync_engine: SyncEngine): + class BadModel: + pass + + with pytest.raises(TypeError): + sync_engine.find_one(BadModel) # type: ignore + + +async def test_find_invalid_limit(aio_engine: AIOEngine): + with pytest.raises(ValueError): + await aio_engine.find(PersonModel, limit=0) + with pytest.raises(ValueError): + await aio_engine.find(PersonModel, limit=-12) -async def test_find_invalid_limit(engine: AIOEngine): +def test_sync_find_invalid_limit(sync_engine: SyncEngine): with pytest.raises(ValueError): - await engine.find(PersonModel, limit=0) + sync_engine.find(PersonModel, limit=0) with pytest.raises(ValueError): - await engine.find(PersonModel, limit=-12) + sync_engine.find(PersonModel, limit=-12) -async def test_find_invalid_skip(engine: AIOEngine): +async def test_find_invalid_skip(aio_engine: AIOEngine): with pytest.raises(ValueError): - await engine.find(PersonModel, skip=-1) + await aio_engine.find(PersonModel, skip=-1) with pytest.raises(ValueError): - await engine.find(PersonModel, limit=-12) + await aio_engine.find(PersonModel, limit=-12) + + +def test_sync_find_invalid_skip(sync_engine: SyncEngine): + with pytest.raises(ValueError): + sync_engine.find(PersonModel, skip=-1) + with pytest.raises(ValueError): + sync_engine.find(PersonModel, limit=-12) @pytest.mark.usefixtures("person_persisted") -async def test_skip(engine: AIOEngine): - p = await engine.find(PersonModel, skip=1) +async def test_skip(aio_engine: AIOEngine): + p = await aio_engine.find(PersonModel, skip=1) assert len(p) == 2 @pytest.mark.usefixtures("person_persisted") -async def test_limit(engine: AIOEngine): - p = await engine.find(PersonModel, limit=1) +def test_sync_skip(sync_engine: SyncEngine): + p = list(sync_engine.find(PersonModel, skip=1)) + assert len(p) == 2 + + +@pytest.mark.usefixtures("person_persisted") +async def test_limit(aio_engine: AIOEngine): + p = await aio_engine.find(PersonModel, limit=1) assert len(p) == 1 @pytest.mark.usefixtures("person_persisted") -async def test_skip_limit(engine: AIOEngine): - p = await engine.find(PersonModel, skip=1, limit=1) +def test_sync_limit(sync_engine: SyncEngine): + p = list(sync_engine.find(PersonModel, limit=1)) assert len(p) == 1 -async def test_save_multiple_time_same_document(engine: AIOEngine): +@pytest.mark.usefixtures("person_persisted") +async def test_skip_limit(aio_engine: AIOEngine): + p = await aio_engine.find(PersonModel, skip=1, limit=1) + assert len(p) == 1 + + +@pytest.mark.usefixtures("person_persisted") +def test_sync_skip_limit(sync_engine: SyncEngine): + p = list(sync_engine.find(PersonModel, skip=1, limit=1)) + assert len(p) == 1 + + +async def test_save_multiple_time_same_document(aio_engine: AIOEngine): fixed_id = ObjectId() instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud", id=fixed_id) - await engine.save(instance) + await aio_engine.save(instance) + + instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud", id=fixed_id) + await aio_engine.save(instance) + + assert await aio_engine.count(PersonModel, PersonModel.id == fixed_id) == 1 + + +def test_sync_save_multiple_time_same_document(sync_engine: SyncEngine): + fixed_id = ObjectId() instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud", id=fixed_id) - await engine.save(instance) + sync_engine.save(instance) - assert await engine.count(PersonModel, PersonModel.id == fixed_id) == 1 + instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud", id=fixed_id) + sync_engine.save(instance) + + assert sync_engine.count(PersonModel, PersonModel.id == fixed_id) == 1 @pytest.mark.usefixtures("person_persisted") -async def test_count(engine: AIOEngine): - count = await engine.count(PersonModel) +async def test_count(aio_engine: AIOEngine): + count = await aio_engine.count(PersonModel) assert count == 3 - count = await engine.count(PersonModel, PersonModel.first_name == "Michel") + count = await aio_engine.count(PersonModel, PersonModel.first_name == "Michel") assert count == 1 - count = await engine.count(PersonModel, PersonModel.first_name == "Gérard") + count = await aio_engine.count(PersonModel, PersonModel.first_name == "Gérard") assert count == 0 -async def test_count_on_non_model_fails(engine: AIOEngine): +@pytest.mark.usefixtures("person_persisted") +def test_sync_count(sync_engine: SyncEngine): + count = sync_engine.count(PersonModel) + assert count == 3 + + count = sync_engine.count(PersonModel, PersonModel.first_name == "Michel") + assert count == 1 + + count = sync_engine.count(PersonModel, PersonModel.first_name == "Gérard") + assert count == 0 + + +async def test_count_on_non_model_fails(aio_engine: AIOEngine): class BadModel: pass with pytest.raises(TypeError): - await engine.count(BadModel) # type: ignore + await aio_engine.count(BadModel) # type: ignore -async def test_find_on_embedded_raises(engine: AIOEngine): +def test_sync_count_on_non_model_fails(sync_engine: SyncEngine): + class BadModel: + pass + + with pytest.raises(TypeError): + sync_engine.count(BadModel) # type: ignore + + +async def test_find_on_embedded_raises(aio_engine: AIOEngine): + class BadModel(EmbeddedModel): + field: int + + with pytest.raises(TypeError): + await aio_engine.find(BadModel) # type: ignore + + +def test_sync_find_on_embedded_raises(sync_engine: SyncEngine): + class BadModel(EmbeddedModel): + field: int + + with pytest.raises(TypeError): + sync_engine.find(BadModel) # type: ignore + + +async def test_save_on_embedded(aio_engine: AIOEngine): class BadModel(EmbeddedModel): field: int + instance = BadModel(field=12) with pytest.raises(TypeError): - await engine.find(BadModel) # type: ignore + await aio_engine.save(instance) # type: ignore -async def test_save_on_embedded(engine: AIOEngine): +def test_sync_save_on_embedded(sync_engine: SyncEngine): class BadModel(EmbeddedModel): field: int instance = BadModel(field=12) with pytest.raises(TypeError): - await engine.save(instance) # type: ignore + sync_engine.save(instance) # type: ignore + + +@pytest.mark.usefixtures("person_persisted") +async def test_implicit_and(aio_engine: AIOEngine): + count = await aio_engine.count( + PersonModel, + PersonModel.first_name == "Michel", + PersonModel.last_name == "Drucker", + ) + assert count == 1 @pytest.mark.usefixtures("person_persisted") -async def test_implicit_and(engine: AIOEngine): - count = await engine.count( +def test_sync_implicit_and(sync_engine: SyncEngine): + count = sync_engine.count( PersonModel, PersonModel.first_name == "Michel", PersonModel.last_name == "Drucker", @@ -233,62 +443,124 @@ async def test_implicit_and(engine: AIOEngine): assert count == 1 -async def test_save_update(engine: AIOEngine): +async def test_save_update(aio_engine: AIOEngine): + instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") + await aio_engine.save(instance) + assert await aio_engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 1 + instance.last_name = "Dupuis" + await aio_engine.save(instance) + assert await aio_engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 0 + assert await aio_engine.count(PersonModel, PersonModel.last_name == "Dupuis") == 1 + + +def test_sync_save_update(sync_engine: SyncEngine): instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") - await engine.save(instance) - assert await engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 1 + sync_engine.save(instance) + assert sync_engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 1 instance.last_name = "Dupuis" - await engine.save(instance) - assert await engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 0 - assert await engine.count(PersonModel, PersonModel.last_name == "Dupuis") == 1 + sync_engine.save(instance) + assert sync_engine.count(PersonModel, PersonModel.last_name == "Pernaud") == 0 + assert sync_engine.count(PersonModel, PersonModel.last_name == "Dupuis") == 1 + + +async def test_delete_and_count( + aio_engine: AIOEngine, person_persisted: List[PersonModel] +): + await aio_engine.delete(person_persisted[0]) + assert await aio_engine.count(PersonModel) == 2 + await aio_engine.delete(person_persisted[1]) + assert await aio_engine.count(PersonModel) == 1 + await aio_engine.delete(person_persisted[2]) + assert await aio_engine.count(PersonModel) == 0 -async def test_delete_and_count(engine: AIOEngine, person_persisted: List[PersonModel]): - await engine.delete(person_persisted[0]) - assert await engine.count(PersonModel) == 2 - await engine.delete(person_persisted[1]) - assert await engine.count(PersonModel) == 1 - await engine.delete(person_persisted[2]) - assert await engine.count(PersonModel) == 0 +def test_sync_delete_and_count( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + sync_engine.delete(person_persisted[0]) + assert sync_engine.count(PersonModel) == 2 + sync_engine.delete(person_persisted[1]) + assert sync_engine.count(PersonModel) == 1 + sync_engine.delete(person_persisted[2]) + assert sync_engine.count(PersonModel) == 0 -async def test_delete_not_existing(engine: AIOEngine): +async def test_delete_not_existing(aio_engine: AIOEngine): non_persisted_instance = PersonModel(first_name="Jean", last_name="Paul") with pytest.raises(DocumentNotFoundError) as exc: - await engine.delete(non_persisted_instance) + await aio_engine.delete(non_persisted_instance) assert exc.value.instance == non_persisted_instance -async def test_modified_fields_cleared_on_document_saved(engine: AIOEngine): +def test_sync_delete_not_existing(sync_engine: SyncEngine): + non_persisted_instance = PersonModel(first_name="Jean", last_name="Paul") + with pytest.raises(DocumentNotFoundError) as exc: + sync_engine.delete(non_persisted_instance) + assert exc.value.instance == non_persisted_instance + + +async def test_modified_fields_cleared_on_document_saved(aio_engine: AIOEngine): instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") assert len(instance.__fields_modified__) > 0 - await engine.save(instance) + await aio_engine.save(instance) assert len(instance.__fields_modified__) == 0 -async def test_modified_fields_cleared_on_nested_document_saved(engine: AIOEngine): +def test_sync_modified_fields_cleared_on_document_saved(sync_engine: SyncEngine): + instance = PersonModel(first_name="Jean-Pierre", last_name="Pernaud") + assert len(instance.__fields_modified__) > 0 + sync_engine.save(instance) + assert len(instance.__fields_modified__) == 0 + + +async def test_modified_fields_cleared_on_nested_document_saved(aio_engine: AIOEngine): hachette = Publisher(name="Hachette Livre", founded=1826, location="FR") book = Book(title="They Didn't See Us Coming", pages=304, publisher=hachette) assert len(hachette.__fields_modified__) > 0 - await engine.save(book) + await aio_engine.save(book) + assert len(hachette.__fields_modified__) == 0 + + +def test_sync_modified_fields_cleared_on_nested_document_saved(sync_engine: SyncEngine): + hachette = Publisher(name="Hachette Livre", founded=1826, location="FR") + book = Book(title="They Didn't See Us Coming", pages=304, publisher=hachette) + assert len(hachette.__fields_modified__) > 0 + sync_engine.save(book) assert len(hachette.__fields_modified__) == 0 @pytest.fixture() -async def engine_one_person(engine: AIOEngine): - await engine.save(PersonModel(first_name="Jean-Pierre", last_name="Pernaud")) +async def engine_one_person(aio_engine: AIOEngine): + await aio_engine.save(PersonModel(first_name="Jean-Pierre", last_name="Pernaud")) @pytest.mark.usefixtures("engine_one_person") -async def test_modified_fields_on_find(engine: AIOEngine): - instance = await engine.find_one(PersonModel) +async def test_modified_fields_on_find(aio_engine: AIOEngine): + instance = await aio_engine.find_one(PersonModel) assert instance is not None assert len(instance.__fields_modified__) == 0 @pytest.mark.usefixtures("engine_one_person") -async def test_modified_fields_on_document_change(engine: AIOEngine): - instance = await engine.find_one(PersonModel) +def test_sync_modified_fields_on_find(sync_engine: SyncEngine): + instance = sync_engine.find_one(PersonModel) + assert instance is not None + assert len(instance.__fields_modified__) == 0 + + +@pytest.mark.usefixtures("engine_one_person") +async def test_modified_fields_on_document_change(aio_engine: AIOEngine): + instance = await aio_engine.find_one(PersonModel) + assert instance is not None + instance.first_name = "Jackie" + assert len(instance.__fields_modified__) == 1 + instance.last_name = "Chan" + assert len(instance.__fields_modified__) == 2 + + +@pytest.mark.usefixtures("engine_one_person") +def test_sync_modified_fields_on_document_change(sync_engine: SyncEngine): + instance = sync_engine.find_one(PersonModel) assert instance is not None instance.first_name = "Jackie" assert len(instance.__fields_modified__) == 1 @@ -297,46 +569,93 @@ async def test_modified_fields_on_document_change(engine: AIOEngine): @pytest.mark.usefixtures("engine_one_person") -async def test_no_set_on_save_fetched_document(engine: AIOEngine, mock_collection): - instance = await engine.find_one(PersonModel) +async def test_no_set_on_save_fetched_document( + aio_engine: AIOEngine, sync_mock_collection +): + instance = await aio_engine.find_one(PersonModel) + assert instance is not None + + collection = sync_mock_collection() + await aio_engine.save(instance) + collection.update_one.assert_not_called() + + +@pytest.mark.usefixtures("engine_one_person") +def test_sync_no_set_on_save_fetched_document( + sync_engine: SyncEngine, sync_mock_collection +): + instance = sync_engine.find_one(PersonModel) assert instance is not None - collection = mock_collection() - await engine.save(instance) - collection.update_one.assert_not_awaited() + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_not_called() @pytest.mark.usefixtures("engine_one_person") -async def test_only_modified_set_on_save(engine: AIOEngine, mock_collection): - instance = await engine.find_one(PersonModel) +async def test_only_modified_set_on_save(aio_engine: AIOEngine, aio_mock_collection): + instance = await aio_engine.find_one(PersonModel) assert instance is not None instance.first_name = "John" - collection = mock_collection() - await engine.save(instance) + collection = aio_mock_collection() + await aio_engine.save(instance) collection.update_one.assert_awaited_once() (_, set_arg), _ = collection.update_one.await_args assert set_arg == {"$set": {"first_name": "John"}} -async def test_only_mutable_list_set_on_save(engine: AIOEngine, mock_collection): +@pytest.mark.usefixtures("engine_one_person") +def test_sync_only_modified_set_on_save(sync_engine: SyncEngine, sync_mock_collection): + instance = sync_engine.find_one(PersonModel) + assert instance is not None + + instance.first_name = "John" + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_called_once() + (_, set_arg), _ = collection.update_one.call_args + assert set_arg == {"$set": {"first_name": "John"}} + + +async def test_only_mutable_list_set_on_save( + aio_engine: AIOEngine, aio_mock_collection +): class M(Model): field: List[str] immutable_field: int instance = M(field=["hello"], immutable_field=12) - await engine.save(instance) + await aio_engine.save(instance) - collection = mock_collection() - await engine.save(instance) + collection = aio_mock_collection() + await aio_engine.save(instance) collection.update_one.assert_awaited_once() (_, set_arg), _ = collection.update_one.await_args set_dict = set_arg["$set"] assert list(set_dict.keys()) == ["field"] +def test_sync_only_mutable_list_set_on_save( + sync_engine: SyncEngine, sync_mock_collection +): + class M(Model): + field: List[str] + immutable_field: int + + instance = M(field=["hello"], immutable_field=12) + sync_engine.save(instance) + + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_called_once() + (_, set_arg), _ = collection.update_one.call_args + set_dict = set_arg["$set"] + assert list(set_dict.keys()) == ["field"] + + async def test_only_mutable_list_of_embedded_set_on_save( - engine: AIOEngine, mock_collection + aio_engine: AIOEngine, aio_mock_collection ): class E(EmbeddedModel): a: str @@ -345,18 +664,38 @@ class M(Model): field: List[E] instance = M(field=[E(a="hello")]) - await engine.save(instance) + await aio_engine.save(instance) - collection = mock_collection() - await engine.save(instance) + collection = aio_mock_collection() + await aio_engine.save(instance) collection.update_one.assert_awaited_once() (_, set_arg), _ = collection.update_one.await_args set_dict = set_arg["$set"] assert set_dict == {"field": [{"a": "hello"}]} +def test_sync_only_mutable_list_of_embedded_set_on_save( + sync_engine: SyncEngine, sync_mock_collection +): + class E(EmbeddedModel): + a: str + + class M(Model): + field: List[E] + + instance = M(field=[E(a="hello")]) + sync_engine.save(instance) + + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_called_once() + (_, set_arg), _ = collection.update_one.call_args + set_dict = set_arg["$set"] + assert set_dict == {"field": [{"a": "hello"}]} + + async def test_only_mutable_dict_of_embedded_set_on_save( - engine: AIOEngine, mock_collection + aio_engine: AIOEngine, aio_mock_collection ): class E(EmbeddedModel): a: str @@ -365,17 +704,39 @@ class M(Model): field: Dict[str, E] instance = M(field={"hello": E(a="world")}) - await engine.save(instance) + await aio_engine.save(instance) - collection = mock_collection() - await engine.save(instance) + collection = aio_mock_collection() + await aio_engine.save(instance) collection.update_one.assert_awaited_once() (_, set_arg), _ = collection.update_one.await_args set_dict = set_arg["$set"] assert set_dict == {"field": {"hello": {"a": "world"}}} -async def test_only_tuple_of_embedded_set_on_save(engine: AIOEngine, mock_collection): +def test_sync_only_mutable_dict_of_embedded_set_on_save( + sync_engine: SyncEngine, sync_mock_collection +): + class E(EmbeddedModel): + a: str + + class M(Model): + field: Dict[str, E] + + instance = M(field={"hello": E(a="world")}) + sync_engine.save(instance) + + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_called_once() + (_, set_arg), _ = collection.update_one.call_args + set_dict = set_arg["$set"] + assert set_dict == {"field": {"hello": {"a": "world"}}} + + +async def test_only_tuple_of_embedded_set_on_save( + aio_engine: AIOEngine, aio_mock_collection +): class E(EmbeddedModel): a: str @@ -383,23 +744,54 @@ class M(Model): field: Tuple[E] instance = M(field=(E(a="world"),)) - await engine.save(instance) + await aio_engine.save(instance) - collection = mock_collection() - await engine.save(instance) + collection = aio_mock_collection() + await aio_engine.save(instance) collection.update_one.assert_awaited_once() (_, set_arg), _ = collection.update_one.await_args set_dict = set_arg["$set"] assert set_dict == {"field": ({"a": "world"},)} -async def test_find_sort_asc(engine: AIOEngine, person_persisted: List[PersonModel]): - results = await engine.find(PersonModel, sort=PersonModel.last_name) +def test_sync_only_tuple_of_embedded_set_on_save( + sync_engine: SyncEngine, sync_mock_collection +): + class E(EmbeddedModel): + a: str + + class M(Model): + field: Tuple[E] + + instance = M(field=(E(a="world"),)) + sync_engine.save(instance) + + collection = sync_mock_collection() + sync_engine.save(instance) + collection.update_one.assert_called_once() + (_, set_arg), _ = collection.update_one.call_args + set_dict = set_arg["$set"] + assert set_dict == {"field": ({"a": "world"},)} + + +async def test_find_sort_asc( + aio_engine: AIOEngine, person_persisted: List[PersonModel] +): + results = await aio_engine.find(PersonModel, sort=PersonModel.last_name) assert results == sorted(person_persisted, key=lambda person: person.last_name) -async def test_find_sort_list(engine: AIOEngine, person_persisted: List[PersonModel]): - results = await engine.find( +def test_sync_find_sort_asc( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + results = list(sync_engine.find(PersonModel, sort=PersonModel.last_name)) + assert results == sorted(person_persisted, key=lambda person: person.last_name) + + +async def test_find_sort_list( + aio_engine: AIOEngine, person_persisted: List[PersonModel] +): + results = await aio_engine.find( PersonModel, sort=(PersonModel.first_name, PersonModel.last_name) ) assert results == sorted( @@ -407,7 +799,31 @@ async def test_find_sort_list(engine: AIOEngine, person_persisted: List[PersonMo ) -async def test_find_sort_wrong_argument(engine: AIOEngine): +def test_sync_find_sort_list( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + results = list( + sync_engine.find( + PersonModel, sort=(PersonModel.first_name, PersonModel.last_name) + ) + ) + assert results == sorted( + person_persisted, key=lambda person: (person.first_name, person.last_name) + ) + + +async def test_find_sort_wrong_argument(aio_engine: AIOEngine): + with pytest.raises( + TypeError, + match=( + "sort has to be a Model field or " + "asc, desc descriptors or a tuple of these" + ), + ): + await aio_engine.find(PersonModel, sort="first_name") + + +def test_sync_find_sort_wrong_argument(sync_engine: SyncEngine): with pytest.raises( TypeError, match=( @@ -415,19 +831,29 @@ async def test_find_sort_wrong_argument(engine: AIOEngine): "asc, desc descriptors or a tuple of these" ), ): - await engine.find(PersonModel, sort="first_name") + sync_engine.find(PersonModel, sort="first_name") -async def test_find_sort_wrong_tuple_argument(engine: AIOEngine): +async def test_find_sort_wrong_tuple_argument(aio_engine: AIOEngine): with pytest.raises( TypeError, match="sort elements have to be Model fields or asc, desc descriptors", ): - await engine.find(PersonModel, sort=("first_name",)) + await aio_engine.find(PersonModel, sort=("first_name",)) -async def test_find_sort_desc(engine: AIOEngine, person_persisted: List[PersonModel]): - results = await engine.find( +def test_sync_find_sort_wrong_tuple_argument(sync_engine: SyncEngine): + with pytest.raises( + TypeError, + match="sort elements have to be Model fields or asc, desc descriptors", + ): + sync_engine.find(PersonModel, sort=("first_name",)) + + +async def test_find_sort_desc( + aio_engine: AIOEngine, person_persisted: List[PersonModel] +): + results = await aio_engine.find( PersonModel, sort=PersonModel.last_name.desc() # type: ignore ) assert results == list( @@ -435,14 +861,32 @@ async def test_find_sort_desc(engine: AIOEngine, person_persisted: List[PersonMo ) +def test_sync_find_sort_desc( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + results = list( + sync_engine.find(PersonModel, sort=PersonModel.last_name.desc()) # type: ignore + ) + assert results == list( + reversed(sorted(person_persisted, key=lambda person: person.last_name)) + ) + + async def test_find_sort_asc_function( - engine: AIOEngine, person_persisted: List[PersonModel] + aio_engine: AIOEngine, person_persisted: List[PersonModel] ): - results = await engine.find(PersonModel, sort=asc(PersonModel.last_name)) + results = await aio_engine.find(PersonModel, sort=asc(PersonModel.last_name)) assert results == sorted(person_persisted, key=lambda person: person.last_name) -async def test_find_sort_multiple_descriptors(engine: AIOEngine): +def test_sync_find_sort_asc_function( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + results = list(sync_engine.find(PersonModel, sort=asc(PersonModel.last_name))) + assert results == sorted(person_persisted, key=lambda person: person.last_name) + + +async def test_find_sort_multiple_descriptors(aio_engine: AIOEngine): class TestModel(Model): a: int b: int @@ -453,8 +897,8 @@ class TestModel(Model): TestModel(a=2, b=2, c=3), TestModel(a=3, b=3, c=2), ] - await engine.save_all(persisted_models) - results = await engine.find( + await aio_engine.save_all(persisted_models) + results = await aio_engine.find( TestModel, sort=( desc(TestModel.a), @@ -468,7 +912,48 @@ class TestModel(Model): ) -async def test_sort_embedded_field(engine: AIOEngine): +def test_sync_find_sort_multiple_descriptors(sync_engine: SyncEngine): + class TestModel(Model): + a: int + b: int + c: int + + persisted_models = [ + TestModel(a=1, b=2, c=3), + TestModel(a=2, b=2, c=3), + TestModel(a=3, b=3, c=2), + ] + sync_engine.save_all(persisted_models) + results = list( + sync_engine.find( + TestModel, + sort=( + desc(TestModel.a), + TestModel.b, + TestModel.c.asc(), # type: ignore + ), + ) + ) + assert results == sorted( + persisted_models, + key=lambda test_model: (-test_model.a, test_model.b, test_model.c), + ) + + +async def test_sort_embedded_field(aio_engine: AIOEngine): + class E(EmbeddedModel): + field: int + + class M(Model): + e: E + + instances = [M(e=E(field=0)), M(e=E(field=1)), M(e=E(field=2))] + await aio_engine.save_all(instances) + results = await aio_engine.find(M, sort=desc(M.e.field)) + assert results == sorted(instances, key=lambda instance: -instance.e.field) + + +def test_sync_sort_embedded_field(sync_engine: SyncEngine): class E(EmbeddedModel): field: int @@ -476,48 +961,97 @@ class M(Model): e: E instances = [M(e=E(field=0)), M(e=E(field=1)), M(e=E(field=2))] - await engine.save_all(instances) - results = await engine.find(M, sort=desc(M.e.field)) + sync_engine.save_all(instances) + results = list(sync_engine.find(M, sort=desc(M.e.field))) assert results == sorted(instances, key=lambda instance: -instance.e.field) -async def test_find_one_sort(engine: AIOEngine, person_persisted: List[PersonModel]): - person = await engine.find_one(PersonModel, sort=PersonModel.last_name) +async def test_find_one_sort( + aio_engine: AIOEngine, person_persisted: List[PersonModel] +): + person = await aio_engine.find_one(PersonModel, sort=PersonModel.last_name) assert person is not None assert person.last_name == "Castaldi" -async def test_find_document_field_not_set_with_default(engine: AIOEngine): +def test_sync_find_one_sort( + sync_engine: SyncEngine, person_persisted: List[PersonModel] +): + person = sync_engine.find_one(PersonModel, sort=PersonModel.last_name) + assert person is not None + assert person.last_name == "Castaldi" + + +async def test_find_document_field_not_set_with_default(aio_engine: AIOEngine): + class M(Model): + field: Optional[str] = None + + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) + gathered = await aio_engine.find_one(M) + assert gathered is not None + assert gathered.field is None + + +def test_sync_find_document_field_not_set_with_default(sync_engine: SyncEngine): class M(Model): field: Optional[str] = None - await engine.get_collection(M).insert_one({"_id": ObjectId()}) - gathered = await engine.find_one(M) + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) + gathered = sync_engine.find_one(M) assert gathered is not None assert gathered.field is None async def test_find_document_field_not_set_with_default_field_descriptor( - engine: AIOEngine, + aio_engine: AIOEngine, ): class M(Model): field: str = Field(default="hello world") - await engine.get_collection(M).insert_one({"_id": ObjectId()}) - gathered = await engine.find_one(M) + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) + gathered = await aio_engine.find_one(M) assert gathered is not None assert gathered.field == "hello world" -async def test_find_document_field_not_set_with_no_default(engine: AIOEngine): +def test_sync_find_document_field_not_set_with_default_field_descriptor( + sync_engine: SyncEngine, +): + class M(Model): + field: str = Field(default="hello world") + + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) + gathered = sync_engine.find_one(M) + assert gathered is not None + assert gathered.field == "hello world" + + +async def test_find_document_field_not_set_with_no_default(aio_engine: AIOEngine): class M(Model): field: str - await engine.get_collection(M).insert_one({"_id": ObjectId()}) + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) with pytest.raises( DocumentParsingError, match="key not found in document" ) as exc_info: - await engine.find_one(M) + await aio_engine.find_one(M) + assert ( + "1 validation error for M\n" + "field\n" + " key not found in document " + "(type=value_error.keynotfoundindocument; key_name='field')" + ) in str(exc_info.value) + + +def test_sync_find_document_field_not_set_with_no_default(sync_engine: SyncEngine): + class M(Model): + field: str + + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) + with pytest.raises( + DocumentParsingError, match="key not found in document" + ) as exc_info: + sync_engine.find_one(M) assert ( "1 validation error for M\n" "field\n" @@ -527,18 +1061,44 @@ class M(Model): async def test_find_document_field_not_set_with_default_factory_disabled( - engine: AIOEngine, + aio_engine: AIOEngine, +): + class M(Model): + field: str = Field(default_factory=lambda: "hello") # pragma: no cover + + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) + with pytest.raises(DocumentParsingError, match="key not found in document"): + await aio_engine.find_one(M) + + +def test_sync_find_document_field_not_set_with_default_factory_disabled( + sync_engine: SyncEngine, ): class M(Model): field: str = Field(default_factory=lambda: "hello") # pragma: no cover - await engine.get_collection(M).insert_one({"_id": ObjectId()}) + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) with pytest.raises(DocumentParsingError, match="key not found in document"): - await engine.find_one(M) + sync_engine.find_one(M) async def test_find_document_field_not_set_with_default_factory_enabled( - engine: AIOEngine, + aio_engine: AIOEngine, +): + class M(Model): + field: str = Field(default_factory=lambda: "hello") + + class Config: + parse_doc_with_default_factories = True + + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) + instance = await aio_engine.find_one(M) + assert instance is not None + assert instance.field == "hello" + + +def test_sync_find_document_field_not_set_with_default_factory_enabled( + sync_engine: SyncEngine, ): class M(Model): field: str = Field(default_factory=lambda: "hello") @@ -546,7 +1106,7 @@ class M(Model): class Config: parse_doc_with_default_factories = True - await engine.get_collection(M).insert_one({"_id": ObjectId()}) - instance = await engine.find_one(M) + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) + instance = sync_engine.find_one(M) assert instance is not None assert instance.field == "hello" diff --git a/tests/integration/test_engine_reference.py b/tests/integration/test_engine_reference.py index 6e122d6d..23546138 100644 --- a/tests/integration/test_engine_reference.py +++ b/tests/integration/test_engine_reference.py @@ -1,7 +1,7 @@ import pytest from odmantic.bson import ObjectId -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine from odmantic.exceptions import DocumentParsingError from odmantic.model import Model from odmantic.reference import Reference @@ -12,11 +12,21 @@ pytestmark = pytest.mark.asyncio -async def test_add_with_references(engine: AIOEngine): +async def test_add_with_references(aio_engine: AIOEngine): publisher = Publisher(name="O'Reilly Media", founded=1980, location="CA") book = Book(title="MongoDB: The Definitive Guide", pages=216, publisher=publisher) - instance = await engine.save(book) - fetched_subinstance = await engine.find_one( + instance = await aio_engine.save(book) + fetched_subinstance = await aio_engine.find_one( + Publisher, Publisher.id == instance.publisher.id + ) + assert fetched_subinstance == publisher + + +def test_sync_add_with_references(sync_engine: SyncEngine): + publisher = Publisher(name="O'Reilly Media", founded=1980, location="CA") + book = Book(title="MongoDB: The Definitive Guide", pages=216, publisher=publisher) + instance = sync_engine.save(book) + fetched_subinstance = sync_engine.find_one( Publisher, Publisher.id == instance.publisher.id ) assert fetched_subinstance == publisher @@ -27,48 +37,105 @@ async def test_add_with_references(engine: AIOEngine): # TODO test add with duplicated reference id -async def test_save_deeply_nested(engine: AIOEngine): +async def test_save_deeply_nested(aio_engine: AIOEngine): + instance = NestedLevel1(next_=NestedLevel2(next_=NestedLevel3())) + await aio_engine.save(instance) + assert await aio_engine.count(NestedLevel3) == 1 + assert await aio_engine.count(NestedLevel2) == 1 + assert await aio_engine.count(NestedLevel1) == 1 + + +def test_sync_save_deeply_nested(sync_engine: SyncEngine): instance = NestedLevel1(next_=NestedLevel2(next_=NestedLevel3())) - await engine.save(instance) - assert await engine.count(NestedLevel3) == 1 - assert await engine.count(NestedLevel2) == 1 - assert await engine.count(NestedLevel1) == 1 + sync_engine.save(instance) + assert sync_engine.count(NestedLevel3) == 1 + assert sync_engine.count(NestedLevel2) == 1 + assert sync_engine.count(NestedLevel1) == 1 -async def test_update_deeply_nested(engine: AIOEngine): +async def test_update_deeply_nested(aio_engine: AIOEngine): inst3 = NestedLevel3( field=0 ) # Isolate inst3 to make sure it's not internaly copied instance = NestedLevel1(next_=NestedLevel2(next_=inst3)) - await engine.save(instance) - assert await engine.count(NestedLevel3, NestedLevel3.field == 42) == 0 + await aio_engine.save(instance) + assert await aio_engine.count(NestedLevel3, NestedLevel3.field == 42) == 0 inst3.field = 42 - await engine.save(instance) - assert await engine.count(NestedLevel3, NestedLevel3.field == 42) == 1 + await aio_engine.save(instance) + assert await aio_engine.count(NestedLevel3, NestedLevel3.field == 42) == 1 + + +def test_sync_update_deeply_nested(sync_engine: SyncEngine): + inst3 = NestedLevel3( + field=0 + ) # Isolate inst3 to make sure it's not internaly copied + instance = NestedLevel1(next_=NestedLevel2(next_=inst3)) + sync_engine.save(instance) + assert sync_engine.count(NestedLevel3, NestedLevel3.field == 42) == 0 + inst3.field = 42 + sync_engine.save(instance) + assert sync_engine.count(NestedLevel3, NestedLevel3.field == 42) == 1 + + +async def test_save_deeply_nested_and_fetch(aio_engine: AIOEngine): + instance = NestedLevel1(next_=NestedLevel2(next_=NestedLevel3(field=0))) + await aio_engine.save(instance) + + fetched = await aio_engine.find_one(NestedLevel1) + assert fetched == instance -async def test_save_deeply_nested_and_fetch(engine: AIOEngine): +def test_sync_save_deeply_nested_and_fetch(sync_engine: SyncEngine): instance = NestedLevel1(next_=NestedLevel2(next_=NestedLevel3(field=0))) - await engine.save(instance) + sync_engine.save(instance) - fetched = await engine.find_one(NestedLevel1) + fetched = sync_engine.find_one(NestedLevel1) assert fetched == instance -async def test_multiple_save_deeply_nested_and_fetch(engine: AIOEngine): +async def test_multiple_save_deeply_nested_and_fetch(aio_engine: AIOEngine): + instances = [ + NestedLevel1(field=1, next_=NestedLevel2(field=2, next_=NestedLevel3(field=3))), + NestedLevel1(field=4, next_=NestedLevel2(field=5, next_=NestedLevel3(field=6))), + ] + await aio_engine.save_all(instances) + + fetched = await aio_engine.find(NestedLevel1) + assert len(fetched) == 2 + assert fetched[0] in instances + assert fetched[1] in instances + + +def test_sync_multiple_save_deeply_nested_and_fetch(sync_engine: SyncEngine): instances = [ NestedLevel1(field=1, next_=NestedLevel2(field=2, next_=NestedLevel3(field=3))), NestedLevel1(field=4, next_=NestedLevel2(field=5, next_=NestedLevel3(field=6))), ] - await engine.save_all(instances) + sync_engine.save_all(instances) - fetched = await engine.find(NestedLevel1) + fetched = list(sync_engine.find(NestedLevel1)) assert len(fetched) == 2 assert fetched[0] in instances assert fetched[1] in instances -async def test_reference_with_key_name(engine: AIOEngine): +async def test_reference_with_key_name(aio_engine: AIOEngine): + class R(Model): + field: int + + class M(Model): + r: R = Reference(key_name="fancy_key_name") + + instance = M(r=R(field=3)) + assert "fancy_key_name" in instance.doc() + await aio_engine.save(instance) + + fetched = await aio_engine.find_one(M) + assert fetched is not None + assert fetched.r.field == 3 + + +def test_sync_reference_with_key_name(sync_engine: SyncEngine): class R(Model): field: int @@ -77,23 +144,41 @@ class M(Model): instance = M(r=R(field=3)) assert "fancy_key_name" in instance.doc() - await engine.save(instance) + sync_engine.save(instance) - fetched = await engine.find_one(M) + fetched = sync_engine.find_one(M) assert fetched is not None assert fetched.r.field == 3 -async def test_reference_not_set_in_database(engine: AIOEngine): +async def test_reference_not_set_in_database(aio_engine: AIOEngine): + class R(Model): + field: int + + class M(Model): + r: R = Reference() + + await aio_engine.get_collection(M).insert_one({"_id": ObjectId()}) + with pytest.raises(DocumentParsingError) as exc_info: + await aio_engine.find_one(M) + assert ( + "1 validation error for M\n" + "r\n" + " referenced document not found " + "(type=value_error.referenceddocumentnotfound; foreign_key_name='r')" + ) in str(exc_info.value) + + +def test_sync_reference_not_set_in_database(sync_engine: SyncEngine): class R(Model): field: int class M(Model): r: R = Reference() - await engine.get_collection(M).insert_one({"_id": ObjectId()}) + sync_engine.get_collection(M).insert_one({"_id": ObjectId()}) with pytest.raises(DocumentParsingError) as exc_info: - await engine.find_one(M) + sync_engine.find_one(M) assert ( "1 validation error for M\n" "r\n" @@ -102,7 +187,31 @@ class M(Model): ) in str(exc_info.value) -async def test_reference_incorect_reference_structure(engine: AIOEngine): +async def test_reference_incorect_reference_structure(aio_engine: AIOEngine): + class R(Model): + field: int + + class M(Model): + r: R = Reference() + + r = R(field=12) + r_doc = r.doc() + del r_doc["field"] + m = M(r=r) + await aio_engine.get_collection(R).insert_one(r_doc) + await aio_engine.get_collection(M).insert_one(m.doc()) + + with pytest.raises(DocumentParsingError) as exc_info: + await aio_engine.find_one(M) + assert ( + "1 validation error for M\n" + "r -> field\n" + " key not found in document " + "(type=value_error.keynotfoundindocument; key_name='field')" + ) in str(exc_info.value) + + +def test_sync_reference_incorect_reference_structure(sync_engine: SyncEngine): class R(Model): field: int @@ -113,11 +222,11 @@ class M(Model): r_doc = r.doc() del r_doc["field"] m = M(r=r) - await engine.get_collection(R).insert_one(r_doc) - await engine.get_collection(M).insert_one(m.doc()) + sync_engine.get_collection(R).insert_one(r_doc) + sync_engine.get_collection(M).insert_one(m.doc()) with pytest.raises(DocumentParsingError) as exc_info: - await engine.find_one(M) + sync_engine.find_one(M) assert ( "1 validation error for M\n" "r -> field\n" diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 668126c5..c74ae52b 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -4,7 +4,7 @@ import pytest from odmantic import Model -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine from odmantic.query import ( QueryExpression, and_, @@ -27,38 +27,58 @@ @pytest.fixture(scope="function") -async def person_persisted(engine: AIOEngine): +async def person_persisted(aio_engine: AIOEngine): initial_instances = [ PersonModel(first_name="Jean-Pierre", last_name="Pernaud"), PersonModel(first_name="Jean-Pierre", last_name="Castaldi"), PersonModel(first_name="Michel", last_name="Drucker"), ] - return await engine.save_all(initial_instances) + return await aio_engine.save_all(initial_instances) @pytest.mark.usefixtures("person_persisted") -async def test_and(engine: AIOEngine): +async def test_and(aio_engine: AIOEngine): query = (PersonModel.first_name == "Michel") & (PersonModel.last_name == "Drucker") assert query == and_( PersonModel.first_name == "Michel", PersonModel.last_name == "Drucker" ) - count = await engine.count(PersonModel, query) + count = await aio_engine.count(PersonModel, query) assert count == 1 @pytest.mark.usefixtures("person_persisted") -async def test_or(engine: AIOEngine): +def test_sync_and(sync_engine: SyncEngine): + query = (PersonModel.first_name == "Michel") & (PersonModel.last_name == "Drucker") + assert query == and_( + PersonModel.first_name == "Michel", PersonModel.last_name == "Drucker" + ) + count = sync_engine.count(PersonModel, query) + assert count == 1 + + +@pytest.mark.usefixtures("person_persisted") +async def test_or(aio_engine: AIOEngine): query = (PersonModel.first_name == "Michel") | (PersonModel.last_name == "Castaldi") assert query == or_( PersonModel.first_name == "Michel", PersonModel.last_name == "Castaldi" ) - count = await engine.count(PersonModel, query) + count = await aio_engine.count(PersonModel, query) assert count == 2 @pytest.mark.usefixtures("person_persisted") -async def test_nor(engine: AIOEngine): - count = await engine.count( +def test_sync_or(sync_engine: SyncEngine): + query = (PersonModel.first_name == "Michel") | (PersonModel.last_name == "Castaldi") + assert query == or_( + PersonModel.first_name == "Michel", PersonModel.last_name == "Castaldi" + ) + count = sync_engine.count(PersonModel, query) + assert count == 2 + + +@pytest.mark.usefixtures("person_persisted") +async def test_nor(aio_engine: AIOEngine): + count = await aio_engine.count( PersonModel, nor_(PersonModel.first_name == "Michel", PersonModel.last_name == "Castaldi"), ) @@ -66,40 +86,78 @@ async def test_nor(engine: AIOEngine): @pytest.mark.usefixtures("person_persisted") -async def test_eq(engine: AIOEngine): +async def test_eq(aio_engine: AIOEngine): query = cast(QueryExpression, PersonModel.first_name == "Michel") assert query == eq(PersonModel.first_name, "Michel") - count = await engine.count(PersonModel, query) + count = await aio_engine.count(PersonModel, query) assert count == 1 @pytest.mark.usefixtures("person_persisted") -async def test_ne(engine: AIOEngine): +def test_sync_eq(sync_engine: SyncEngine): + query = cast(QueryExpression, PersonModel.first_name == "Michel") + assert query == eq(PersonModel.first_name, "Michel") + count = sync_engine.count(PersonModel, query) + assert count == 1 + + +@pytest.mark.usefixtures("person_persisted") +async def test_ne(aio_engine: AIOEngine): + query = PersonModel.first_name != "Michel" + assert query == ne(PersonModel.first_name, "Michel") + count = await aio_engine.count(PersonModel, query) + assert count == 2 + + +@pytest.mark.usefixtures("person_persisted") +def test_sync_ne(sync_engine: SyncEngine): query = PersonModel.first_name != "Michel" assert query == ne(PersonModel.first_name, "Michel") - count = await engine.count(PersonModel, query) + count = sync_engine.count(PersonModel, query) assert count == 2 @pytest.mark.usefixtures("person_persisted") -async def test_in_(engine: AIOEngine): +async def test_in_(aio_engine: AIOEngine): + query = in_(PersonModel.first_name, ["Michel", "Jean-Pierre"]) + # TODO allow this with a mypy plugin + assert query == PersonModel.first_name.in_( # type: ignore + ["Michel", "Jean-Pierre"] + ) + count = await aio_engine.count(PersonModel, query) + assert count == 3 + + +@pytest.mark.usefixtures("person_persisted") +def test_sync_in_(sync_engine: SyncEngine): query = in_(PersonModel.first_name, ["Michel", "Jean-Pierre"]) # TODO allow this with a mypy plugin assert query == PersonModel.first_name.in_( # type: ignore ["Michel", "Jean-Pierre"] ) - count = await engine.count(PersonModel, query) + count = sync_engine.count(PersonModel, query) assert count == 3 @pytest.mark.usefixtures("person_persisted") -async def test_not_in(engine: AIOEngine): +async def test_not_in(aio_engine: AIOEngine): + query = not_in(PersonModel.first_name, ["Michel", "Jean-Pierre"]) + # TODO allow this with a mypy plugin + assert query == PersonModel.first_name.not_in( # type: ignore + ["Michel", "Jean-Pierre"] + ) + count = await aio_engine.count(PersonModel, query) + assert count == 0 + + +@pytest.mark.usefixtures("person_persisted") +def test_sync_not_in(sync_engine: SyncEngine): query = not_in(PersonModel.first_name, ["Michel", "Jean-Pierre"]) # TODO allow this with a mypy plugin assert query == PersonModel.first_name.not_in( # type: ignore ["Michel", "Jean-Pierre"] ) - count = await engine.count(PersonModel, query) + count = sync_engine.count(PersonModel, query) assert count == 0 @@ -109,65 +167,111 @@ class AgedPerson(Model): @pytest.fixture(scope="function") -async def aged_person_persisted(engine: AIOEngine): +async def aged_person_persisted(aio_engine: AIOEngine): initial_instances = [ AgedPerson(name="Jean-Pierre", age=25), AgedPerson(name="Jean-Paul", age=40), AgedPerson(name="Michel", age=70), ] - return await engine.save_all(initial_instances) + return await aio_engine.save_all(initial_instances) @pytest.mark.usefixtures("aged_person_persisted") -async def test_gt(engine: AIOEngine): +async def test_gt(aio_engine: AIOEngine): query = AgedPerson.age > 40 assert query == AgedPerson.age.gt(40) # type: ignore assert query == gt(AgedPerson.age, 40) - count = await engine.count(AgedPerson, query) + count = await aio_engine.count(AgedPerson, query) assert count == 1 @pytest.mark.usefixtures("aged_person_persisted") -async def test_gte(engine: AIOEngine): +def test_sync_gt(sync_engine: SyncEngine): + query = AgedPerson.age > 40 + assert query == AgedPerson.age.gt(40) # type: ignore + assert query == gt(AgedPerson.age, 40) + count = sync_engine.count(AgedPerson, query) + assert count == 1 + + +@pytest.mark.usefixtures("aged_person_persisted") +async def test_gte(aio_engine: AIOEngine): query = AgedPerson.age >= 40 assert query == AgedPerson.age.gte(40) # type: ignore assert query == gte(AgedPerson.age, 40) - count = await engine.count(AgedPerson, query) + count = await aio_engine.count(AgedPerson, query) assert count == 2 @pytest.mark.usefixtures("aged_person_persisted") -async def test_lt(engine: AIOEngine): +async def test_lt(aio_engine: AIOEngine): + query = AgedPerson.age < 40 + assert query == AgedPerson.age.lt(40) # type: ignore + assert query == lt(AgedPerson.age, 40) + count = await aio_engine.count(AgedPerson, query) + assert count == 1 + + +@pytest.mark.usefixtures("aged_person_persisted") +def test_sync_lt(sync_engine: SyncEngine): query = AgedPerson.age < 40 assert query == AgedPerson.age.lt(40) # type: ignore assert query == lt(AgedPerson.age, 40) - count = await engine.count(AgedPerson, query) + count = sync_engine.count(AgedPerson, query) assert count == 1 @pytest.mark.usefixtures("aged_person_persisted") -async def test_lte(engine: AIOEngine): +async def test_lte(aio_engine: AIOEngine): query = AgedPerson.age <= 40 assert query == AgedPerson.age.lte(40) # type: ignore assert query == lte(AgedPerson.age, 40) - count = await engine.count(AgedPerson, query) + count = await aio_engine.count(AgedPerson, query) + assert count == 2 + + +@pytest.mark.usefixtures("aged_person_persisted") +def test_sync_lte(sync_engine: SyncEngine): + query = AgedPerson.age <= 40 + assert query == AgedPerson.age.lte(40) # type: ignore + assert query == lte(AgedPerson.age, 40) + count = sync_engine.count(AgedPerson, query) + assert count == 2 + + +@pytest.mark.usefixtures("person_persisted") +async def test_match_pattern_string(aio_engine: AIOEngine): + # TODO allow this with a mypy plugin + query = PersonModel.first_name.match(r"^Jean-.*") # type: ignore + assert query == match(PersonModel.first_name, "^Jean-.*") + count = await aio_engine.count(PersonModel, query) assert count == 2 @pytest.mark.usefixtures("person_persisted") -async def test_match_pattern_string(engine: AIOEngine): +def test_sync_match_pattern_string(sync_engine: SyncEngine): # TODO allow this with a mypy plugin query = PersonModel.first_name.match(r"^Jean-.*") # type: ignore assert query == match(PersonModel.first_name, "^Jean-.*") - count = await engine.count(PersonModel, query) + count = sync_engine.count(PersonModel, query) + assert count == 2 + + +@pytest.mark.usefixtures("person_persisted") +async def test_match_pattern_compiled(aio_engine: AIOEngine): + # TODO allow this with a mypy plugin + r = re.compile(r"^Jean-.*") + query = PersonModel.first_name.match(r) # type: ignore + assert query == match(PersonModel.first_name, r) + count = await aio_engine.count(PersonModel, query) assert count == 2 @pytest.mark.usefixtures("person_persisted") -async def test_match_pattern_compiled(engine: AIOEngine): +def test_sync_match_pattern_compiled(sync_engine: SyncEngine): # TODO allow this with a mypy plugin r = re.compile(r"^Jean-.*") query = PersonModel.first_name.match(r) # type: ignore assert query == match(PersonModel.first_name, r) - count = await engine.count(PersonModel, query) + count = sync_engine.count(PersonModel, query) assert count == 2 diff --git a/tests/integration/test_types.py b/tests/integration/test_types.py index 667256ec..afcb0c1a 100644 --- a/tests/integration/test_types.py +++ b/tests/integration/test_types.py @@ -7,8 +7,9 @@ import pytest from bson import Binary, Decimal128, Int64, ObjectId, Regex from motor.motor_asyncio import AsyncIOMotorDatabase +from pymongo.database import Database -from odmantic.engine import AIOEngine +from odmantic.engine import AIOEngine, SyncEngine from odmantic.model import Model pytestmark = pytest.mark.asyncio @@ -72,13 +73,13 @@ class TypeTestCase(Generic[T]): @pytest.mark.parametrize("case", type_test_data) async def test_bson_type_inference( - motor_database: AsyncIOMotorDatabase, engine: AIOEngine, case: TypeTestCase + motor_database: AsyncIOMotorDatabase, aio_engine: AIOEngine, case: TypeTestCase ): class ModelWithTypedField(Model): field: case.python_type # type: ignore # TODO: Fix objectid optional (type: ignore) - instance = await engine.save(ModelWithTypedField(field=case.sample_value)) + instance = await aio_engine.save(ModelWithTypedField(field=case.sample_value)) document = await motor_database[ModelWithTypedField.__collection__].find_one( { +ModelWithTypedField.id: instance.id, @@ -93,8 +94,31 @@ class ModelWithTypedField(Model): assert recovered_instance.field == instance.field +@pytest.mark.parametrize("case", type_test_data) +def test_sync_bson_type_inference( + pymongo_database: Database, sync_engine: SyncEngine, case: TypeTestCase +): + class ModelWithTypedField(Model): + field: case.python_type # type: ignore + + # TODO: Fix objectid optional (type: ignore) + instance = sync_engine.save(ModelWithTypedField(field=case.sample_value)) + document = pymongo_database[ModelWithTypedField.__collection__].find_one( + { + +ModelWithTypedField.id: instance.id, + +ModelWithTypedField.field: {"$type": case.bson_type}, + } + ) + assert document is not None, ( + f"Type inference error: {case.python_type} -> {case.bson_type}" + f" ({case.sample_value})" + ) + recovered_instance = ModelWithTypedField(field=document["field"]) + assert recovered_instance.field == instance.field + + async def test_custom_bson_serializable( - motor_database: AsyncIOMotorDatabase, engine: AIOEngine + motor_database: AsyncIOMotorDatabase, aio_engine ): class FancyFloat: @classmethod @@ -113,7 +137,7 @@ def __bson__(cls, v): class ModelWithCustomField(Model): field: FancyFloat - instance = await engine.save(ModelWithCustomField(field=3.14)) + instance = await aio_engine.save(ModelWithCustomField(field=3.14)) document = await motor_database[ModelWithCustomField.__collection__].find_one( { +ModelWithCustomField.id: instance.id, @@ -123,3 +147,35 @@ class ModelWithCustomField(Model): assert document is not None, "Couldn't retrieve the document with it's string value" recovered_instance = ModelWithCustomField.parse_doc(document) assert recovered_instance.field == instance.field + + +def test_sync_custom_bson_serializable( + pymongo_database: Database, sync_engine: SyncEngine +): + class FancyFloat: + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + return float(v) + + @classmethod + def __bson__(cls, v): + # We store the float as a string in the DB + return str(v) + + class ModelWithCustomField(Model): + field: FancyFloat + + instance = sync_engine.save(ModelWithCustomField(field=3.14)) + document = pymongo_database[ModelWithCustomField.__collection__].find_one( + { + +ModelWithCustomField.id: instance.id, + +ModelWithCustomField.field: {"$type": "string"}, # type: ignore + } + ) + assert document is not None, "Couldn't retrieve the document with it's string value" + recovered_instance = ModelWithCustomField.parse_doc(document) + assert recovered_instance.field == instance.field diff --git a/tests/integration/test_zoo.py b/tests/integration/test_zoo.py index 07092fcd..2093f967 100644 --- a/tests/integration/test_zoo.py +++ b/tests/integration/test_zoo.py @@ -1,29 +1,52 @@ import pytest from odmantic import AIOEngine +from odmantic.engine import SyncEngine from tests.zoo.player import Player from tests.zoo.twitter_user import TwitterUser pytestmark = pytest.mark.asyncio -async def test_twitter_user(engine: AIOEngine): +async def test_twitter_user(aio_engine: AIOEngine): main = TwitterUser() - await engine.save(main) + await aio_engine.save(main) friends = [TwitterUser() for _ in range(25)] - await engine.save_all(friends) + await aio_engine.save_all(friends) friend_ids = [f.id for f in friends] main.following = friend_ids - await engine.save(main) + await aio_engine.save(main) - fetched_main = await engine.find_one(TwitterUser, TwitterUser.id == main.id) + fetched_main = await aio_engine.find_one(TwitterUser, TwitterUser.id == main.id) assert fetched_main is not None assert fetched_main == main assert set(friend_ids) == set(fetched_main.following) -async def test_player(engine: AIOEngine): +def test_sync_twitter_user(sync_engine: SyncEngine): + main = TwitterUser() + sync_engine.save(main) + friends = [TwitterUser() for _ in range(25)] + sync_engine.save_all(friends) + friend_ids = [f.id for f in friends] + main.following = friend_ids + sync_engine.save(main) + + fetched_main = sync_engine.find_one(TwitterUser, TwitterUser.id == main.id) + assert fetched_main is not None + assert fetched_main == main + assert set(friend_ids) == set(fetched_main.following) + + +async def test_player(aio_engine: AIOEngine): + leeroy = Player(name="Leeroy Jenkins") + await aio_engine.save(leeroy) + fetched = await aio_engine.find_one(Player) + assert fetched == leeroy + + +def test_sync_player(sync_engine: SyncEngine): leeroy = Player(name="Leeroy Jenkins") - await engine.save(leeroy) - fetched = await engine.find_one(Player) + sync_engine.save(leeroy) + fetched = sync_engine.find_one(Player) assert fetched == leeroy diff --git a/tox.ini b/tox.ini index d84cd045..28aeb7d5 100644 --- a/tox.ini +++ b/tox.ini @@ -3,6 +3,8 @@ isolated_build = true envlist = py{36}-motor{23,24,25}-pydantic{16,17,18} py{37,38,39}-motor{21,22,23,24,25,30}-pydantic{17,18,19} + py{36}-motor{24}-pymongo{3_11,3_12}-pydantic{16,17,18} + py{37,38,39}-motor{24}-pymongo{3_11,3_12}-pydantic{17,18,19} skip_missing_interpreters=false [testenv] extras = test @@ -13,6 +15,12 @@ deps = motor24: motor ~= 2.4.0 motor25: motor ~= 2.5.0 motor30: motor ~= 3.0.0 + pymongo3_11: pymongo ~= 3.11.0 + pymongo3_12: pymongo ~= 3.12.0 + # pymongo 4.0.0 is not supported by any version of motor + # pymongo4_0: pymongo ~= 4.0.0 + # pymongo 4.1.0 is the the only version supported by motor 3.0.0, it's already covered + # pymongo4_1: pymongo ~= 4.1.0 pydantic16: pydantic ~= 1.6.2 pydantic17: pydantic ~= 1.7.4 pydantic18: pydantic ~= 1.8.2