From 45e1bf7a83686a5b933eb009447e89e5d1c41ca9 Mon Sep 17 00:00:00 2001 From: Valentin Stanciu <250871+svalentin@users.noreply.github.com> Date: Fri, 14 Jul 2023 14:34:10 +0300 Subject: [PATCH] Typeshed cherry-pick: Fix @patch when `new` is missing (#10459) (#15673) For release-1.5 --- mypy/typeshed/stdlib/unittest/mock.pyi | 27 ++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/mypy/typeshed/stdlib/unittest/mock.pyi b/mypy/typeshed/stdlib/unittest/mock.pyi index 0ed0701cc450..db1cc7d9bfc9 100644 --- a/mypy/typeshed/stdlib/unittest/mock.pyi +++ b/mypy/typeshed/stdlib/unittest/mock.pyi @@ -234,6 +234,8 @@ class _patch(Generic[_T]): def copy(self) -> _patch[_T]: ... @overload def __call__(self, func: _TT) -> _TT: ... + # If new==DEFAULT, this should add a MagicMock parameter to the function + # arguments. See the _patch_default_new class below for this functionality. @overload def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ... if sys.version_info >= (3, 8): @@ -257,6 +259,22 @@ class _patch(Generic[_T]): def start(self) -> _T: ... def stop(self) -> None: ... +if sys.version_info >= (3, 8): + _Mock: TypeAlias = MagicMock | AsyncMock +else: + _Mock: TypeAlias = MagicMock + +# This class does not exist at runtime, it's a hack to make this work: +# @patch("foo") +# def bar(..., mock: MagicMock) -> None: ... +class _patch_default_new(_patch[_Mock]): + @overload + def __call__(self, func: _TT) -> _TT: ... + # Can't use the following as ParamSpec is only allowed as last parameter: + # def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ... + @overload + def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ... + class _patch_dict: in_dict: Any values: Any @@ -273,11 +291,8 @@ class _patch_dict: start: Any stop: Any -if sys.version_info >= (3, 8): - _Mock: TypeAlias = MagicMock | AsyncMock -else: - _Mock: TypeAlias = MagicMock - +# This class does not exist at runtime, it's a hack to add methods to the +# patch() function. class _patcher: TEST_PREFIX: str dict: type[_patch_dict] @@ -307,7 +322,7 @@ class _patcher: autospec: Any | None = ..., new_callable: Any | None = ..., **kwargs: Any, - ) -> _patch[_Mock]: ... + ) -> _patch_default_new: ... @overload @staticmethod def object( # type: ignore[misc]