Skip to content

Commit

Permalink
Improve in handling (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 16, 2023
1 parent 78ee166 commit b689b7a
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 71 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Support `in` on objects with only `__iter__` (#588)
- Do not call `.mro()` method on non-types (#587)
- Add `class_attribute_transformers` hook (#585)
- Support for PEP 702 (`@typing.deprecated`) (#578)
Expand Down
18 changes: 8 additions & 10 deletions pyanalyze/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .options import Options, PyObjectSequenceOption
from .signature import ParameterKind, Signature, SigParameter
from .stacked_scopes import Composite
from .typevar import resolve_bounds_map
from .value import (
AnySource,
AnyValue,
Expand All @@ -33,6 +32,7 @@
GenericValue,
get_tv_map,
KnownValue,
is_iterable,
make_coro_type,
SubclassValue,
TypedValue,
Expand All @@ -49,7 +49,6 @@
YieldT = TypeVar("YieldT")
SendT = TypeVar("SendT")
ReturnT = TypeVar("ReturnT")
IterableValue = GenericValue(collections.abc.Iterable, [TypeVarValue(YieldT)])
GeneratorValue = GenericValue(
collections.abc.Generator,
[TypeVarValue(YieldT), TypeVarValue(SendT), TypeVarValue(ReturnT)],
Expand Down Expand Up @@ -93,11 +92,10 @@ class FunctionInfo:
def get_generator_yield_type(self, ctx: CanAssignContext) -> Value:
if self.return_annotation is None:
return AnyValue(AnySource.unannotated)
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
if isinstance(can_assign, CanAssignError):
iterable_val = is_iterable(self.return_annotation, ctx)
if isinstance(iterable_val, CanAssignError):
return AnyValue(AnySource.error)
tv_map, _ = resolve_bounds_map(can_assign, ctx)
return tv_map.get(YieldT, AnyValue(AnySource.generic_argument))
return iterable_val

def get_generator_send_type(self, ctx: CanAssignContext) -> Value:
if self.return_annotation is None:
Expand All @@ -107,8 +105,8 @@ def get_generator_send_type(self, ctx: CanAssignContext) -> Value:
return tv_map.get(SendT, AnyValue(AnySource.generic_argument))
# If the return annotation is a non-Generator Iterable, assume the send
# type is None.
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
if isinstance(can_assign, CanAssignError):
iterable_val = is_iterable(self.return_annotation, ctx)
if isinstance(iterable_val, CanAssignError):
return AnyValue(AnySource.error)
return KnownValue(None)

Expand All @@ -120,8 +118,8 @@ def get_generator_return_type(self, ctx: CanAssignContext) -> Value:
return tv_map.get(ReturnT, AnyValue(AnySource.generic_argument))
# If the return annotation is a non-Generator Iterable, assume the return
# type is None.
can_assign = IterableValue.can_assign(self.return_annotation, ctx)
if isinstance(can_assign, CanAssignError):
iterable_val = is_iterable(self.return_annotation, ctx)
if isinstance(iterable_val, CanAssignError):
return AnyValue(AnySource.error)
return KnownValue(None)

Expand Down
2 changes: 1 addition & 1 deletion pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def inner(key: Value) -> Value:
return AnyValue(AnySource.error) # shouldn't happen
key = replace_known_sequence_value(key)
if not TypedValue(slice).is_assignable(key, ctx.visitor):
key = ctx.visitor._check_dunder_call(
key, _ = ctx.visitor._check_dunder_call(
ctx.ast_for_arg("obj"), Composite(key), "__index__", [], allow_call=True
)

Expand Down
141 changes: 97 additions & 44 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
FunctionResult,
GeneratorValue,
IMPLICIT_CLASSMETHODS,
IterableValue,
ReturnT,
SendT,
YieldT,
Expand Down Expand Up @@ -173,6 +172,7 @@
GenericBases,
GenericValue,
get_tv_map,
is_iterable,
is_union,
KnownValue,
kv_pairs_from_mapping,
Expand Down Expand Up @@ -3234,7 +3234,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> Value:
else:
operand = self.composite_from_node(node.operand)
_, method = UNARY_OPERATION_TO_DESCRIPTION_AND_METHOD[type(node.op)]
val = self._check_dunder_call(node, operand, method, [], allow_call=True)
val, _ = self._check_dunder_call(node, operand, method, [], allow_call=True)
return val

def visit_BinOp(self, node: ast.BinOp) -> Value:
Expand Down Expand Up @@ -3299,7 +3299,10 @@ def _visit_binop_internal(
if is_inplace:
assert imethod is not None, f"no inplace method available for {op}"
with self.catch_errors() as inplace_errors:
inplace_result = self._check_dunder_call(
# Not _check_dunder_call_or_catch because if the call doesn't
# typecheck it normally returns NotImplemented and we try the
# non-inplace method next.
inplace_result, _ = self._check_dunder_call(
source_node,
left_composite,
imethod,
Expand Down Expand Up @@ -3335,42 +3338,55 @@ def _visit_binop_no_mvv(
type(op)
]
if rmethod is None:
# "in" falls back to __getitem__ if __contains__ is not defined
# "in" falls back to __iter__ and then to __getitem__ if __contains__ is not defined
if method == "__contains__":
with self.catch_errors() as contains_errors:
contains_result = self._check_dunder_call(
source_node,
left_composite,
method,
[right_composite],
allow_call=allow_call,
)
if not contains_errors:
return contains_result

with self.catch_errors() as getitem_errors:
self._check_dunder_call(
source_node,
left_composite,
"__getitem__",
[right_composite],
allow_call=allow_call,
)
if not getitem_errors:
contains_result_or_errors = self._check_dunder_call_or_catch(
source_node,
left_composite,
method,
[right_composite],
allow_call=allow_call,
)
if isinstance(contains_result_or_errors, Value):
return contains_result_or_errors

iterable_type = is_iterable(left, self)
if isinstance(iterable_type, Value):
can_assign = iterable_type.can_assign(right, self)
if isinstance(can_assign, CanAssignError):
self._show_error_if_checking(
source_node,
"Unsupported operand for 'in'",
ErrorCode.incompatible_argument,
detail=str(can_assign),
)
return TypedValue(bool)
else:
return TypedValue(bool)

getitem_result = self._check_dunder_call_or_catch(
source_node,
left_composite,
"__getitem__",
[right_composite],
allow_call=allow_call,
)
if isinstance(getitem_result, Value):
return TypedValue(bool) # Always returns a bool
self.show_caught_errors(contains_errors)
self.show_caught_errors(contains_result_or_errors)
return TypedValue(bool)

return self._check_dunder_call(
result, _ = self._check_dunder_call(
source_node,
left_composite,
method,
[right_composite],
allow_call=allow_call,
)
return result

with self.catch_errors() as left_errors:
left_result = self._check_dunder_call(
left_result, _ = self._check_dunder_call(
source_node,
left_composite,
method,
Expand All @@ -3379,7 +3395,7 @@ def _visit_binop_no_mvv(
)

with self.catch_errors() as right_errors:
right_result = self._check_dunder_call(
right_result, _ = self._check_dunder_call(
source_node,
right_composite,
rmethod,
Expand Down Expand Up @@ -3458,7 +3474,8 @@ def visit_Await(self, node: ast.Await) -> Value:
def unpack_awaitable(self, composite: Composite, node: ast.AST) -> Value:
tv_map = get_tv_map(AwaitableValue, composite.value, self)
if isinstance(tv_map, CanAssignError):
return self._check_dunder_call(node, composite, "__await__", [])
result, _ = self._check_dunder_call(node, composite, "__await__", [])
return result
else:
return tv_map.get(T, AnyValue(AnySource.generic_argument))

Expand All @@ -3473,8 +3490,8 @@ def visit_YieldFrom(self, node: ast.YieldFrom) -> Value:
ReturnT: can_assign.get(T, AnyValue(AnySource.generic_argument))
}
else:
can_assign = get_tv_map(IterableValue, value, self)
if isinstance(can_assign, CanAssignError):
iterable_type = is_iterable(value, self)
if isinstance(iterable_type, CanAssignError):
self._show_error_if_checking(
node,
f"Cannot use {value} in yield from",
Expand All @@ -3483,9 +3500,7 @@ def visit_YieldFrom(self, node: ast.YieldFrom) -> Value:
)
tv_map = {ReturnT: AnyValue(AnySource.error)}
else:
tv_map = {
YieldT: can_assign.get(T, AnyValue(AnySource.generic_argument))
}
tv_map = {YieldT: iterable_type}

if self.current_function_info is not None:
expected_yield = self.current_function_info.get_generator_yield_type(self)
Expand Down Expand Up @@ -3838,8 +3853,8 @@ def _member_value_of_iterator(
"""
composite = self.composite_from_node(node)
if is_async:
iterator = self._check_dunder_call(node, composite, "__aiter__", [])
anext = self._check_dunder_call(
iterator, _ = self._check_dunder_call(node, composite, "__aiter__", [])
anext, _ = self._check_dunder_call(
node, Composite(iterator, None, node), "__anext__", []
)
return self.unpack_awaitable(Composite(anext), node)
Expand Down Expand Up @@ -4468,9 +4483,10 @@ def _composite_from_subscript_no_mvv(
return_value = local_value
return return_value
elif isinstance(node.ctx, ast.Del):
return self._check_dunder_call(
result, _ = self._check_dunder_call(
node.value, root_composite, "__delitem__", [index_composite]
)
return result
else:
self.show_error(
node,
Expand All @@ -4495,28 +4511,65 @@ def _get_dunder(self, node: ast.AST, callee_val: Value, method_name: str) -> Val
)
return method_object

def _check_dunder_call_or_catch(
self,
node: ast.AST,
callee_composite: Composite,
method_name: str,
args: Iterable[Composite],
allow_call: bool = False,
) -> Union[Value, List[node_visitor.Error]]:
"""Use this for checking a dunder call that may fall back to another.
There are three cases:
- The dunder does not exist. We want to defer the error, in case the fallback
exists.
- The dunder exists and the call typechecks. We want to return its result.
- The dunder exists, but the call doesn't typecheck. We want to show the error
immediately and return Any.
"""
with self.catch_errors() as errors:
result, exists = self._check_dunder_call(
node, callee_composite, method_name, args, allow_call=allow_call
)
if not errors:
return result
elif exists:
# Inplace method exists, but it doesn't accept these arguments
self.show_caught_errors(errors)
return result
else:
return errors

def _check_dunder_call(
self,
node: ast.AST,
callee_composite: Composite,
method_name: str,
args: Iterable[Composite],
allow_call: bool = False,
) -> Value:
) -> Tuple[Value, bool]:
if isinstance(callee_composite.value, MultiValuedValue):
composites = [
Composite(val, callee_composite.varname, callee_composite.node)
for val in callee_composite.value.vals
]
with qcore.override(self, "in_union_decomposition", True):
values = [
values_and_exists = [
self._check_dunder_call_no_mvv(
node, composite, method_name, args, allow_call
)
for composite in composites
]
return unite_and_simplify(
*values, limit=self.options.get_value_for(UnionSimplificationLimit)
values = [value for value, _ in values_and_exists]
# TODO: We should do something more complex when unions are involved.
exists = all(exists for _, exists in values_and_exists)
return (
unite_and_simplify(
*values, limit=self.options.get_value_for(UnionSimplificationLimit)
),
exists,
)
return self._check_dunder_call_no_mvv(
node, callee_composite, method_name, args, allow_call
Expand All @@ -4529,14 +4582,14 @@ def _check_dunder_call_no_mvv(
method_name: str,
args: Iterable[Composite],
allow_call: bool = False,
) -> Value:
) -> Tuple[Value, bool]:
method_object = self._get_dunder(node, callee_composite.value, method_name)
if method_object is UNINITIALIZED_VALUE:
return AnyValue(AnySource.error)
return AnyValue(AnySource.error), False
return_value = self.check_call(
node, method_object, [callee_composite, *args], allow_call=allow_call
)
return return_value
return return_value, True

def _get_composite(self, composite: Varname, node: ast.AST, value: Value) -> Value:
local_value, _ = self.scopes.current_scope().get_local(
Expand Down
6 changes: 4 additions & 2 deletions pyanalyze/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@

from . import analysis_lib

Error = Dict[str, Any]


@dataclass(frozen=True)
class _FakeNode:
Expand Down Expand Up @@ -541,12 +543,12 @@ def _get_default_settings(cls) -> Optional[Dict[Enum, bool]]:
}

@contextmanager
def catch_errors(self) -> Iterator[List[Dict[str, Any]]]:
def catch_errors(self) -> Iterator[List[Error]]:
caught_errors = []
with qcore.override(self, "caught_errors", caught_errors):
yield caught_errors

def show_caught_errors(self, errors: Iterable[Dict[str, Any]]) -> None:
def show_caught_errors(self, errors: Iterable[Error]) -> None:
for error in errors:
self.show_error(**error)

Expand Down
Loading

0 comments on commit b689b7a

Please sign in to comment.