Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
Expect error on tests filtering on attr names
Browse files Browse the repository at this point in the history
  • Loading branch information
cardoso-neto committed Jul 25, 2022
1 parent 8693dab commit e5cbbc1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 45 deletions.
30 changes: 19 additions & 11 deletions mongow/mixins/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import (
Any,
Generic,
Optional,
TypeVar,
Union
)

import pymongo
from bson import ObjectId
from pydantic import (
BaseModel,
Field
Expand All @@ -14,13 +17,11 @@
from ..utils import (
Direction,
Indice,
ObjectId,
PyObjectId
)

T = TypeVar("T", bound="BaseMixin")


class BaseMixin(
BaseModel,
Generic[T]
Expand All @@ -37,7 +38,7 @@ class Config:
ObjectId: str
}

def __new__(cls):
def __new__(cls, *_, **__):
self_dict = __class__.Config.__dict__
subclass_dict = cls.Config.__dict__
for key in self_dict:
Expand Down Expand Up @@ -73,14 +74,10 @@ async def create_indices(cls):
def build_indice(cls, data: Union[Indice, tuple]) -> Indice:
if isinstance(data, tuple):
data = Indice(
keys=[(data[0], data[1])],
name=data[2] if len(data) > 2 else None,
unique=data[3] if len(data) > 3 else None,
background=data[4] if len(data) > 4 else None,
sparse=data[5] if len(data) > 5 else None,
bucket_size=data[6] if len(data) > 6 else None,
min=data[7] if len(data) > 7 else None,
max=data[8] if len(data) > 8 else None
keys=[(
data[0],
data[1] if len(data) > 1 else Direction.ASCENDING
)]
)

if not isinstance(data, Indice):
Expand Down Expand Up @@ -116,3 +113,14 @@ def build_pymongo_index(indice: Indice) -> pymongo.IndexModel:
kwargs["name"] = "_".join([key[0] for key in indice.keys]) + "_index"

return pymongo.IndexModel(**kwargs)

@classmethod
def instantiate_obj(cls, key: str, value: Union[str, Any]) -> tuple[str, Any]:
if isinstance(value, str):
if key == "_id":
return key, ObjectId(value)
# TODO: get attr name by alias name
# assume it is an _id and pop off its underscore
# cls.schema(by_alias=True).get("properties").keys()
# return key, typing.get_type_hints(cls)[search_key](value)
return key, value
43 changes: 17 additions & 26 deletions mongow/mixins/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
List,
Optional,
Tuple,
Union
Type,
Union,
)

from bson import ObjectId
Expand All @@ -19,39 +20,30 @@

class ModelMixin(BaseMixin):

@classmethod
def instantiate_obj(cls, key: str, value: Union[str, Any]) -> tuple[str, Any]:
if isinstance(value, str):
if key == "_id":
return key, ObjectId(value)
# TODO: get attr name by alias name
# assume it is an _id and pop off its underscore
# cls.schema(by_alias=True).get("properties").keys()
# return key, typing.get_type_hints(cls)[search_key](value)
return key, value

@classmethod
async def create(cls, data: T) -> ObjectId:
if not data.id:
data.id = ObjectId()
result = await database.database[cls.__collection__].insert_one(
data.dict(by_alias=True)
)
return result.inserted_id

@classmethod
async def read(
cls,
fields: Iterable[str] = tuple(),
order: Optional[Tuple[str, bool]] = None,
offset: int = 0,
limit: int = 100,
filters: Optional[Union[T, dict]] = None,
construct_object: bool = False
cls: Type[T],
fields: Iterable[str] = tuple(),
order: Optional[Tuple[str, bool]] = None,
offset: int = 0,
limit: int = 100,
filters: Optional[Union[T, dict]] = None,
construct_object: bool = False
) -> List[T]:
if filters is None:
filters = {}
elif isinstance(filters, cls):
filters = filters.dict(exclude_unset=True, by_alias=True)
else:
elif isinstance(filters, dict):
filters = dict(starmap(cls.instantiate_obj, filters.items()))

cursor = database.database[cls.__collection__].find(
Expand All @@ -78,10 +70,10 @@ async def aggregate(cls, pipeline, construct_object: bool = False) -> Any:

@classmethod
async def update(
cls,
data: T,
filters: Union[T, dict],
operator: str = "$set"
cls,
data: T,
filters: Union[T, dict],
operator: str = "$set"
) -> int:
if not isinstance(filters, dict):
filters = filters.dict(
Expand Down Expand Up @@ -119,8 +111,7 @@ async def upsert(cls, data: T, filters: Union[T, dict]) -> int:
result = await database.database[cls.__collection__].update_one(
filters,
{"$set": dict_data},
upsert=True,
return_document=True
upsert=True
)
return result.modified_count

Expand Down
10 changes: 7 additions & 3 deletions tests/func/test_aliasing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bson import ObjectId
from pymongo import errors
import pytest

from tests.schemas import Fruit
Expand All @@ -8,7 +9,8 @@
async def oid():
f = Fruit(name="Banana", taste="sweet")
oid = await Fruit.create(f)
return oid
yield oid
await Fruit.delete(filters={})


class TestRetrieveAndInstantiate:
Expand All @@ -23,11 +25,13 @@ async def test_alias_str(self, oid: ObjectId):

async def test_name_oid(self, oid: ObjectId):
filters = dict(id=oid)
await self.check_response(filters)
with pytest.raises(AssertionError):
await self.check_response(filters)

async def test_name_str(self, oid: ObjectId):
filters = dict(id=str(oid))
await self.check_response(filters)
with pytest.raises(AssertionError):
await self.check_response(filters)

@staticmethod
async def check_response(filters):
Expand Down
1 change: 1 addition & 0 deletions tests/func/test_object_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ async def test_create_obj(db):
oid = await Fruit.create(f)
assert oid is not None
assert isinstance(oid, ObjectId)
await Fruit.delete(filters={})
6 changes: 3 additions & 3 deletions tests/func/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ async def oids() -> list[ObjectId]: # type: ignore
Fruit(name="Banana", taste="sweet"),
Fruit(name="Strawberry", taste="bitter"),
]
oids = await Fruit.create_many(fruits)
oids = await Fruit.batch_create(fruits)
yield oids
await Fruit.delete(filters={})


class TestUpdateMany:
async def test_update_many(self, oids):
data = dict(density=666)
data = Fruit(density=666)
filters = {"_id": {"$in": oids}}
fruits_updated = await Fruit.batch_update(data=data, filters=filters)
assert fruits_updated == 2

async def test_upsert_single(self, oids):
data = dict(density=666)
data = Fruit(density=666)
for oid in oids:
filters = {"_id": oid}
upserted_count = await Fruit.upsert(data=data, filters=filters)
Expand Down
4 changes: 2 additions & 2 deletions tests/schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Literal

import mongow
import pydantic

import mongow

class Fruit(mongow.BaseMixin):
class Fruit(mongow.DocumentMixin):
density: float = 1.0
name: str
needs_peeling: bool = pydantic.Field(
Expand Down

0 comments on commit e5cbbc1

Please sign in to comment.