Skip to content

Commit

Permalink
Add ModelDump as alternative (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil authored Jul 22, 2024
1 parent fca0d63 commit b154e4f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 44 deletions.
61 changes: 61 additions & 0 deletions mongoz/core/db/documents/_internal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from decimal import Decimal
from typing import Any, Dict

import bson
from bson.decimal128 import Decimal128
from pydantic import BaseModel, ConfigDict
from pydantic_core._pydantic_core import SchemaValidator as SchemaValidator

from mongoz.core.signals.signal import Signal


class DescriptiveMeta:
"""
The `Meta` class used to configure each metadata of the model.
Expand All @@ -18,3 +29,53 @@ class Meta:
"""

... # pragma: no cover


class ModelDump(BaseModel):
"""
Definition for a model dump. This is used to generate the model fields and their
respective values.
"""

model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True,
json_encoders={bson.ObjectId: str, Signal: str},
validate_assignment=True,
)

def convert_decimal(self, model_dump_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively converts Decimal values in the model_dump_dict to Decimal128.
Args:
model_dump_dict (Dict[str, Any]): The dictionary to convert.
Returns:
Dict[str, Any]: The converted dictionary.
"""

if not model_dump_dict:
return model_dump_dict

for key, value in model_dump_dict.items():
if isinstance(value, dict):
self.convert_decimal(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
self.convert_decimal(item)
elif isinstance(value, Decimal):
model_dump_dict[key] = Decimal128(str(value))
return model_dump_dict

def model_dump(self, show_id: bool = False, **kwargs: Any) -> Dict[str, Any]:
"""
Args:
show_pk: bool - Enforces showing the id in the model_dump.
"""
model = super().model_dump(**kwargs)
if "id" not in model and show_id:
model = {**{"id": self.id}, **model}
model_dump = self.convert_decimal(model)
return model_dump
44 changes: 4 additions & 40 deletions mongoz/core/db/documents/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import copy
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, Type, Union

import bson
import pydantic
from bson.decimal128 import Decimal128

# from bson.decimal128 import Decimal128
from pydantic import BaseModel, ConfigDict
from pydantic_core._pydantic_core import SchemaValidator as SchemaValidator

from mongoz.core.db.documents._internal import DescriptiveMeta
from mongoz.core.db.documents._internal import DescriptiveMeta, ModelDump
from mongoz.core.db.documents.document_proxy import ProxyDocument
from mongoz.core.db.documents.metaclasses import BaseModelMeta, MetaInfo
from mongoz.core.db.fields.base import MongozField
Expand Down Expand Up @@ -154,42 +154,6 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}(id={self.id})"


class MongozBaseModel(BaseMongoz):
class MongozBaseModel(BaseMongoz, ModelDump):
__mongoz_fields__: ClassVar[Mapping[str, Type["MongozField"]]]
id: Union[ObjectId, None] = pydantic.Field(alias="_id")

def convert_decimal(self, model_dump_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively converts Decimal values in the model_dump_dict to Decimal128.
Args:
model_dump_dict (Dict[str, Any]): The dictionary to convert.
Returns:
Dict[str, Any]: The converted dictionary.
"""

if not model_dump_dict:
return model_dump_dict

for key, value in model_dump_dict.items():
if isinstance(value, dict):
self.convert_decimal(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
self.convert_decimal(item)
elif isinstance(value, Decimal):
model_dump_dict[key] = Decimal128(str(value))
return model_dump_dict

def model_dump(self, show_id: bool = False, **kwargs: Any) -> Dict[str, Any]:
"""
Args:
show_pk: bool - Enforces showing the id in the model_dump.
"""
model = super().model_dump(**kwargs)
if "id" not in model and show_id:
model = {**{"id": self.id}, **model}
model_dump = self.convert_decimal(model)
return model_dump
4 changes: 2 additions & 2 deletions mongoz/core/db/documents/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel

from mongoz.core.connection.collections import Collection
from mongoz.core.db.documents._internal import ModelDump
from mongoz.core.db.documents.document_row import DocumentRow
from mongoz.core.db.documents.metaclasses import EmbeddedModelMetaClass
from mongoz.core.db.fields.base import MongozField
Expand Down Expand Up @@ -54,11 +55,10 @@ async def update(
for name, annotations in self.__annotations__.items()
if name in kwargs
}

if field_definitions:
pydantic_model: Type[BaseModel] = pydantic.create_model(
self.__class__.__name__,
__config__=self.model_config,
__base__=ModelDump,
**field_definitions,
)
model = pydantic_model.model_validate(kwargs)
Expand Down
4 changes: 3 additions & 1 deletion mongoz/core/db/querysets/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ async def update_many(self, **kwargs: Any) -> List[T]:
"""
Updates many documents (bulk update)
"""
from mongoz.core.db.documents._internal import ModelDump

manager: "Manager" = self.clone()

field_definitions = {
Expand All @@ -518,7 +520,7 @@ async def update_many(self, **kwargs: Any) -> List[T]:
if field_definitions:
pydantic_model: Type[pydantic.BaseModel] = pydantic.create_model(
manager.model_class.__name__,
__config__=manager.model_class.model_config,
__base__=ModelDump,
**field_definitions,
)
model = pydantic_model.model_validate(kwargs)
Expand Down
4 changes: 3 additions & 1 deletion mongoz/core/db/querysets/core/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ async def bulk_update(self, **kwargs: Any) -> List[T]:
return await self.update_many(**kwargs)

async def update_many(self, **kwargs: Any) -> List[T]:
from mongoz.core.db.documents._internal import ModelDump

field_definitions = {
name: (annotations, ...)
for name, annotations in self.model_class.__annotations__.items()
Expand All @@ -287,7 +289,7 @@ async def update_many(self, **kwargs: Any) -> List[T]:
if field_definitions:
pydantic_model: Type[pydantic.BaseModel] = pydantic.create_model(
self.model_class.__name__,
__config__=self.model_class.model_config,
__base__=ModelDump,
**field_definitions,
)
model = pydantic_model.model_validate(kwargs)
Expand Down
26 changes: 26 additions & 0 deletions tests/models/manager/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pydantic
import pytest
from bson import Decimal128

import mongoz
from mongoz import Document
Expand Down Expand Up @@ -40,3 +41,28 @@ async def test_decimal_128_two() -> None:

arch = await Archive.objects.last()
assert float(str(arch.price)) == 22.246


async def test_decimal_on_update() -> None:
await Archive.objects.create(name="Batman", price="22.246")

arch = await Archive.objects.last()

arch.price = Decimal("28")
await arch.save()

arch = await Archive.objects.last()

assert arch.price == Decimal128("28")

await arch.update(price=Decimal("30"))

arch = await Archive.objects.last()

assert arch.price == Decimal128("30")

await Archive.objects.filter().update(price=Decimal("40"))

arch = await Archive.objects.last()

assert arch.price == Decimal128("40")

0 comments on commit b154e4f

Please sign in to comment.