diff --git a/docs/changelog.md b/docs/changelog.md index 49c3d187..58b9c749 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Correctly check the `self` argument to `@property` getters (#506) - Correctly track assignments of variables inside `try` blocks and inside `with` blocks that may suppress exceptions (#504) - Support mappings that do not inherit from `collections.abc.Mapping` diff --git a/pyanalyze/arg_spec.py b/pyanalyze/arg_spec.py index 13227590..968c3af0 100644 --- a/pyanalyze/arg_spec.py +++ b/pyanalyze/arg_spec.py @@ -28,7 +28,6 @@ Impl, MaybeSignature, OverloadedSignature, - PropertyArgSpec, make_bound_method, SigParameter, Signature, @@ -792,16 +791,6 @@ def _uncached_get_argspec( # these with inspect, so just give up. return self._make_any_sig(obj) - if isinstance(obj, property): - # If we know the getter, inherit its return value. - if obj.fget: - fget_argspec = self._cached_get_argspec( - obj.fget, impl, is_asynq, in_overload_resolution - ) - if fget_argspec is not None and fget_argspec.has_return_value(): - return PropertyArgSpec(obj, return_value=fget_argspec.return_value) - return PropertyArgSpec(obj) - return None def _make_any_sig(self, obj: object) -> Signature: diff --git a/pyanalyze/attributes.py b/pyanalyze/attributes.py index e0de731a..f32e9d07 100644 --- a/pyanalyze/attributes.py +++ b/pyanalyze/attributes.py @@ -64,7 +64,7 @@ def record_usage(self, obj: Any, val: Value) -> None: def record_attr_read(self, obj: Any) -> None: pass - def get_property_type_from_argspec(self, obj: Any) -> Value: + def get_property_type_from_argspec(self, obj: property) -> Value: return AnyValue(AnySource.inference) def get_attribute_from_typeshed(self, typ: type, *, on_class: bool) -> Value: diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 2b83a579..37cdcf79 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -173,6 +173,7 @@ is_union, kv_pairs_from_mapping, make_weak, + set_self, unannotate_value, unite_and_simplify, unite_values, @@ -324,6 +325,7 @@ class _AttrContext(CheckerAttrContext): visitor: "NameCheckVisitor" node: Optional[ast.AST] ignore_none: bool = False + record_reads: bool = True # Needs to be implemented explicitly to work around Cython limitations def __init__( @@ -332,11 +334,12 @@ def __init__( attr: str, visitor: "NameCheckVisitor", *, - node: Optional[ast.AST] = None, + node: Optional[ast.AST], ignore_none: bool = False, skip_mro: bool = False, skip_unwrap: bool = False, prefer_typeshed: bool = False, + record_reads: bool = True, ) -> None: super().__init__( root_composite, @@ -350,25 +353,21 @@ def __init__( self.node = node self.visitor = visitor self.ignore_none = ignore_none + self.record_reads = record_reads def record_usage(self, obj: object, val: Value) -> None: self.visitor._maybe_record_usage(obj, self.attr, val) def record_attr_read(self, obj: type) -> None: - if self.node is not None: + if self.record_reads and self.node is not None: self.visitor._record_type_attr_read(obj, self.attr, self.node) - def get_property_type_from_argspec(self, obj: object) -> Value: - argspec = self.visitor.arg_spec_cache.get_argspec(obj) - if argspec is not None: - if argspec.has_return_value(): - return argspec.return_value - # If we visited the property and inferred a return value, - # use it. - local = self.visitor.get_local_return_value(argspec) - if local is not None: - return local - return AnyValue(AnySource.inference) + def get_property_type_from_argspec(self, obj: property) -> Value: + if obj.fget is None: + return UNINITIALIZED_VALUE + + getter = set_self(KnownValue(obj.fget), self.root_composite.value) + return self.visitor.check_call(self.node, getter, [self.root_composite]) def should_ignore_none_attributes(self) -> bool: return self.ignore_none @@ -1372,8 +1371,10 @@ def _check_for_incompatible_overrides( Composite(base_class_value), varname, self, + node=node, skip_mro=True, skip_unwrap=True, + record_reads=False, ) base_value = attributes.get_attribute(ctx) can_assign = self._can_assign_to_base(base_value, value) @@ -1723,6 +1724,12 @@ def _set_argspec_to_retval( if isinstance(info.node, ast.AsyncFunctionDef) or info.is_decorated_coroutine: return_value = GenericValue(collections.abc.Awaitable, [return_value]) + if isinstance(val, KnownValue) and isinstance(val.val, property): + fget = val.val.fget + if fget is None: + return + val = KnownValue(fget) + sig = self.signature_from_value(val) if sig is None or sig.has_return_value(): return @@ -4407,7 +4414,7 @@ def _can_perform_call( def check_call( self, - node: ast.AST, + node: Optional[ast.AST], callee: Value, args: Iterable[Composite], keywords: Iterable[Tuple[Optional[str], Composite]] = (), @@ -4430,7 +4437,7 @@ def check_call( def _check_call_no_mvv( self, - node: ast.AST, + node: Optional[ast.AST], callee_wrapped: Value, args: Iterable[Composite], keywords: Iterable[Tuple[Optional[str], Composite]] = (), @@ -4451,11 +4458,12 @@ def _check_call_no_mvv( return_value = AnyValue(AnySource.from_another) elif extended_argspec is None: - self._show_error_if_checking( - node, - f"{callee_wrapped} is not callable", - error_code=ErrorCode.not_callable, - ) + if node is not None: + self._show_error_if_checking( + node, + f"{callee_wrapped} is not callable", + error_code=ErrorCode.not_callable, + ) return_value = AnyValue(AnySource.error) else: diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index 20913067..22b94dce 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -94,7 +94,7 @@ Tuple, TYPE_CHECKING, ) -from typing_extensions import Literal, Protocol, Self +from typing_extensions import Literal, Protocol, Self, assert_never if TYPE_CHECKING: from .name_check_visitor import NameCheckVisitor @@ -183,7 +183,7 @@ def on_error( @dataclass class _VisitorBasedContext: visitor: "NameCheckVisitor" - node: ast.AST + node: Optional[ast.AST] @property def can_assign_ctx(self) -> CanAssignContext: @@ -197,7 +197,11 @@ def on_error( node: Optional[ast.AST] = None, detail: Optional[str] = ..., ) -> None: - self.visitor.show_error(node or self.node, message, code, detail=detail) + if node is None: + node = self.node + if node is None: + return + self.visitor.show_error(node, message, code, detail=detail) @dataclass @@ -278,7 +282,7 @@ class CallContext: """Using the visitor can allow various kinds of advanced logic in impl functions.""" composites: Dict[str, Composite] - node: ast.AST + node: Optional[ast.AST] """AST node corresponding to the function call. Useful for showing errors.""" @@ -1040,7 +1044,10 @@ def get_default_return(self, source: AnySource = AnySource.error) -> CallReturn: return CallReturn(return_value, is_error=True) def check_call( - self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST + self, + args: Iterable[Argument], + visitor: "NameCheckVisitor", + node: Optional[ast.AST], ) -> Value: """Type check a call to this Signature with the given arguments. @@ -1586,9 +1593,15 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Signature": params.append((name, param)) else: params.append((name, param.substitute_typevars(typevars))) + params_dict = dict(params) + return_value = self.return_value.substitute_typevars(typevars) + # Returning the same object helps the local return value check, which relies + # on identity of signature objects. + if return_value == self.return_value and params_dict == self.parameters: + return self return Signature( - dict(params), - self.return_value.substitute_typevars(typevars), + params_dict, + return_value, impl=self.impl, callable=self.callable, is_asynq=self.is_asynq, @@ -1998,7 +2011,10 @@ def __init__(self, sigs: Sequence[Signature]) -> None: object.__setattr__(self, "signatures", tuple(sigs)) def check_call( - self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST + self, + args: Iterable[Argument], + visitor: "NameCheckVisitor", + node: Optional[ast.AST], ) -> Value: """Check a call to an overloaded function. @@ -2179,9 +2195,10 @@ def _make_detail( return CanAssignError(children=details) def substitute_typevars(self, typevars: TypeVarMap) -> "OverloadedSignature": - return OverloadedSignature( - [sig.substitute_typevars(typevars) for sig in self.signatures] - ) + new_sigs = [sig.substitute_typevars(typevars) for sig in self.signatures] + if all(sig1 is sig2 for sig1, sig2 in zip(self.signatures, new_sigs)): + return self + return OverloadedSignature(new_sigs) def bind_self( self, @@ -2255,7 +2272,10 @@ class BoundMethodSignature: return_override: Optional[Value] = None def check_call( - self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST + self, + args: Iterable[Argument], + visitor: "NameCheckVisitor", + node: Optional[ast.AST], ) -> Value: ret = self.signature.check_call( [(self.self_composite, None), *args], visitor, node @@ -2299,30 +2319,7 @@ def __str__(self) -> str: return f"{self.signature} bound to {self.self_composite.value}" -@dataclass(frozen=True) -class PropertyArgSpec: - """Pseudo-argspec for properties.""" - - obj: object - return_value: Value = AnyValue(AnySource.unannotated) - - def check_call( - self, args: Iterable[Argument], visitor: "NameCheckVisitor", node: ast.AST - ) -> Value: - raise TypeError("property object is not callable") - - def has_return_value(self) -> bool: - return not isinstance(self.return_value, AnyValue) - - def substitute_typevars(self, typevars: TypeVarMap) -> "PropertyArgSpec": - return PropertyArgSpec( - self.obj, self.return_value.substitute_typevars(typevars) - ) - - -MaybeSignature = Union[ - None, Signature, BoundMethodSignature, PropertyArgSpec, OverloadedSignature -] +MaybeSignature = Union[None, Signature, BoundMethodSignature, OverloadedSignature] def make_bound_method( @@ -2339,7 +2336,7 @@ def make_bound_method( return_override = argspec.return_override return BoundMethodSignature(argspec.signature, self_composite, return_override) else: - assert False, f"invalid argspec {argspec}" + assert_never(argspec) T = TypeVar("T") diff --git a/pyanalyze/test_signature.py b/pyanalyze/test_signature.py index 137c43da..0260e633 100644 --- a/pyanalyze/test_signature.py +++ b/pyanalyze/test_signature.py @@ -404,6 +404,16 @@ def test_property(self): def capybara(uid): assert_is_value(PropertyObject(uid).string_property, TypedValue(str)) + @assert_passes() + def test_local_return(self): + class X: + @property + def foo(self): + return str(1) + + def capybara() -> None: + assert_is_value(X().foo, TypedValue(str)) + class TestShadowing(TestNameCheckVisitorBase): @assert_passes() @@ -1108,3 +1118,34 @@ def wrapper(): assert_type(func(1), int) assert_type(func(1, 1), int) assert_type(func("x"), float) + + +class TestSelfAnnotation(TestNameCheckVisitorBase): + @assert_passes() + def test_method(self): + from typing import Generic, TypeVar + + T = TypeVar("T") + + class Capybara(Generic[T]): + def method(self: "Capybara[int]") -> int: + return 1 + + def caller(ci: Capybara[int], cs: Capybara[str]): + assert_is_value(ci.method(), TypedValue(int)) + cs.method() # E: incompatible_argument + + @assert_passes() + def test_property(self): + from typing import Generic, TypeVar + + T = TypeVar("T") + + class Capybara(Generic[T]): + @property + def prop(self: "Capybara[int]") -> int: + return 1 + + def caller(ci: Capybara[int], cs: Capybara[str]): + assert_is_value(ci.prop, TypedValue(int)) + cs.prop # E: incompatible_argument diff --git a/pyanalyze/value.py b/pyanalyze/value.py index eadd127e..d4fb5b3e 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -570,8 +570,6 @@ def get_signature( return None if isinstance(signature, pyanalyze.signature.BoundMethodSignature): signature = signature.get_signature(ctx=ctx) - if isinstance(signature, pyanalyze.signature.PropertyArgSpec): - return None return signature def substitute_typevars(self, typevars: TypeVarMap) -> "Value":