Skip to content

Commit

Permalink
Support typing_extensions.get_overloads (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 16, 2023
1 parent b689b7a commit 83bb6ce
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Support `typing_extensions.get_overloads` and `typing.get_overloads` (#589)
- Support `in` on objects with only `__iter__` (#588)
- Do not call `.mro()` method on non-types (#587)
- Add `class_attribute_transformers` hook (#585)
Expand Down
66 changes: 55 additions & 11 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import textwrap
from dataclasses import dataclass, replace
from types import FunctionType, MethodType, ModuleType
import typing
from typing import (
Any,
Callable,
Expand All @@ -36,7 +37,12 @@
from . import implementation
from .analysis_lib import is_positional_only_arg_name
from .annotations import Context, RuntimeEvaluator, type_from_runtime
from .extensions import CustomCheck, get_overloads, get_type_evaluations, TypeGuard
from .extensions import (
CustomCheck,
get_overloads as pyanalyze_get_overloads,
get_type_evaluations,
TypeGuard,
)
from .find_unused import used
from .functions import translate_vararg_type
from .options import Options, PyObjectSequenceOption
Expand All @@ -47,6 +53,8 @@
is_newtype,
is_typing_name,
safe_equals,
safe_getattr,
safe_hasattr,
safe_isinstance,
safe_issubclass,
)
Expand Down Expand Up @@ -83,6 +91,19 @@
Value,
)

_GET_OVERLOADS = []

try:
from typing_extensions import get_overloads
except ImportError:
pass
else:
_GET_OVERLOADS.append(get_overloads)
if sys.version_info >= (3, 11):
# TODO: support version checks
# static analysis: ignore[undefined_attribute]
_GET_OVERLOADS.append(typing.get_overloads)

# types.MethodWrapperType in 3.7+
MethodWrapperType = type(object().__str__)

Expand Down Expand Up @@ -578,18 +599,23 @@ def _uncached_get_argspec(
# Must be after the check for bound methods, because otherwise we
# won't bind self correctly.
if not in_overload_resolution:
for get_overloads_func in _GET_OVERLOADS:
inner_obj = safe_getattr(obj, "__func__", obj)
if safe_hasattr(inner_obj, "__module__") and safe_hasattr(
inner_obj, "__qualname__"
):
sig = self._maybe_make_overloaded_signature(
get_overloads_func(inner_obj), impl, is_asynq
)
if sig is not None:
return sig
fq_name = get_fully_qualified_name(obj)
if fq_name is not None:
overloads = get_overloads(fq_name)
if overloads:
sigs = [
self._cached_get_argspec(
overload, impl, is_asynq, in_overload_resolution=True
)
for overload in overloads
]
if all_of_type(sigs, Signature):
return OverloadedSignature(sigs)
sig = self._maybe_make_overloaded_signature(
pyanalyze_get_overloads(fq_name), impl, is_asynq
)
if sig is not None:
return sig
evaluator_sig = self._maybe_make_evaluator_sig(obj, impl, is_asynq)
if evaluator_sig is not None:
return evaluator_sig
Expand Down Expand Up @@ -790,6 +816,24 @@ def _uncached_get_argspec(

return None

def _maybe_make_overloaded_signature(
self,
overloads: Sequence[Callable[..., Any]],
impl: Optional[Impl],
is_asynq: bool,
) -> Optional[OverloadedSignature]:
if not overloads:
return None
sigs = [
self._cached_get_argspec(
overload, impl, is_asynq, in_overload_resolution=True
)
for overload in overloads
]
if not all_of_type(sigs, Signature):
return None
return OverloadedSignature(sigs)

def _make_any_sig(self, obj: object) -> Signature:
if FunctionsSafeToCall.contains(obj, self.options):
return Signature.make(
Expand Down
32 changes: 30 additions & 2 deletions pyanalyze/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def capybara():
overloaded("x", "y") # E: incompatible_call

@assert_passes()
def test_runtime(self):
def test_pyanalyze_extensions(self):
from typing import Union

from pyanalyze.extensions import overload
Expand All @@ -936,7 +936,35 @@ def overloaded(*args: str) -> Union[int, str]:
def capybara():
assert_is_value(overloaded(), TypedValue(int))
assert_is_value(overloaded("x"), TypedValue(str))
overloaded(1) # E: incompatible_call
overloaded(1) # E: incompatible_argument
overloaded("a", "b") # E: incompatible_call

@assert_passes()
def test_typing_extensions(self):
from typing import Union

from typing_extensions import overload

@overload
def overloaded() -> int:
raise NotImplementedError

@overload
def overloaded(x: str) -> str:
raise NotImplementedError

def overloaded(*args: str) -> Union[int, str]:
if not args:
return 0
elif len(args) == 1:
return args[0]
else:
raise TypeError("too many arguments")

def capybara():
assert_is_value(overloaded(), TypedValue(int))
assert_is_value(overloaded("x"), TypedValue(str))
overloaded(1) # E: incompatible_argument
overloaded("a", "b") # E: incompatible_call

@assert_passes()
Expand Down

0 comments on commit 83bb6ce

Please sign in to comment.