diff --git a/mongow/mixins/base.py b/mongow/mixins/base.py index 20106e9..d0b8237 100644 --- a/mongow/mixins/base.py +++ b/mongow/mixins/base.py @@ -1,10 +1,13 @@ from typing import ( + Any, Generic, + Optional, TypeVar, Union ) import pymongo +from bson import ObjectId from pydantic import ( BaseModel, Field @@ -14,13 +17,11 @@ from ..utils import ( Direction, Indice, - ObjectId, PyObjectId ) T = TypeVar("T", bound="BaseMixin") - class BaseMixin( BaseModel, Generic[T] @@ -37,7 +38,7 @@ class Config: ObjectId: str } - def __new__(cls): + def __new__(cls, *_, **__): self_dict = __class__.Config.__dict__ subclass_dict = cls.Config.__dict__ for key in self_dict: @@ -73,14 +74,10 @@ async def create_indices(cls): def build_indice(cls, data: Union[Indice, tuple]) -> Indice: if isinstance(data, tuple): data = Indice( - keys=[(data[0], data[1])], - name=data[2] if len(data) > 2 else None, - unique=data[3] if len(data) > 3 else None, - background=data[4] if len(data) > 4 else None, - sparse=data[5] if len(data) > 5 else None, - bucket_size=data[6] if len(data) > 6 else None, - min=data[7] if len(data) > 7 else None, - max=data[8] if len(data) > 8 else None + keys=[( + data[0], + data[1] if len(data) > 1 else Direction.ASCENDING + )] ) if not isinstance(data, Indice): @@ -116,3 +113,14 @@ def build_pymongo_index(indice: Indice) -> pymongo.IndexModel: kwargs["name"] = "_".join([key[0] for key in indice.keys]) + "_index" return pymongo.IndexModel(**kwargs) + + @classmethod + def instantiate_obj(cls, key: str, value: Union[str, Any]) -> tuple[str, Any]: + if isinstance(value, str): + if key == "_id": + return key, ObjectId(value) + # TODO: get attr name by alias name + # assume it is an _id and pop off its underscore + # cls.schema(by_alias=True).get("properties").keys() + # return key, typing.get_type_hints(cls)[search_key](value) + return key, value \ No newline at end of file diff --git a/mongow/mixins/model.py b/mongow/mixins/model.py index 748e9cf..ec7ff39 100644 --- a/mongow/mixins/model.py +++ b/mongow/mixins/model.py @@ -5,7 +5,8 @@ List, Optional, Tuple, - Union + Type, + Union, ) from bson import ObjectId @@ -19,19 +20,10 @@ class ModelMixin(BaseMixin): - @classmethod - def instantiate_obj(cls, key: str, value: Union[str, Any]) -> tuple[str, Any]: - if isinstance(value, str): - if key == "_id": - return key, ObjectId(value) - # TODO: get attr name by alias name - # assume it is an _id and pop off its underscore - # cls.schema(by_alias=True).get("properties").keys() - # return key, typing.get_type_hints(cls)[search_key](value) - return key, value - @classmethod async def create(cls, data: T) -> ObjectId: + if not data.id: + data.id = ObjectId() result = await database.database[cls.__collection__].insert_one( data.dict(by_alias=True) ) @@ -39,19 +31,19 @@ async def create(cls, data: T) -> ObjectId: @classmethod async def read( - cls, - fields: Iterable[str] = tuple(), - order: Optional[Tuple[str, bool]] = None, - offset: int = 0, - limit: int = 100, - filters: Optional[Union[T, dict]] = None, - construct_object: bool = False + cls: Type[T], + fields: Iterable[str] = tuple(), + order: Optional[Tuple[str, bool]] = None, + offset: int = 0, + limit: int = 100, + filters: Optional[Union[T, dict]] = None, + construct_object: bool = False ) -> List[T]: if filters is None: filters = {} elif isinstance(filters, cls): filters = filters.dict(exclude_unset=True, by_alias=True) - else: + elif isinstance(filters, dict): filters = dict(starmap(cls.instantiate_obj, filters.items())) cursor = database.database[cls.__collection__].find( @@ -78,10 +70,10 @@ async def aggregate(cls, pipeline, construct_object: bool = False) -> Any: @classmethod async def update( - cls, - data: T, - filters: Union[T, dict], - operator: str = "$set" + cls, + data: T, + filters: Union[T, dict], + operator: str = "$set" ) -> int: if not isinstance(filters, dict): filters = filters.dict( @@ -119,8 +111,7 @@ async def upsert(cls, data: T, filters: Union[T, dict]) -> int: result = await database.database[cls.__collection__].update_one( filters, {"$set": dict_data}, - upsert=True, - return_document=True + upsert=True ) return result.modified_count diff --git a/tests/func/test_aliasing.py b/tests/func/test_aliasing.py index 2d00c34..f482d10 100644 --- a/tests/func/test_aliasing.py +++ b/tests/func/test_aliasing.py @@ -1,4 +1,5 @@ from bson import ObjectId +from pymongo import errors import pytest from tests.schemas import Fruit @@ -8,7 +9,8 @@ async def oid(): f = Fruit(name="Banana", taste="sweet") oid = await Fruit.create(f) - return oid + yield oid + await Fruit.delete(filters={}) class TestRetrieveAndInstantiate: @@ -23,11 +25,13 @@ async def test_alias_str(self, oid: ObjectId): async def test_name_oid(self, oid: ObjectId): filters = dict(id=oid) - await self.check_response(filters) + with pytest.raises(AssertionError): + await self.check_response(filters) async def test_name_str(self, oid: ObjectId): filters = dict(id=str(oid)) - await self.check_response(filters) + with pytest.raises(AssertionError): + await self.check_response(filters) @staticmethod async def check_response(filters): diff --git a/tests/func/test_object_creation.py b/tests/func/test_object_creation.py index 22b0624..3b1dc3d 100644 --- a/tests/func/test_object_creation.py +++ b/tests/func/test_object_creation.py @@ -9,3 +9,4 @@ async def test_create_obj(db): oid = await Fruit.create(f) assert oid is not None assert isinstance(oid, ObjectId) + await Fruit.delete(filters={}) diff --git a/tests/func/test_update.py b/tests/func/test_update.py index d4ecb85..311bf1a 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -10,20 +10,20 @@ async def oids() -> list[ObjectId]: # type: ignore Fruit(name="Banana", taste="sweet"), Fruit(name="Strawberry", taste="bitter"), ] - oids = await Fruit.create_many(fruits) + oids = await Fruit.batch_create(fruits) yield oids await Fruit.delete(filters={}) class TestUpdateMany: async def test_update_many(self, oids): - data = dict(density=666) + data = Fruit(density=666) filters = {"_id": {"$in": oids}} fruits_updated = await Fruit.batch_update(data=data, filters=filters) assert fruits_updated == 2 async def test_upsert_single(self, oids): - data = dict(density=666) + data = Fruit(density=666) for oid in oids: filters = {"_id": oid} upserted_count = await Fruit.upsert(data=data, filters=filters) diff --git a/tests/schemas.py b/tests/schemas.py index 15baa51..4ed0ecf 100644 --- a/tests/schemas.py +++ b/tests/schemas.py @@ -1,10 +1,10 @@ from typing import Literal -import mongow import pydantic +import mongow -class Fruit(mongow.BaseMixin): +class Fruit(mongow.DocumentMixin): density: float = 1.0 name: str needs_peeling: bool = pydantic.Field(