Skip to content

Commit

Permalink
Improve debugability of deepcopy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ZipFile committed Dec 14, 2024
1 parent aa56b70 commit 9950093
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 12 deletions.
21 changes: 21 additions & 0 deletions src/dependency_injector/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,24 @@ class Error(Exception):

class NoSuchProviderError(Error, AttributeError):
"""Error that is raised when provider lookup is failed."""


class NonCopyableArgumentError(Error):
"""Error that is raised when provider argument is not deep-copyable."""

index: int
keyword: str
provider: object

def __init__(self, provider: object, index: int = -1, keyword: str = "") -> None:
self.provider = provider
self.index = index
self.keyword = keyword

def __str__(self) -> str:
s = (
f"keyword argument {self.keyword}"
if self.keyword
else f"argument at index {self.index}"
)
return f"Couldn't copy {s} for provider {self.provider!r}"
16 changes: 15 additions & 1 deletion src/dependency_injector/providers.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,21 @@ def is_delegated(instance: Any) -> bool: ...
def represent_provider(provider: Provider, provides: Any) -> str: ...


def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ...
def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None) -> Any: ...


def deepcopy_args(
provider: Provider[Any],
args: Tuple[Any, ...],
memo: Optional[_Dict[int, Any]] = None,
) -> Tuple[Any, ...]: ...


def deepcopy_kwargs(
provider: Provider[Any],
kwargs: _Dict[str, Any],
memo: Optional[_Dict[int, Any]] = None,
) -> Dict[str, Any]: ...


def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...
Expand Down
65 changes: 54 additions & 11 deletions src/dependency_injector/providers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ except ImportError:
from .errors import (
Error,
NoSuchProviderError,
NonCopyableArgumentError,
)

cimport cython
Expand Down Expand Up @@ -1252,8 +1253,8 @@ cdef class Callable(Provider):

copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
self._copy_overridings(copied, memo)
return copied

Expand Down Expand Up @@ -2539,8 +2540,8 @@ cdef class Factory(Provider):

copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo)
return copied
Expand Down Expand Up @@ -2838,8 +2839,8 @@ cdef class BaseSingleton(Provider):

copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo)
return copied
Expand Down Expand Up @@ -3451,7 +3452,7 @@ cdef class List(Provider):
return copied

copied = _memorized_duplicate(self, memo)
copied.set_args(*deepcopy(self.args, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
self._copy_overridings(copied, memo)
return copied

Expand Down Expand Up @@ -3674,8 +3675,8 @@ cdef class Resource(Provider):

copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))

self._copy_overridings(copied, memo)

Expand Down Expand Up @@ -4525,8 +4526,8 @@ cdef class MethodCaller(Provider):

copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
self._copy_overridings(copied, memo)
return copied

Expand Down Expand Up @@ -4927,6 +4928,48 @@ cpdef object deepcopy(object instance, dict memo=None):
return copy.deepcopy(instance, memo)


cpdef tuple deepcopy_args(
Provider provider,
tuple args,
dict[int, object] memo = None,
):
"""A wrapper for deepcopy for positional arguments.
Used to improve debugability of objects that cannot be deep-copied.
"""

cdef list[object] out = []

for i, arg in enumerate(args):
try:
out.append(copy.deepcopy(arg, memo))
except Exception as e:
raise NonCopyableArgumentError(provider, index=i) from e

return tuple(out)


cpdef dict[str, object] deepcopy_kwargs(
Provider provider,
dict[str, object] kwargs,
dict[int, object] memo = None,
):
"""A wrapper for deepcopy for keyword arguments.
Used to improve debugability of objects that cannot be deep-copied.
"""

cdef dict[str, object] out = {}

for name, arg in kwargs.items():
try:
out[name] = copy.deepcopy(arg, memo)
except Exception as e:
raise NonCopyableArgumentError(provider, keyword=name) from e

return out


def __add_sys_streams(memo):
"""Add system streams to memo dictionary.
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/providers/utils/test_deepcopy_py3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import sys
from typing import Any, Dict, NoReturn

from pytest import raises

from dependency_injector.errors import NonCopyableArgumentError
from dependency_injector.providers import (
Provider,
deepcopy,
deepcopy_args,
deepcopy_kwargs,
)


class NonCopiable:
def __deepcopy__(self, memo: Dict[int, Any]) -> NoReturn:
raise NotImplementedError


def test_deepcopy_streams_not_copied() -> None:
l = [sys.stdin, sys.stdout, sys.stderr]
assert deepcopy(l) == l


def test_deepcopy_args() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}

assert deepcopy_args(provider, (1, copiable), memo) == (1, copiable)


def test_deepcopy_args_non_copiable() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}

with raises(
NonCopyableArgumentError,
match=r"^Couldn't copy argument at index 3 for provider ",
):
deepcopy_args(provider, (1, copiable, object(), NonCopiable()), memo)


def test_deepcopy_kwargs() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}

assert deepcopy_kwargs(provider, {"x": 1, "y": copiable}, memo) == {
"x": 1,
"y": copiable,
}


def test_deepcopy_kwargs_non_copiable() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}

with raises(
NonCopyableArgumentError,
match=r"^Couldn't copy keyword argument z for provider ",
):
deepcopy_kwargs(provider, {"x": 1, "y": copiable, "z": NonCopiable()}, memo)

0 comments on commit 9950093

Please sign in to comment.