Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make allow_call work with Annotated literals #540

Merged
merged 3 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- `allow_call` callables are now also called if the arguments
are literals wrapped in `Annotated` (#540)
- Support Python 3.11 (#537)
- Fix type checking of binary operators involving unions (#531)
- Improve `TypeVar` solution heuristic for constrained
Expand Down
49 changes: 30 additions & 19 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from typing_extensions import assert_never, Literal, Protocol, Self

from .error_code import ErrorCode
from .safe import all_of_type
from .stacked_scopes import (
AbstractConstraint,
AndConstraint,
Expand Down Expand Up @@ -1254,41 +1253,45 @@ def _maybe_perform_call(
args = []
kwargs = {}
for definitely_present, composite in actual_args.positionals:
if definitely_present and isinstance(composite.value, KnownValue):
args.append(composite.value.val)
else:
if not definitely_present:
return None
arg = _extract_known_value(composite.value)
if arg is None:
return None
args.append(arg.val)
if actual_args.star_args is not None:
values = concrete_values_from_iterable(
actual_args.star_args, ctx.can_assign_ctx
)
if isinstance(values, collections.abc.Sequence) and all_of_type(
values, KnownValue
):
args += [val.val for val in values]
else:
if not isinstance(values, collections.abc.Sequence):
return None
for args_val in values:
arg = _extract_known_value(args_val)
if arg is None:
return None
args.append(arg.val)
for kwarg, (required, composite) in actual_args.keywords.items():
if not required:
return None
if isinstance(composite.value, KnownValue):
kwargs[kwarg] = composite.value.val
else:
kwarg_value = _extract_known_value(composite.value)
if kwarg_value is None:
return None
kwargs[kwarg] = kwarg_value.val
if actual_args.star_kwargs is not None:
value = replace_known_sequence_value(actual_args.star_kwargs)
if isinstance(value, DictIncompleteValue):
for pair in value.kv_pairs:
if pair.is_many or not pair.is_required:
return None
key_val = _extract_known_value(pair.key)
value_val = _extract_known_value(pair.value)
if (
pair.is_required
and not pair.is_many
and isinstance(pair.key, KnownValue)
and isinstance(pair.key.val, str)
and isinstance(pair.value, KnownValue)
key_val is None
or value_val is None
or not isinstance(key_val.val, str)
):
kwargs[pair.key.val] = pair.value.val
else:
return None
kwargs[key_val.val] = value_val.val
else:
return None

Expand Down Expand Up @@ -2518,3 +2521,11 @@ def decompose_union(
), f"all union members matched between {expected_type} and {parent_value}"
return bounds_map, union_used_any, unite_values(*remaining_values)
return None


def _extract_known_value(val: Value) -> Optional[KnownValue]:
if isinstance(val, AnnotatedValue):
val = val.value
if isinstance(val, KnownValue):
return val
return None
11 changes: 11 additions & 0 deletions pyanalyze/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,17 @@ def capybara():

s.encode("not an encoding") # E: incompatible_call

@assert_passes()
def test_annotated_known(self):
from typing_extensions import Literal, Annotated
from pyanalyze.extensions import LiteralOnly

def capybara():
encoding: Annotated[Literal["ascii"], LiteralOnly()] = "ascii"

s = "x"
assert_is_value(s.encode(encoding), KnownValue(b"x"))


class TestOverload(TestNameCheckVisitorBase):
@assert_passes()
Expand Down