From 83bb6ce0a23409cfb4f740159d9b463aa2da1fc7 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 16 Jan 2023 12:41:48 -0800 Subject: [PATCH] Support typing_extensions.get_overloads (#589) --- docs/changelog.md | 1 + pyanalyze/arg_spec.py | 66 ++++++++++++++++++++++++++++++------- pyanalyze/test_signature.py | 32 ++++++++++++++++-- 3 files changed, 86 insertions(+), 13 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 7a181852..60b9e9a4 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Support `typing_extensions.get_overloads` and `typing.get_overloads` (#589) - Support `in` on objects with only `__iter__` (#588) - Do not call `.mro()` method on non-types (#587) - Add `class_attribute_transformers` hook (#585) diff --git a/pyanalyze/arg_spec.py b/pyanalyze/arg_spec.py index b0fab9c5..f48401be 100644 --- a/pyanalyze/arg_spec.py +++ b/pyanalyze/arg_spec.py @@ -13,6 +13,7 @@ import textwrap from dataclasses import dataclass, replace from types import FunctionType, MethodType, ModuleType +import typing from typing import ( Any, Callable, @@ -36,7 +37,12 @@ from . import implementation from .analysis_lib import is_positional_only_arg_name from .annotations import Context, RuntimeEvaluator, type_from_runtime -from .extensions import CustomCheck, get_overloads, get_type_evaluations, TypeGuard +from .extensions import ( + CustomCheck, + get_overloads as pyanalyze_get_overloads, + get_type_evaluations, + TypeGuard, +) from .find_unused import used from .functions import translate_vararg_type from .options import Options, PyObjectSequenceOption @@ -47,6 +53,8 @@ is_newtype, is_typing_name, safe_equals, + safe_getattr, + safe_hasattr, safe_isinstance, safe_issubclass, ) @@ -83,6 +91,19 @@ Value, ) +_GET_OVERLOADS = [] + +try: + from typing_extensions import get_overloads +except ImportError: + pass +else: + _GET_OVERLOADS.append(get_overloads) +if sys.version_info >= (3, 11): + # TODO: support version checks + # static analysis: ignore[undefined_attribute] + _GET_OVERLOADS.append(typing.get_overloads) + # types.MethodWrapperType in 3.7+ MethodWrapperType = type(object().__str__) @@ -578,18 +599,23 @@ def _uncached_get_argspec( # Must be after the check for bound methods, because otherwise we # won't bind self correctly. if not in_overload_resolution: + for get_overloads_func in _GET_OVERLOADS: + inner_obj = safe_getattr(obj, "__func__", obj) + if safe_hasattr(inner_obj, "__module__") and safe_hasattr( + inner_obj, "__qualname__" + ): + sig = self._maybe_make_overloaded_signature( + get_overloads_func(inner_obj), impl, is_asynq + ) + if sig is not None: + return sig fq_name = get_fully_qualified_name(obj) if fq_name is not None: - overloads = get_overloads(fq_name) - if overloads: - sigs = [ - self._cached_get_argspec( - overload, impl, is_asynq, in_overload_resolution=True - ) - for overload in overloads - ] - if all_of_type(sigs, Signature): - return OverloadedSignature(sigs) + sig = self._maybe_make_overloaded_signature( + pyanalyze_get_overloads(fq_name), impl, is_asynq + ) + if sig is not None: + return sig evaluator_sig = self._maybe_make_evaluator_sig(obj, impl, is_asynq) if evaluator_sig is not None: return evaluator_sig @@ -790,6 +816,24 @@ def _uncached_get_argspec( return None + def _maybe_make_overloaded_signature( + self, + overloads: Sequence[Callable[..., Any]], + impl: Optional[Impl], + is_asynq: bool, + ) -> Optional[OverloadedSignature]: + if not overloads: + return None + sigs = [ + self._cached_get_argspec( + overload, impl, is_asynq, in_overload_resolution=True + ) + for overload in overloads + ] + if not all_of_type(sigs, Signature): + return None + return OverloadedSignature(sigs) + def _make_any_sig(self, obj: object) -> Signature: if FunctionsSafeToCall.contains(obj, self.options): return Signature.make( diff --git a/pyanalyze/test_signature.py b/pyanalyze/test_signature.py index c2c66afa..aa515d4d 100644 --- a/pyanalyze/test_signature.py +++ b/pyanalyze/test_signature.py @@ -912,7 +912,7 @@ def capybara(): overloaded("x", "y") # E: incompatible_call @assert_passes() - def test_runtime(self): + def test_pyanalyze_extensions(self): from typing import Union from pyanalyze.extensions import overload @@ -936,7 +936,35 @@ def overloaded(*args: str) -> Union[int, str]: def capybara(): assert_is_value(overloaded(), TypedValue(int)) assert_is_value(overloaded("x"), TypedValue(str)) - overloaded(1) # E: incompatible_call + overloaded(1) # E: incompatible_argument + overloaded("a", "b") # E: incompatible_call + + @assert_passes() + def test_typing_extensions(self): + from typing import Union + + from typing_extensions import overload + + @overload + def overloaded() -> int: + raise NotImplementedError + + @overload + def overloaded(x: str) -> str: + raise NotImplementedError + + def overloaded(*args: str) -> Union[int, str]: + if not args: + return 0 + elif len(args) == 1: + return args[0] + else: + raise TypeError("too many arguments") + + def capybara(): + assert_is_value(overloaded(), TypedValue(int)) + assert_is_value(overloaded("x"), TypedValue(str)) + overloaded(1) # E: incompatible_argument overloaded("a", "b") # E: incompatible_call @assert_passes()