diff --git a/docs/changelog.md b/docs/changelog.md index 787241f4..3bdb390f 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Allow storing type narrowing constraints in variables (#343) - The first argument to `__new__` and `__init_subclass` does not need to be `self` (#342) - Drop dependencies on `attrs` and `mypy_extensions` (#341) diff --git a/pyanalyze/annotations.py b/pyanalyze/annotations.py index b742a2ef..e4d75c99 100644 --- a/pyanalyze/annotations.py +++ b/pyanalyze/annotations.py @@ -647,11 +647,12 @@ def show_error( def get_name(self, node: ast.Name) -> Value: if self.visitor is not None: - return self.visitor.resolve_name( + val, _ = self.visitor.resolve_name( node, error_node=self.node, suppress_errors=self.should_suppress_undefined_names, ) + return val elif self.globals is not None: if node.id in self.globals: return KnownValue(self.globals[node.id]) diff --git a/pyanalyze/implementation.py b/pyanalyze/implementation.py index cdbde82f..b4aa993a 100644 --- a/pyanalyze/implementation.py +++ b/pyanalyze/implementation.py @@ -11,7 +11,7 @@ ConstraintType, PredicateProvider, OrConstraint, - Varname, + VarnameWithOrigin, ) from .signature import ANY_SIGNATURE, SigParameter, Signature, ImplReturn, CallContext from .value import ( @@ -50,7 +50,6 @@ unpack_values, ) -from functools import reduce import collections.abc from itertools import product import qcore @@ -116,7 +115,7 @@ def _isinstance_impl(ctx: CallContext) -> ImplReturn: def _constraint_from_isinstance( - varname: Optional[Varname], class_or_tuple: Value + varname: Optional[VarnameWithOrigin], class_or_tuple: Value ) -> AbstractConstraint: if varname is None: return NULL_CONSTRAINT @@ -132,7 +131,7 @@ def _constraint_from_isinstance( Constraint(varname, ConstraintType.is_instance, True, elt) for elt in class_or_tuple.val ] - return reduce(OrConstraint, constraints) + return OrConstraint.make(constraints) else: return NULL_CONSTRAINT @@ -313,18 +312,26 @@ def _list_append_impl(ctx: CallContext) -> ImplReturn: lst = replace_known_sequence_value(ctx.vars["self"]) element = ctx.vars["object"] varname = ctx.visitor.varname_for_self_constraint(ctx.node) - if isinstance(lst, SequenceIncompleteValue): - no_return_unless = Constraint( - varname, - ConstraintType.is_value_object, - True, - SequenceIncompleteValue.make_or_known(list, (*lst.members, element)), - ) - return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) - elif isinstance(lst, GenericValue): - return _maybe_broaden_weak_type( - "list.append", "object", ctx.vars["self"], lst, element, ctx, list, varname - ) + if varname is not None: + if isinstance(lst, SequenceIncompleteValue): + no_return_unless = Constraint( + varname, + ConstraintType.is_value_object, + True, + SequenceIncompleteValue.make_or_known(list, (*lst.members, element)), + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif isinstance(lst, GenericValue): + return _maybe_broaden_weak_type( + "list.append", + "object", + ctx.vars["self"], + lst, + element, + ctx, + list, + varname, + ) return ImplReturn(KnownValue(None)) @@ -537,9 +544,12 @@ def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn: self_value.typ, [*self_value.kv_pairs, KVPair(key, default, is_required=not is_present)], ) - no_return_unless = Constraint( - varname, ConstraintType.is_value_object, True, new_value - ) + if varname is not None: + no_return_unless = Constraint( + varname, ConstraintType.is_value_object, True, new_value + ) + else: + no_return_unless = NULL_CONSTRAINT if not is_present: return ImplReturn(default, no_return_unless=no_return_unless) return ImplReturn( @@ -554,9 +564,12 @@ def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn: new_type = make_weak( GenericValue(self_value.typ, [new_key_type, new_value_type]) ) - no_return_unless = Constraint( - varname, ConstraintType.is_value_object, True, new_type - ) + if varname is not None: + no_return_unless = Constraint( + varname, ConstraintType.is_value_object, True, new_type + ) + else: + no_return_unless = NULL_CONSTRAINT return ImplReturn(new_value_type, no_return_unless=no_return_unless) else: tv_map = key_type.can_assign(key, ctx.visitor) @@ -596,7 +609,7 @@ def _weak_dict_update( self_val: Value, pairs: Sequence[KVPair], ctx: CallContext, - varname: Optional[Varname], + varname: Optional[VarnameWithOrigin], ) -> ImplReturn: self_pairs = kv_pairs_from_mapping(self_val, ctx.visitor) if isinstance(self_pairs, CanAssignError): @@ -622,7 +635,7 @@ def _add_pairs_to_dict( self_val: Value, pairs: Sequence[KVPair], ctx: CallContext, - varname: Optional[Varname], + varname: Optional[VarnameWithOrigin], ) -> ImplReturn: if _is_weak(self_val): return _weak_dict_update(self_val, pairs, ctx, varname) @@ -766,11 +779,16 @@ def inner(lst: Value, iterable: Value) -> ImplReturn: constrained_value = make_weak(GenericValue(list, [generic_arg])) if return_container: return ImplReturn(constrained_value) - no_return_unless = Constraint( - varname, ConstraintType.is_value_object, True, constrained_value - ) - return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) - elif isinstance(cleaned_lst, GenericValue) and isinstance(iterable, TypedValue): + if varname is not None: + no_return_unless = Constraint( + varname, ConstraintType.is_value_object, True, constrained_value + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif ( + varname is not None + and isinstance(cleaned_lst, GenericValue) + and isinstance(iterable, TypedValue) + ): actual_type = iterable.get_generic_arg_for_type( collections.abc.Iterable, ctx.visitor, 0 ) @@ -810,7 +828,7 @@ def _maybe_broaden_weak_type( actual_type: Value, ctx: CallContext, typ: type, - varname: Varname, + varname: VarnameWithOrigin, *, return_container: bool = False, ) -> ImplReturn: @@ -842,21 +860,39 @@ def _set_add_impl(ctx: CallContext) -> ImplReturn: set_value = replace_known_sequence_value(ctx.vars["self"]) element = ctx.vars["object"] varname = ctx.visitor.varname_for_self_constraint(ctx.node) - if isinstance(set_value, SequenceIncompleteValue): - no_return_unless = Constraint( - varname, - ConstraintType.is_value_object, - True, - SequenceIncompleteValue.make_or_known(set, (*set_value.members, element)), - ) - return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) - elif isinstance(set_value, GenericValue): - return _maybe_broaden_weak_type( - "set.add", "object", ctx.vars["self"], set_value, element, ctx, set, varname - ) + if varname is not None: + if isinstance(set_value, SequenceIncompleteValue): + no_return_unless = Constraint( + varname, + ConstraintType.is_value_object, + True, + SequenceIncompleteValue.make_or_known( + set, (*set_value.members, element) + ), + ) + return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) + elif isinstance(set_value, GenericValue): + return _maybe_broaden_weak_type( + "set.add", + "object", + ctx.vars["self"], + set_value, + element, + ctx, + set, + varname, + ) return ImplReturn(KnownValue(None)) +def _remove_annotated(val: Value) -> Value: + if isinstance(val, AnnotatedValue): + return val.value + elif isinstance(val, MultiValuedValue): + return unite_values(*[_remove_annotated(subval) for subval in val.vals]) + return val + + def _assert_is_value_impl(ctx: CallContext) -> Value: if not ctx.visitor._is_checking(): return KnownValue(None) @@ -870,6 +906,8 @@ def _assert_is_value_impl(ctx: CallContext) -> Value: arg="value", ) else: + if _remove_annotated(ctx.vars["skip_annotated"]) == KnownValue(True): + obj = _remove_annotated(obj) if obj != expected_value.val: ctx.show_error( f"Bad value inference: expected {expected_value.val}, got {obj}", @@ -1061,7 +1099,16 @@ def get_default_argspecs() -> Dict[object, Signature]: signatures = [ # pyanalyze helpers Signature.make( - [SigParameter("obj"), SigParameter("value", annotation=TypedValue(Value))], + [ + SigParameter("obj"), + SigParameter("value", annotation=TypedValue(Value)), + SigParameter( + "skip_annotated", + SigParameter.KEYWORD_ONLY, + default=KnownValue(False), + annotation=TypedValue(bool), + ), + ], KnownValue(None), impl=_assert_is_value_impl, callable=assert_is_value, diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index bf5bc138..c51cc1b5 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -17,7 +17,6 @@ import collections.abc import contextlib from dataclasses import dataclass -from functools import reduce import inspect from itertools import chain import logging @@ -74,8 +73,8 @@ from .safe import safe_getattr, is_hashable, safe_in, all_of_type from .stacked_scopes import ( AbstractConstraint, - CompositeVariable, Composite, + CompositeIndex, FunctionScope, Varname, Constraint, @@ -87,12 +86,16 @@ ConstraintType, ScopeType, StackedScopes, + VarnameOrigin, + VarnameWithOrigin, VisitorState, PredicateProvider, LEAVES_LOOP, LEAVES_SCOPE, + annotate_with_constraint, constrain_value, SubScope, + extract_constraints, ) from .signature import ( ANY_SIGNATURE, @@ -1037,7 +1040,8 @@ def _set_name_in_scope( self.module.__name__, varname ) if varname in current_scope: - return current_scope.get_local(varname, node, self.state) + value, _ = current_scope.get_local(varname, node, self.state) + return value if scope_type == ScopeType.class_scope and isinstance(node, ast.AST): self._check_for_class_variable_redefinition(varname, node) current_scope.set(varname, value, node, self.state) @@ -1071,7 +1075,7 @@ def resolve_name( node: ast.Name, error_node: Optional[ast.AST] = None, suppress_errors: bool = False, - ) -> Value: + ) -> Tuple[Value, VarnameOrigin]: """Resolves a Name node to a value. :param node: Node to resolve the name from @@ -1088,7 +1092,9 @@ def resolve_name( """ if error_node is None: error_node = node - value, defining_scope = self.scopes.get_with_scope(node.id, node, self.state) + value, defining_scope, origin = self.scopes.get_with_scope( + node.id, node, self.state + ) if defining_scope is not None: if defining_scope.scope_type in ( ScopeType.module_scope, @@ -1106,7 +1112,7 @@ def resolve_name( self._show_error_if_checking( error_node, f"Undefined name: {node.id}", ErrorCode.undefined_name ) - return AnyValue(AnySource.error) + return AnyValue(AnySource.error), origin if isinstance(value, MultiValuedValue): subvals = value.vals elif isinstance(value, AnnotatedValue) and isinstance( @@ -1132,10 +1138,10 @@ def resolve_name( ] ) if isinstance(value, AnnotatedValue): - return AnnotatedValue(new_mvv, value.metadata) + return AnnotatedValue(new_mvv, value.metadata), origin else: - return new_mvv - return value + return new_mvv, origin + return value, origin def _maybe_show_missing_import_error(self, node: ast.Name) -> None: """Shows errors that suggest adding an import statement in the semi-right place. @@ -1594,7 +1600,7 @@ def _visit_function_body( ), qcore.override(self, "return_values", []): self._generic_visit_list(body) return_values = self.return_values - return_set = scope.get_local(LEAVES_SCOPE, node, self.state) + return_set, _ = scope.get_local(LEAVES_SCOPE, node, self.state) self._check_function_unused_vars(scope) return self._compute_return_type( @@ -2445,28 +2451,16 @@ def _maybe_make_sequence( else: values.append(elt) if has_unknown_value: - return make_weak( - GenericValue( - typ, - [ - unite_and_simplify( - *values, limit=self.config.UNION_SIMPLIFICATION_LIMIT - ) - ], - ) + arg = unite_and_simplify( + *values, limit=self.config.UNION_SIMPLIFICATION_LIMIT ) + return make_weak(GenericValue(typ, [arg])) else: return SequenceIncompleteValue(typ, values) # Operations def visit_BoolOp(self, node: ast.BoolOp) -> Value: - val, _ = self.constraint_from_bool_op(node) - return val - - def constraint_from_bool_op( - self, node: ast.BoolOp - ) -> Tuple[Value, AbstractConstraint]: # Visit an AND or OR expression. # We want to show an error if the left operand in a BoolOp is always true, @@ -2495,31 +2489,29 @@ def constraint_from_bool_op( values.append(right_value) out_constraints.append(constraint) constraint_cls = AndConstraint if is_and else OrConstraint - constraint = reduce(constraint_cls, reversed(out_constraints)) - return unite_values(*values), constraint + constraint = constraint_cls.make(reversed(out_constraints)) + return annotate_with_constraint(unite_values(*values), constraint) def visit_Compare(self, node: ast.Compare) -> Value: - val, _ = self.constraint_from_compare(node) - return val - - def constraint_from_compare( - self, node: ast.Compare - ) -> Tuple[Value, AbstractConstraint]: if len(node.ops) != 1: # TODO handle multi-comparison properly self.generic_visit(node) - return AnyValue(AnySource.inference), NULL_CONSTRAINT + return AnyValue(AnySource.inference) op = node.ops[0] lhs, lhs_constraint = self._visit_possible_constraint(node.left) rhs, rhs_constraint = self._visit_possible_constraint(node.comparators[0]) + if isinstance(lhs, AnnotatedValue): + lhs = lhs.value + if isinstance(rhs, AnnotatedValue): + rhs = rhs.value if isinstance(lhs_constraint, PredicateProvider) and isinstance( rhs, KnownValue ): - return self._constraint_from_predicate_provider(lhs_constraint, rhs.val, op) + return self._value_from_predicate_provider(lhs_constraint, rhs.val, op) elif isinstance(rhs_constraint, PredicateProvider) and isinstance( lhs, KnownValue ): - return self._constraint_from_predicate_provider(rhs_constraint, lhs.val, op) + return self._value_from_predicate_provider(rhs_constraint, lhs.val, op) elif isinstance(rhs, KnownValue): constraint = self._constraint_from_compare_op( node.left, rhs.val, op, is_right=True @@ -2534,7 +2526,7 @@ def constraint_from_compare( val = TypedValue(bool) else: val = AnyValue(AnySource.inference) - return val, constraint + return annotate_with_constraint(val, constraint) def _constraint_from_compare_op( self, constrained_node: ast.AST, other_val: Any, op: ast.AST, *, is_right: bool @@ -2623,9 +2615,9 @@ def predicate_func(value: Value, positive: bool) -> Optional[Value]: return Constraint(varname, ConstraintType.predicate, True, predicate_func) - def _constraint_from_predicate_provider( + def _value_from_predicate_provider( self, pred: PredicateProvider, other_val: Any, op: ast.AST - ) -> Tuple[Value, AbstractConstraint]: + ) -> Value: positive_operator, negative_operator = COMPARATOR_TO_OPERATOR[type(op)] def predicate_func(value: Value, positive: bool) -> Optional[Value]: @@ -2644,15 +2636,9 @@ def predicate_func(value: Value, positive: bool) -> Optional[Value]: constraint = Constraint( pred.varname, ConstraintType.predicate, True, predicate_func ) - return AnyValue(AnySource.inference), constraint + return annotate_with_constraint(AnyValue(AnySource.inference), constraint) def visit_UnaryOp(self, node: ast.UnaryOp) -> Value: - val, _ = self.constraint_from_unary_op(node) - return val - - def constraint_from_unary_op( - self, node: ast.UnaryOp - ) -> Tuple[Value, AbstractConstraint]: if isinstance(node.op, ast.Not): # not doesn't have its own special method val, constraint = self.constraint_from_condition(node.operand) @@ -2663,12 +2649,12 @@ def constraint_from_unary_op( val = KnownValue(True) else: val = TypedValue(bool) - return val, constraint.invert() + return annotate_with_constraint(val, constraint.invert()) else: operand = self.composite_from_node(node.operand) _, method = UNARY_OPERATION_TO_DESCRIPTION_AND_METHOD[type(node.op)] val = self._check_dunder_call(node, operand, method, [], allow_call=True) - return val, NULL_CONSTRAINT + return val def visit_BinOp(self, node: ast.BinOp) -> Value: left = self.composite_from_node(node.left) @@ -3034,33 +3020,25 @@ def visit_Assert(self, node: ast.Assert) -> None: self._check_boolability(test, node, disabled={ErrorCode.value_always_true}) def add_constraint(self, node: object, constraint: AbstractConstraint) -> None: + if constraint is NULL_CONSTRAINT: + return # save some work self.scopes.current_scope().add_constraint(constraint, node, self.state) def _visit_possible_constraint( self, node: ast.AST ) -> Tuple[Value, AbstractConstraint]: - if isinstance(node, ast.Compare): - pair = self.constraint_from_compare(node) - elif isinstance(node, (ast.Name, ast.Attribute, ast.Subscript)): + if isinstance(node, (ast.Name, ast.Attribute, ast.Subscript)): composite = self.composite_from_node(node) if composite.varname is not None: constraint = Constraint( composite.varname, ConstraintType.is_truthy, True, None ) - pair = composite.value, constraint + val = annotate_with_constraint(composite.value, constraint) else: - pair = composite.value, NULL_CONSTRAINT - elif isinstance(node, ast.Call): - pair = self.constraint_from_call(node) - elif isinstance(node, ast.UnaryOp): - pair = self.constraint_from_unary_op(node) - elif isinstance(node, ast.BoolOp): - pair = self.constraint_from_bool_op(node) + val = composite.value else: - pair = self.visit(node), NULL_CONSTRAINT - if self.annotate: - node.inferred_value = pair[0] - return pair + val = self.visit(node) + return val, extract_constraints(val) def visit_Break(self, node: ast.Break) -> None: self._set_name_in_scope(LEAVES_LOOP, node, AnyValue(AnySource.marker)) @@ -3077,7 +3055,9 @@ def visit_For(self, node: ast.For) -> None: else: always_entered = len(iterated_value) > 0 if not isinstance(iterated_value, Value): - iterated_value = unite_values(*iterated_value) + iterated_value = unite_and_simplify( + *iterated_value, limit=self.config.UNION_SIMPLIFICATION_LIMIT + ) with self.scopes.subscope() as body_scope: with self.scopes.loop_scope(): with qcore.override(self, "being_assigned", iterated_value): @@ -3119,7 +3099,9 @@ def visit_While(self, node: ast.While) -> None: self._handle_loop_else(node.orelse, body_scope, always_entered) if self.state == VisitorState.collect_names: - self.visit(node.test) + test, constraint = self.constraint_from_condition( + node.test, check_boolability=False + ) with self.scopes.subscope(): self.add_constraint((node, 2), constraint) self._generic_visit_list(node.body) @@ -3449,14 +3431,14 @@ def composite_from_name( ) -> Composite: if force_read or self._is_read_ctx(node.ctx): self.yield_checker.record_usage(node.id, node) - value = self.resolve_name(node) + value, origin = self.resolve_name(node) varname_value = VariableNameValue.from_varname( node.id, self.config.varname_value_map() ) if varname_value is not None and self._should_use_varname_value(value): value = varname_value value = self._maybe_use_hardcoded_type(value, node.id) - return Composite(value, node.id, node) + return Composite(value, VarnameWithOrigin(node.id, origin), node) elif self._is_write_ctx(node.ctx): if self._name_node_to_statement is not None: statement = self.node_context.nearest_enclosing( @@ -3474,7 +3456,7 @@ def composite_from_name( if not is_ann_assign: self.yield_checker.record_assignment(node.id) value = self._set_name_in_scope(node.id, node, value=value) - return Composite(value, node.id, node) + return Composite(value, VarnameWithOrigin(node.id), node) else: # not sure when (if ever) the other contexts can happen self.show_error(node, f"Bad context: {node.ctx}", ErrorCode.unexpected_node) @@ -3524,32 +3506,32 @@ def composite_from_subscript(self, node: ast.Subscript) -> Composite: and isinstance(index, KnownValue) and is_hashable(index.val) ): - composite_var = root_composite.get_extended_varname(index) + varname = self._extend_composite(root_composite, index, node) else: - composite_var = None + varname = None if isinstance(root_composite.value, MultiValuedValue): values = [ self._composite_from_subscript_no_mvv( node, Composite(val, root_composite.varname, root_composite.node), index_composite, - composite_var, + varname, ) for val in root_composite.value.vals ] return_value = unite_values(*values) else: return_value = self._composite_from_subscript_no_mvv( - node, root_composite, index_composite, composite_var + node, root_composite, index_composite, varname ) - return Composite(return_value, composite_var, node) + return Composite(return_value, varname, node) def _composite_from_subscript_no_mvv( self, node: ast.Subscript, root_composite: Composite, index_composite: Composite, - composite_var: Optional[CompositeVariable], + composite_var: Optional[VarnameWithOrigin], ) -> Value: value = root_composite.value index = index_composite.value @@ -3559,7 +3541,9 @@ def _composite_from_subscript_no_mvv( composite_var is not None and self.scopes.scope_type() == ScopeType.function_scope ): - self.scopes.set(composite_var, self.being_assigned, node, self.state) + self.scopes.set( + composite_var.get_varname(), self.being_assigned, node, self.state + ) self._check_dunder_call( node.value, root_composite, @@ -3591,7 +3575,7 @@ def _composite_from_subscript_no_mvv( with self.catch_errors(): getitem = self._get_dunder(node.value, value, "__getitem__") if getitem is not UNINITIALIZED_VALUE: - return_value, _ = self.check_call( + return_value = self.check_call( node.value, getitem, [root_composite, index_composite], @@ -3610,7 +3594,7 @@ def _composite_from_subscript_no_mvv( ) return_value = AnyValue(AnySource.error) else: - return_value, _ = self.check_call( + return_value = self.check_call( node.value, cgi, [index_composite], allow_call=True ) else: @@ -3636,7 +3620,9 @@ def _composite_from_subscript_no_mvv( composite_var is not None and self.scopes.scope_type() == ScopeType.function_scope ): - local_value = self._get_composite(composite_var, node, return_value) + local_value = self._get_composite( + composite_var.get_varname(), node, return_value + ) if local_value is not UNINITIALIZED_VALUE: return_value = local_value return return_value @@ -3706,13 +3692,13 @@ def _check_dunder_call_no_mvv( method_object = self._get_dunder(node, callee_composite.value, method_name) if method_object is UNINITIALIZED_VALUE: return AnyValue(AnySource.error) - return_value, _ = self.check_call( + return_value = self.check_call( node, method_object, [callee_composite, *args], allow_call=allow_call ) return return_value def _get_composite(self, composite: Varname, node: ast.AST, value: Value) -> Value: - local_value = self.scopes.current_scope().get_local( + local_value, _ = self.scopes.current_scope().get_local( composite, node, self.state, fallback_value=value ) if isinstance(local_value, MultiValuedValue): @@ -3726,6 +3712,15 @@ def _get_composite(self, composite: Varname, node: ast.AST, value: Value) -> Val def visit_Attribute(self, node: ast.Attribute) -> Value: return self.composite_from_attribute(node).value + def _extend_composite( + self, root_composite: Composite, index: CompositeIndex, node: ast.AST + ) -> Optional[VarnameWithOrigin]: + varname = root_composite.get_extended_varname(index) + if varname is None: + return None + origin = self.scopes.current_scope().get_origin(varname, node, self.state) + return root_composite.get_extended_varname_with_origin(index, origin) + def composite_from_attribute(self, node: ast.Attribute) -> Composite: if isinstance(node.value, ast.Name): attr_str = f"{node.value.id}.{node.attr}" @@ -3735,13 +3730,15 @@ def composite_from_attribute(self, node: ast.Attribute) -> Composite: self.yield_checker.record_usage(attr_str, node) root_composite = self.composite_from_node(node.value) - composite = root_composite.get_extended_varname(node.attr) + composite = self._extend_composite(root_composite, node.attr, node) if self._is_write_ctx(node.ctx): if ( composite is not None and self.scopes.scope_type() == ScopeType.function_scope ): - self.scopes.set(composite, self.being_assigned, node, self.state) + self.scopes.set( + composite.get_varname(), self.being_assigned, node, self.state + ) if isinstance(root_composite.value, TypedValue): typ = root_composite.value.typ @@ -3774,7 +3771,7 @@ def composite_from_attribute(self, node: ast.Attribute) -> Composite: composite is not None and self.scopes.scope_type() == ScopeType.function_scope ): - local_value = self._get_composite(composite, node, value) + local_value = self._get_composite(composite.get_varname(), node, value) if local_value is not UNINITIALIZED_VALUE: value = local_value value = self._maybe_use_hardcoded_type(value, node.attr) @@ -3960,23 +3957,12 @@ def composite_from_node(self, node: ast.AST) -> Composite: node.inferred_value = composite.value return composite - def varname_for_constraint(self, node: ast.AST) -> Optional[Varname]: + def varname_for_constraint(self, node: ast.AST) -> Optional[VarnameWithOrigin]: """Given a node, returns a variable name that could be used in a local scope.""" - # TODO replace with composite_from_node(). This is currently used only by - # implementation functions. - if isinstance(node, ast.Attribute): - attribute_path = self._get_attribute_path(node) - if attribute_path: - attributes = tuple(attribute_path[1:]) - return CompositeVariable(attribute_path[0], attributes) - else: - return None - elif isinstance(node, ast.Name): - return node.id - else: - return None + composite = self.composite_from_node(node) + return composite.varname - def varname_for_self_constraint(self, node: ast.AST) -> Optional[Varname]: + def varname_for_self_constraint(self, node: ast.AST) -> Optional[VarnameWithOrigin]: """Helper for constraints on self from method calls. Given an ``ast.Call`` node representing a method call, return the variable name @@ -4012,10 +3998,6 @@ def visit_keyword(self, node: ast.keyword) -> Tuple[Optional[str], Composite]: return (node.arg, self.composite_from_node(node.value)) def visit_Call(self, node: ast.Call) -> Value: - val, _ = self.constraint_from_call(node) - return val - - def constraint_from_call(self, node: ast.Call) -> Tuple[Value, AbstractConstraint]: callee_wrapped = self.visit(node.func) args = [self.composite_from_node(arg) for arg in node.args] if node.keywords: @@ -4023,7 +4005,7 @@ def constraint_from_call(self, node: ast.Call) -> Tuple[Value, AbstractConstrain else: keywords = [] - return_value, constraint = self.check_call( + return_value = self.check_call( node, callee_wrapped, args, keywords, allow_call=self.in_annotation ) @@ -4051,7 +4033,7 @@ def constraint_from_call(self, node: ast.Call) -> Tuple[Value, AbstractConstrain if caller is not None: self.collector.record_call(caller, callee_val) - return return_value, constraint + return return_value def _can_perform_call( self, args: Iterable[Value], keywords: Iterable[Tuple[Optional[str], Value]] @@ -4098,18 +4080,17 @@ def check_call( keywords: Iterable[Tuple[Optional[str], Composite]] = (), *, allow_call: bool = False, - ) -> Tuple[Value, AbstractConstraint]: + ) -> Value: if isinstance(callee, MultiValuedValue): with qcore.override(self, "in_union_decomposition", True): - values, constraints = zip( - *[ - self._check_call_no_mvv( - node, val, args, keywords, allow_call=allow_call - ) - for val in callee.vals - ] - ) - return unite_values(*values), reduce(OrConstraint, constraints) + values = [ + self._check_call_no_mvv( + node, val, args, keywords, allow_call=allow_call + ) + for val in callee.vals + ] + + return unite_values(*values) return self._check_call_no_mvv( node, callee, args, keywords, allow_call=allow_call ) @@ -4122,12 +4103,12 @@ def _check_call_no_mvv( keywords: Iterable[Tuple[Optional[str], Composite]] = (), *, allow_call: bool = False, - ) -> Tuple[Value, AbstractConstraint]: + ) -> Value: if isinstance(callee_wrapped, KnownValue) and any( callee_wrapped.val is ignored for ignored in self.config.IGNORED_CALLEES ): self.log(logging.INFO, "Ignoring callee", callee_wrapped) - return AnyValue(AnySource.error), NULL_CONSTRAINT + return AnyValue(AnySource.error) extended_argspec = self.signature_from_value(callee_wrapped, node) if extended_argspec is ANY_SIGNATURE: @@ -4201,22 +4182,22 @@ def _check_call_no_mvv( callee_wrapped.val ): async_fn = callee_wrapped.val.__self__ - return ( - AsyncTaskIncompleteValue(_get_task_cls(async_fn), return_value), - constraint, + return AsyncTaskIncompleteValue( + _get_task_cls(async_fn), + annotate_with_constraint(return_value, constraint), ) elif isinstance( callee_wrapped, UnboundMethodValue ) and callee_wrapped.secondary_attr_name in ("async", "asynq"): async_fn = callee_wrapped.get_method() - return ( - AsyncTaskIncompleteValue(_get_task_cls(async_fn), return_value), - constraint, + return AsyncTaskIncompleteValue( + _get_task_cls(async_fn), + annotate_with_constraint(return_value, constraint), ) elif isinstance(callee_wrapped, UnboundMethodValue) and asynq.is_pure_async_fn( callee_wrapped.get_method() ): - return return_value, constraint + return annotate_with_constraint(return_value, constraint) else: if ( isinstance(return_value, AnyValue) @@ -4225,8 +4206,8 @@ def _check_call_no_mvv( ): task_cls = _get_task_cls(callee_wrapped.val) if isinstance(task_cls, type): - return TypedValue(task_cls), constraint - return return_value, constraint + return TypedValue(task_cls) + return annotate_with_constraint(return_value, constraint) def signature_from_value( self, value: Value, node: Optional[ast.AST] = None diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index ca60c564..427639f4 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -6,7 +6,6 @@ """ - from .error_code import ErrorCode from .safe import all_of_type from .stacked_scopes import ( @@ -17,7 +16,7 @@ ConstraintType, NULL_CONSTRAINT, AbstractConstraint, - Varname, + VarnameWithOrigin, ) from .value import ( AnnotatedValue, @@ -57,7 +56,6 @@ from collections import defaultdict, OrderedDict import collections.abc from dataclasses import dataclass, field, replace -from functools import reduce import itertools from types import MethodType, FunctionType import inspect @@ -128,14 +126,6 @@ class ActualArguments: kwargs_required: bool -def _maybe_or_constraint( - left: AbstractConstraint, right: AbstractConstraint -) -> AbstractConstraint: - if left is NULL_CONSTRAINT or right is NULL_CONSTRAINT: - return NULL_CONSTRAINT - return OrConstraint(left, right) - - class ImplReturn(NamedTuple): """Return value of :term:`impl` functions. @@ -165,8 +155,8 @@ def unite_impl_rets(cls, rets: Sequence["ImplReturn"]) -> "ImplReturn": return ImplReturn(NO_RETURN_VALUE) return ImplReturn( unite_values(*[r.return_value for r in rets]), - reduce(_maybe_or_constraint, [r.constraint for r in rets]), - reduce(_maybe_or_constraint, [r.no_return_unless for r in rets]), + OrConstraint.make([r.constraint for r in rets]), + OrConstraint.make([r.no_return_unless for r in rets]), ) @@ -190,7 +180,7 @@ def ast_for_arg(self, arg: str) -> Optional[ast.AST]: return composite.node return None - def varname_for_arg(self, arg: str) -> Optional[Varname]: + def varname_for_arg(self, arg: str) -> Optional[VarnameWithOrigin]: """Return a :term:`varname` corresponding to the given function argument. This is useful for creating a :class:`pyanalyze.stacked_scopes.Constraint` @@ -425,9 +415,7 @@ def _apply_annotated_constraints( ret = ImplReturn(raw_return) else: ret = raw_return - constraints = [] - if ret.constraint is not NULL_CONSTRAINT: - constraints.append(ret.constraint) + constraints = [ret.constraint] if isinstance(ret.return_value, AnnotatedValue): for guard in ret.return_value.get_metadata_of_type( ParameterTypeGuardExtension @@ -480,10 +468,7 @@ def _apply_annotated_constraints( ), ) constraints.append(constraint) - if constraints: - constraint = reduce(AndConstraint, constraints) - else: - constraint = NULL_CONSTRAINT + constraint = AndConstraint.make(constraints) return ImplReturn(ret.return_value, constraint, ret.no_return_unless) def bind_arguments( diff --git a/pyanalyze/stacked_scopes.py b/pyanalyze/stacked_scopes.py index 7b7468f4..61f74675 100644 --- a/pyanalyze/stacked_scopes.py +++ b/pyanalyze/stacked_scopes.py @@ -33,6 +33,7 @@ Callable, ContextManager, Dict, + FrozenSet, Iterable, Iterator, List, @@ -49,10 +50,13 @@ from .extensions import reveal_type from .safe import safe_equals, safe_issubclass from .value import ( + NO_RETURN_VALUE, AnnotatedValue, AnySource, AnyValue, + ConstraintExtension, KnownValue, + MultiValuedValue, ReferencingValue, SubclassValue, TypeVarMap, @@ -87,6 +91,15 @@ class ScopeType(enum.Enum): function_scope = 4 +# Nodes as used in scopes can be any object, as long as they are hashable. +Node = object +# Tag for a Varname that changes when the variable is assigned to. +VarnameOrigin = FrozenSet[Optional[Node]] +CompositeIndex = Union[str, KnownValue] + +EMPTY_ORIGIN = frozenset((None,)) + + @dataclass(frozen=True) class CompositeVariable: """:term:`varname` used to implement constraints on instance variables. @@ -103,15 +116,44 @@ class CompositeVariable: """ varname: str - attributes: Sequence[Union[str, KnownValue]] + attributes: Sequence[CompositeIndex] - def extend_with(self, index: Union[str, KnownValue]) -> "CompositeVariable": + def extend_with(self, index: CompositeIndex) -> "CompositeVariable": return CompositeVariable(self.varname, (*self.attributes, index)) Varname = Union[str, CompositeVariable] -# Nodes as used in scopes can be any object, as long as they are hashable. -Node = object + + +@dataclass(frozen=True) +class VarnameWithOrigin: + varname: str + origin: VarnameOrigin = EMPTY_ORIGIN + indices: Sequence[Tuple[CompositeIndex, VarnameOrigin]] = () + + def extend_with( + self, index: CompositeIndex, origin: VarnameOrigin + ) -> "VarnameWithOrigin": + return VarnameWithOrigin( + self.varname, self.origin, (*self.indices, (index, origin)) + ) + + def get_all_varnames(self) -> Iterable[Tuple[Varname, VarnameOrigin]]: + yield self.varname, self.origin + for i, (_, origin) in enumerate(self.indices): + varname = CompositeVariable( + self.varname, tuple(index for index, _ in self.indices[: i + 1]) + ) + yield varname, origin + + def get_varname(self) -> Varname: + if self.indices: + return CompositeVariable( + self.varname, tuple(index for index, _ in self.indices) + ) + return self.varname + + SubScope = Dict[Varname, List[Node]] # Type for Constraint.value if constraint type is predicate @@ -123,18 +165,23 @@ class Composite(NamedTuple): origin. This is useful for setting constraints.""" value: Value - varname: Optional[Varname] = None + varname: Optional[VarnameWithOrigin] = None node: Optional[AST] = None - def get_extended_varname( - self, index: Union[str, KnownValue] - ) -> Optional[CompositeVariable]: + def get_extended_varname(self, index: CompositeIndex) -> Optional[Varname]: if self.varname is None: return None - if isinstance(self.varname, str): - return CompositeVariable(self.varname, (index,)) - else: - return self.varname.extend_with(index) + base = self.varname.get_varname() + if isinstance(base, CompositeVariable): + return CompositeVariable(base.varname, (*base.attributes, index)) + return CompositeVariable(base, (index,)) + + def get_extended_varname_with_origin( + self, index: CompositeIndex, origin: VarnameOrigin + ) -> Optional[VarnameWithOrigin]: + if self.varname is None: + return None + return self.varname.extend_with(index, origin) def substitute_typevars(self, typevars: TypeVarMap) -> "Composite": return Composite( @@ -230,7 +277,7 @@ def f(x: Optional[int]) -> None: """ - varname: Varname + varname: VarnameWithOrigin """The :term:`varname` that the constraint applies to.""" constraint_type: ConstraintType """Type of constraint. Determines the meaning of :attr:`value`.""" @@ -242,14 +289,24 @@ def f(x: Optional[int]) -> None: """Type for an ``is_instance`` constraint; value identical to the variable for ``is_value``; unused for is_truthy; :class:`pyanalyze.value.Value` object for `is_value_object`.""" + inverted: Optional["Constraint"] = field( + compare=False, repr=False, hash=False, default=None + ) + + def __post_init__(self) -> None: + assert isinstance(self.varname, VarnameWithOrigin), self.varname def apply(self) -> Iterable["Constraint"]: yield self def invert(self) -> "Constraint": - return Constraint( + if self.inverted is not None: + return self.inverted + inverted = Constraint( self.varname, self.constraint_type, not self.positive, self.value ) + object.__setattr__(self, "inverted", inverted) + return inverted def apply_to_values(self, values: Iterable[Value]) -> Iterable[Value]: for value in values: @@ -377,9 +434,21 @@ def apply_to_value(self, value: Value) -> Iterable[Value]: else: assert False, f"unknown constraint type {self.constraint_type}" + def __str__(self) -> str: + sign = "+" if self.positive else "-" + if isinstance(self.value, list): + value = str(list(map(str, self.value))) + else: + value = str(self.value) + return f"<{sign}{self.varname} {self.constraint_type.name} {value}>" + -TRUTHY_CONSTRAINT = Constraint("%unused", ConstraintType.is_truthy, True, None) -FALSY_CONSTRAINT = Constraint("%unused", ConstraintType.is_truthy, False, None) +TRUTHY_CONSTRAINT = Constraint( + VarnameWithOrigin("%unused"), ConstraintType.is_truthy, True, None +) +FALSY_CONSTRAINT = Constraint( + VarnameWithOrigin("%unused"), ConstraintType.is_truthy, False, None +) @dataclass(frozen=True) @@ -425,7 +494,7 @@ def two_lengths(tpl: Union[Tuple[int], Tuple[str, int]]) -> int: """ - varname: Varname + varname: VarnameWithOrigin provider: Callable[[Value], Value] def apply(self) -> Iterable[Constraint]: @@ -440,46 +509,72 @@ def invert(self) -> AbstractConstraint: class AndConstraint(AbstractConstraint): """Represents the AND of two constraints.""" - left: AbstractConstraint - right: AbstractConstraint + constraints: Tuple[AbstractConstraint, ...] def apply(self) -> Iterable["Constraint"]: - for constraint in self.left.apply(): - yield constraint - for constraint in self.right.apply(): - yield constraint + for cons in self.constraints: + yield from cons.apply() def invert(self) -> "OrConstraint": # ~(A and B) -> ~A or ~B - return OrConstraint(self.left.invert(), self.right.invert()) + return OrConstraint(tuple([cons.invert() for cons in self.constraints])) + + @classmethod + def make(cls, constraints: Iterable[AbstractConstraint]) -> AbstractConstraint: + processed = {} + for cons in constraints: + if isinstance(cons, AndConstraint): + for subcons in cons.constraints: + processed[id(subcons)] = subcons + continue + processed[id(cons)] = cons + + final = [] + for constraint in processed.values(): + if isinstance(constraint, OrConstraint): + # A AND (A OR B) reduces to a + if any(id(subcons) in processed for subcons in constraint.constraints): + continue + final.append(constraint) + + if not final: + return NULL_CONSTRAINT + if len(final) == 1: + (cons,) = final + return cons + return cls(tuple(final)) + + def __str__(self) -> str: + children = " AND ".join(map(str, self.constraints)) + return f"({children})" @dataclass(frozen=True) class OrConstraint(AbstractConstraint): """Represents the OR of two constraints.""" - left: AbstractConstraint - right: AbstractConstraint + constraints: Tuple[AbstractConstraint, ...] def apply(self) -> Iterable[Constraint]: - left = self._group_constraints(self.left) - right = self._group_constraints(self.right) + grouped = [self._group_constraints(cons) for cons in self.constraints] + left, *rest = grouped for varname, constraints in left.items(): # Produce one_of constraints if the same variable name # applies on both the left and the right side. - if varname in right: - yield Constraint( - varname, - ConstraintType.one_of, - True, - [ - self._constraint_from_list(varname, constraints), - self._constraint_from_list(varname, right[varname]), + if all(varname in group for group in rest): + constraints = [ + self._constraint_from_list(varname, constraints), + *[ + self._constraint_from_list(varname, group[varname]) + for group in rest ], + ] + yield Constraint( + varname, ConstraintType.one_of, True, list(set(constraints)) ) def _constraint_from_list( - self, varname: Varname, constraints: Sequence[Constraint] + self, varname: VarnameWithOrigin, constraints: Sequence[Constraint] ) -> Constraint: if len(constraints) == 1: return constraints[0] @@ -488,7 +583,7 @@ def _constraint_from_list( def _group_constraints( self, abstract_constraint: AbstractConstraint - ) -> Dict[str, List[Constraint]]: + ) -> Dict[VarnameWithOrigin, List[Constraint]]: by_varname = defaultdict(list) for constraint in abstract_constraint.apply(): by_varname[constraint.varname].append(constraint) @@ -496,21 +591,54 @@ def _group_constraints( def invert(self) -> AndConstraint: # ~(A or B) -> ~A and ~B - return AndConstraint(self.left.invert(), self.right.invert()) + return AndConstraint(tuple([cons.invert() for cons in self.constraints])) + + @classmethod + def make(cls, constraints: Iterable[AbstractConstraint]) -> AbstractConstraint: + processed = {} + for cons in constraints: + if isinstance(cons, OrConstraint): + for subcons in cons.constraints: + processed[id(subcons)] = subcons + continue + processed[id(cons)] = cons + + final = [] + for constraint in processed.values(): + if isinstance(constraint, AndConstraint): + # A OR (A AND B) reduces to a + if any(id(subcons) in processed for subcons in constraint.constraints): + continue + elif isinstance(constraint, Constraint): + inverted = id(constraint.invert()) + if inverted in processed: + continue + final.append(constraint) + + if not final: + return NULL_CONSTRAINT + if len(final) == 1: + (cons,) = final + return cons + return cls(tuple(final)) + + def __str__(self) -> str: + children = " OR ".join(map(str, self.constraints)) + return f"({children})" +@dataclass(frozen=True) class _ConstrainedValue(Value): """Helper class, only used within a FunctionScope.""" - def __init__( - self, definition_nodes: Set[Node], constraints: Sequence[Constraint] - ) -> None: - self.definition_nodes = definition_nodes - self.constraints = constraints - self.resolution_cache = {} + definition_nodes: FrozenSet[Node] + constraints: Sequence[Constraint] + resolution_cache: Dict[_LookupContext, Value] = field( + default_factory=dict, init=False, compare=False, hash=False, repr=False + ) -_empty_constrained = _ConstrainedValue(set(), []) +_empty_constrained = _ConstrainedValue(frozenset(), []) @dataclass @@ -544,22 +672,25 @@ def get( node: object, state: VisitorState, from_parent_scope: bool = False, - ) -> Tuple[Value, Optional["Scope"]]: - local_value = self.get_local( + ) -> Tuple[Value, Optional["Scope"], VarnameOrigin]: + local_value, origin = self.get_local( varname, node, state, from_parent_scope=from_parent_scope ) if local_value is not UNINITIALIZED_VALUE: - return self.resolve_reference(local_value, state), self + return self.resolve_reference(local_value, state), self, origin elif self.parent_scope is not None: # Parent scopes don't get the node to help local lookup. parent_node = ( (varname, self.scope_node) if self.scope_node is not None else None ) - return self.parent_scope.get( + val, scope, _ = self.parent_scope.get( varname, parent_node, state, from_parent_scope=True ) + # Tag lookups in the parent scope with this scope node, so we + # don't carry over constraints across scopes. + return val, scope, EMPTY_ORIGIN else: - return UNINITIALIZED_VALUE, None + return UNINITIALIZED_VALUE, None, EMPTY_ORIGIN def get_local( self, @@ -568,11 +699,16 @@ def get_local( state: VisitorState, from_parent_scope: bool = False, fallback_value: Optional[Value] = None, - ) -> Value: + ) -> Tuple[Value, VarnameOrigin]: if varname in self.variables: - return self.variables[varname] + return self.variables[varname], EMPTY_ORIGIN else: - return UNINITIALIZED_VALUE + return UNINITIALIZED_VALUE, EMPTY_ORIGIN + + def get_origin( + self, varname: Varname, node: Node, state: VisitorState + ) -> VarnameOrigin: + return EMPTY_ORIGIN def set( self, varname: Varname, value: Value, node: Node, state: VisitorState @@ -620,7 +756,7 @@ def combine_subscopes( def resolve_reference(self, value: Value, state: VisitorState) -> Value: if isinstance(value, ReferencingValue): - referenced, _ = value.scope.get(value.name, None, state) + referenced, _, _ = value.scope.get(value.name, None, state) # globals that are None are probably set to something else later if safe_equals(referenced, KnownValue(None)): return AnyValue(AnySource.inference) @@ -785,7 +921,7 @@ class FunctionScope(Scope): """ name_to_current_definition_nodes: SubScope - usage_to_definition_nodes: Dict[Node, List[Node]] + usage_to_definition_nodes: Dict[Tuple[Node, Varname], List[Node]] definition_node_to_value: Dict[Node, Value] name_to_all_definition_nodes: Dict[str, Set[Node]] name_to_composites: Dict[str, Set[CompositeVariable]] @@ -832,20 +968,51 @@ def add_constraint( """ for constraint in abstract_constraint.apply(): - def_nodes = set(self.name_to_current_definition_nodes[constraint.varname]) - # We set both a constraint and its inverse using the same node as the definition - # node, so cheat and include the constraint itself in the key. If you write constraints - # to the same key in definition_node_to_value multiple times, you're likely to get - # infinite recursion. - node = (node, constraint) - assert ( - node not in self.definition_node_to_value - ), "duplicate constraint for {}".format(node) - self.definition_node_to_value[node] = _ConstrainedValue( - def_nodes, [constraint] - ) - self.name_to_current_definition_nodes[constraint.varname] = [node] - self._add_composite(constraint.varname) + self._add_single_constraint(constraint, node, state) + + def _add_single_constraint( + self, constraint: Constraint, node: Node, state: VisitorState + ) -> None: + for parent_varname, constraint_origin in constraint.varname.get_all_varnames(): + current_origin = self.get_origin(parent_varname, node, state) + current_set = self._resolve_origin(current_origin) + constraint_set = self._resolve_origin(constraint_origin) + if current_set - constraint_set: + return + + varname = constraint.varname.get_varname() + def_nodes = frozenset(self.name_to_current_definition_nodes[varname]) + # We set both a constraint and its inverse using the same node as the definition + # node, so cheat and include the constraint itself in the key. + node = (node, constraint) + val = _ConstrainedValue(def_nodes, [constraint]) + self.definition_node_to_value[node] = val + self.name_to_current_definition_nodes[varname] = [node] + self._add_composite(varname) + + def _resolve_origin(self, definers: Iterable[Node]) -> FrozenSet[Node]: + seen = set() + pending = set(definers) + out = set() + while pending: + definer = pending.pop() + if definer in seen: + continue + seen.add(definer) + if definer is None: + out.add(None) + elif definer not in self.definition_node_to_value: + # maybe from a different scope + return EMPTY_ORIGIN + else: + val = self.definition_node_to_value[definer] + if isinstance(val, _ConstrainedValue): + pending |= val.definition_nodes + else: + out.add(definer) + if not out: + return EMPTY_ORIGIN + return frozenset(out) def set( self, varname: Varname, value: Value, node: Node, state: VisitorState @@ -872,38 +1039,61 @@ def set( def get_local( self, - varname: str, + varname: Varname, node: Node, state: VisitorState, from_parent_scope: bool = False, fallback_value: Optional[Value] = None, - ) -> Value: + ) -> Tuple[Value, VarnameOrigin]: self._add_composite(varname) ctx = _LookupContext(varname, fallback_value, node, state) if from_parent_scope: self.accessed_from_special_nodes.add(varname) + key = (node, varname) if node is None: self.accessed_from_special_nodes.add(varname) # this indicates that we're not looking at a normal local variable reference, but # something special like a nested function if varname in self.name_to_all_definition_nodes: - return self._get_value_from_nodes( - self.name_to_all_definition_nodes[varname], ctx - ) + definers = self.name_to_all_definition_nodes[varname] else: - return self.referencing_value_vars[varname] - if state is VisitorState.check_names: - if node not in self.usage_to_definition_nodes: - return self.referencing_value_vars[varname] + return self.referencing_value_vars[varname], EMPTY_ORIGIN + elif state is VisitorState.check_names: + if key not in self.usage_to_definition_nodes: + return self.referencing_value_vars[varname], EMPTY_ORIGIN else: - definers = self.usage_to_definition_nodes[node] + definers = self.usage_to_definition_nodes[key] else: if varname in self.name_to_current_definition_nodes: definers = self.name_to_current_definition_nodes[varname] - self.usage_to_definition_nodes[node] += definers + self.usage_to_definition_nodes[key] += definers else: - return self.referencing_value_vars[varname] - return self._get_value_from_nodes(definers, ctx) + return self.referencing_value_vars[varname], EMPTY_ORIGIN + return self._get_value_from_nodes(definers, ctx), self._resolve_origin(definers) + + def get_origin( + self, varname: Varname, node: Node, state: VisitorState + ) -> VarnameOrigin: + key = (node, varname) + if node is None: + # this indicates that we're not looking at a normal local variable reference, but + # something special like a nested function + if varname in self.name_to_all_definition_nodes: + definers = self.name_to_all_definition_nodes[varname] + else: + return EMPTY_ORIGIN + elif state is VisitorState.check_names: + if key not in self.usage_to_definition_nodes: + return EMPTY_ORIGIN + else: + definers = self.usage_to_definition_nodes[key] + else: + if varname in self.name_to_current_definition_nodes: + definers = self.name_to_current_definition_nodes[varname] + self.usage_to_definition_nodes[key] += definers + else: + return EMPTY_ORIGIN + return self._resolve_origin(definers) @contextlib.contextmanager def subscope(self) -> Iterable[SubScope]: @@ -971,6 +1161,9 @@ def _resolve_value(self, val: Value, ctx: _LookupContext) -> Value: key = replace(ctx, fallback_value=None) if key in val.resolution_cache: return val.resolution_cache[key] + # Guard against recursion. This happens in the test_len_condition test. + # Perhaps we should do something smarter to prevent recursion. + val.resolution_cache[key] = NO_RETURN_VALUE if val.definition_nodes or ctx.fallback_value: resolved = self._get_value_from_nodes( val.definition_nodes, ctx, val.constraints @@ -979,7 +1172,7 @@ def _resolve_value(self, val: Value, ctx: _LookupContext) -> Value: assert ( self.parent_scope ), "constrained value must have definition nodes or parent scope" - parent_val, _ = self.parent_scope.get(ctx.varname, None, ctx.state) + parent_val, _, _ = self.parent_scope.get(ctx.varname, None, ctx.state) resolved = _constrain_value( [parent_val], val.constraints, @@ -1124,15 +1317,16 @@ def get(self, varname: Varname, node: Node, state: VisitorState) -> Value: Returns :data:`pyanalyze.value.UNINITIALIZED_VALUE` if the name is not defined in any known scope. """ - value, _ = self.get_with_scope(varname, node, state) + value, _, _ = self.get_with_scope(varname, node, state) return value def get_with_scope( self, varname: Varname, node: Node, state: VisitorState - ) -> Tuple[Value, Optional[Scope]]: + ) -> Tuple[Value, Optional[Scope], VarnameOrigin]: """Like :meth:`get`, but also returns the scope object the name was found in. - Returns a (:class:`pyanalyze.value.Value`, :class:`Scope`) tuple. The :class:`Scope` is ``None`` if the name was not found. + Returns a (:class:`pyanalyze.value.Value`, :class:`Scope`, origin) tuple. The :class:`Scope` + is ``None`` if the name was not found. """ return self.scopes[-1].get(varname, node, state) @@ -1233,9 +1427,30 @@ def _constrain_value( for constraint in constraints: values = list(constraint.apply_to_values(values)) if not values: - # TODO: maybe show an error here? This branch should mean the code is - # unreachable. return AnyValue(AnySource.unreachable) if simplification_limit is not None: return unite_and_simplify(*values, limit=simplification_limit) return unite_values(*values) + + +def annotate_with_constraint(value: Value, constraint: AbstractConstraint) -> Value: + if constraint is NULL_CONSTRAINT: + return value + return annotate_value(value, [ConstraintExtension(constraint)]) + + +def extract_constraints(value: Value) -> AbstractConstraint: + if isinstance(value, AnnotatedValue): + extensions = list(value.get_metadata_of_type(ConstraintExtension)) + constraints = [ext.constraint for ext in extensions] + base = extract_constraints(value.value) + constraints = [ + cons for cons in [*constraints, base] if cons is not NULL_CONSTRAINT + ] + return AndConstraint.make(constraints) + elif isinstance(value, MultiValuedValue): + constraints = [extract_constraints(subval) for subval in value.vals] + if not constraints: + return NULL_CONSTRAINT + return OrConstraint.make(constraints) + return NULL_CONSTRAINT diff --git a/pyanalyze/test_name_check_visitor.py b/pyanalyze/test_name_check_visitor.py index c4e87572..d5a3cbac 100644 --- a/pyanalyze/test_name_check_visitor.py +++ b/pyanalyze/test_name_check_visitor.py @@ -15,7 +15,7 @@ ) from .implementation import assert_is_value, dump_value from .error_code import DISABLED_IN_TESTS, ErrorCode -from .stacked_scopes import Composite +from .stacked_scopes import Composite, Varname from .test_config import TestConfig from .value import ( AnnotatedValue, @@ -140,7 +140,6 @@ def _make_module(code_str: str) -> types.ModuleType: make_weak=make_weak, UNINITIALIZED_VALUE=UNINITIALIZED_VALUE, NO_RETURN_VALUE=NO_RETURN_VALUE, - Composite=Composite, ) return make_module(code_str, extra_scope) @@ -763,22 +762,35 @@ def capybara(x): assert_is_value( cond and 1, MultiValuedValue([TypedValue(str), KnownValue(None), KnownValue(1)]), + skip_annotated=True, ) assert_is_value( - cond2 and 1, MultiValuedValue([KnownValue(None), KnownValue(1)]) + cond2 and 1, + MultiValuedValue([KnownValue(None), KnownValue(1)]), + skip_annotated=True, ) assert_is_value( - cond or 1, MultiValuedValue([TypedValue(str), KnownValue(1)]) + cond or 1, + MultiValuedValue([TypedValue(str), KnownValue(1)]), + skip_annotated=True, ) assert_is_value( - cond2 or 1, MultiValuedValue([KnownValue(True), KnownValue(1)]) + cond2 or 1, + MultiValuedValue([KnownValue(True), KnownValue(1)]), + skip_annotated=True, ) def hutia(x=None): assert_is_value(x, AnyValue(AnySource.unannotated) | KnownValue(None)) - assert_is_value(x or 1, AnyValue(AnySource.unannotated) | KnownValue(1)) + assert_is_value( + x or 1, + AnyValue(AnySource.unannotated) | KnownValue(1), + skip_annotated=True, + ) y = x or 1 - assert_is_value(y, AnyValue(AnySource.unannotated) | KnownValue(1)) + assert_is_value( + y, AnyValue(AnySource.unannotated) | KnownValue(1), skip_annotated=True + ) assert_is_value( (True if x else False) or None, KnownValue(True) | KnownValue(None) ) @@ -1596,6 +1608,7 @@ class TestUnboundMethodValue(TestNameCheckVisitorBase): @assert_passes() def test_inference(self): from pyanalyze.tests import PropertyObject, ClassWithAsync + from pyanalyze.stacked_scopes import Composite def capybara(oid): assert_is_value( @@ -1630,6 +1643,12 @@ def capybara(oid): @assert_passes() def test_metaclass_super(self): + from pyanalyze.stacked_scopes import Composite, VarnameWithOrigin + from qcore.testing import Anything + from typing import Any, cast + + varname = VarnameWithOrigin("self", cast(Any, Anything)) + class Metaclass(type): def __init__(self, name, bases, attrs): super(Metaclass, self).__init__(name, bases, attrs) @@ -1640,7 +1659,7 @@ def __init__(self, name, bases, attrs): assert_is_value( self.__init__, UnboundMethodValue( - "__init__", Composite(TypedValue(Metaclass), "self") + "__init__", Composite(TypedValue(Metaclass), varname) ), ) @@ -1727,7 +1746,7 @@ class TestOperators(TestNameCheckVisitorBase): @assert_passes(settings={ErrorCode.value_always_true: False}) def test_not(self): def capybara(x): - assert_is_value(not x, TypedValue(bool)) + assert_is_value(not x, TypedValue(bool), skip_annotated=True) assert_is_value(not True, KnownValue(False)) @assert_passes() diff --git a/pyanalyze/test_signature.py b/pyanalyze/test_signature.py index 17742cb7..b59414db 100644 --- a/pyanalyze/test_signature.py +++ b/pyanalyze/test_signature.py @@ -7,6 +7,7 @@ AnySource, AnyValue, CanAssignError, + ConstraintExtension, GenericValue, KnownValue, MultiValuedValue, @@ -499,20 +500,22 @@ def fn(): @assert_passes() def test_return_value(self): from pyanalyze.value import HasAttrGuardExtension + from qcore.testing import Anything + from typing import Any, cast + + val = AnnotatedValue( + TypedValue(bool), + [ + HasAttrGuardExtension( + "object", KnownValue("foo"), AnyValue(AnySource.inference) + ), + cast(Any, Anything), + ], + ) def capybara(x): l = hasattr(x, "foo") - assert_is_value( - l, - AnnotatedValue( - TypedValue(bool), - [ - HasAttrGuardExtension( - "object", KnownValue("foo"), AnyValue(AnySource.inference) - ) - ], - ), - ) + assert_is_value(l, val) @assert_passes() def test_required_kwonly_args(self): @@ -601,6 +604,8 @@ def run(): @assert_passes() def test_hasattr(self): from pyanalyze.value import HasAttrGuardExtension + from typing import Any, cast + from qcore.testing import Anything class Quemisia(object): def gravis(self): @@ -613,21 +618,19 @@ def wrong_args(): def mistyped_args(): hasattr(True, False) # E: incompatible_argument + inferred = AnnotatedValue( + TypedValue(bool), + [ + HasAttrGuardExtension( + "object", KnownValue("__qualname__"), AnyValue(AnySource.inference) + ), + cast(Any, Anything), + ], + ) + def only_on_class(o: object): val = hasattr(o, "__qualname__") - assert_is_value( - val, - AnnotatedValue( - TypedValue(bool), - [ - HasAttrGuardExtension( - "object", - KnownValue("__qualname__"), - AnyValue(AnySource.inference), - ) - ], - ), - ) + assert_is_value(val, inferred) @assert_fails(ErrorCode.incompatible_call) def test_keyword_only_args(self): diff --git a/pyanalyze/test_stacked_scopes.py b/pyanalyze/test_stacked_scopes.py index e6b403fa..e34c3ad1 100644 --- a/pyanalyze/test_stacked_scopes.py +++ b/pyanalyze/test_stacked_scopes.py @@ -929,9 +929,49 @@ def paca(cond1, cond2): else: assert_is_value(x, KnownValue(False)) + @assert_passes() + def test_double_index(self): + from typing import Union, Optional + + class A: + attr: Union[int, str] + + class B: + attr: Optional[A] + + def capybara(b: B): + assert_is_value(b, TypedValue(B)) + assert_is_value(b.attr, TypedValue(A) | KnownValue(None)) + if b.attr is not None: + assert_is_value(b.attr, TypedValue(A)) + assert_is_value(b.attr.attr, TypedValue(int) | TypedValue(str)) + if isinstance(b.attr.attr, int): + assert_is_value(b.attr.attr, TypedValue(int)) + + @assert_passes() + def test_nested_scope(self): + from pyanalyze.value import WeakExtension + + class A: + pass + + class B(A): + pass + + def capybara(a: A, iterable): + if isinstance(a, B): + assert_is_value(a, TypedValue(B)) + lst = [a for _ in iterable] + assert_is_value( + lst, + AnnotatedValue( + GenericValue(list, [TypedValue(B)]), [WeakExtension()] + ), + ) + @assert_passes() def test_qcore_asserts(self): - from qcore.asserts import assert_is, assert_is_not, assert_is_instance + from qcore.asserts import assert_is_instance def capybara(cond): if cond: @@ -1608,7 +1648,7 @@ def eat_no_assign(self): def test_subscript(self): from typing import Any, Dict - def capybara(x: Dict[str, Any]) -> None: + def capybara(x: Dict[str, Any], y) -> None: assert_is_value(x["a"], AnyValue(AnySource.explicit)) x["a"] = 1 assert_is_value(x["a"], KnownValue(1)) @@ -1629,3 +1669,78 @@ def test_uniq_chain(): assert [] == uniq_chain([]) assert list(range(3)) == uniq_chain(range(3) for _ in range(3)) assert [1] == uniq_chain([1, 1, 1] for _ in range(3)) + + +class TestInvalidation(TestNameCheckVisitorBase): + @assert_passes() + def test_still_valid(self) -> None: + def capybara(x, y): + condition = isinstance(x, int) + assert_is_value(x, AnyValue(AnySource.unannotated)) + if condition: + assert_is_value(x, TypedValue(int)) + + condition = isinstance(y, int) if x else isinstance(y, str) + assert_is_value(y, AnyValue(AnySource.unannotated)) + if condition: + assert_is_value(y, TypedValue(int) | TypedValue(str)) + + @assert_passes() + def test_invalidated(self) -> None: + def capybara(x, y): + condition = isinstance(x, int) + assert_is_value(x, AnyValue(AnySource.unannotated)) + x = y + if condition: + assert_is_value(x, AnyValue(AnySource.unannotated)) + + @assert_passes() + def test_other_scope(self) -> None: + def callee(x): + return isinstance(x, int) + + def capybara(x, y): + if callee(y): + assert_is_value(x, AnyValue(AnySource.unannotated)) + assert_is_value(y, AnyValue(AnySource.unannotated)) + + @assert_passes() + def test_while(self) -> None: + from typing import Optional + + def make_optional() -> Optional[str]: + return "x" + + def capybara(): + x = make_optional() + while x: + assert_is_value(x, TypedValue(str)) + x = make_optional() + + @assert_passes() + def test_len_condition(self) -> None: + def capybara(file_list, key, ids): + has_bias = len(key) > 0 + data = [] + for _ in file_list: + assert_is_value(key, AnyValue(AnySource.unannotated)) + if has_bias: + assert_is_value(key, AnyValue(AnySource.unannotated)) + data = [ids, data[key]] + else: + data = [ids] + + @assert_passes() + def test_len_condition_with_type(self) -> None: + from typing import Optional + + def capybara(file_list, key: Optional[int], ids): + has_bias = key is not None + data = [] + for _ in file_list: + assert_is_value(key, TypedValue(int) | KnownValue(None)) + if has_bias: + assert_is_value(key, TypedValue(int)) + data = [ids, data[key]] + else: + data = [ids] diff --git a/pyanalyze/test_value.py b/pyanalyze/test_value.py index dbf087c5..75ba882d 100644 --- a/pyanalyze/test_value.py +++ b/pyanalyze/test_value.py @@ -33,6 +33,7 @@ SequenceIncompleteValue, TypeVarMap, concrete_values_from_iterable, + unite_and_simplify, ) _checker = Checker(TestConfig()) @@ -562,3 +563,10 @@ def test_pickling() -> None: _assert_pickling_roundtrip(KnownValue(1)) _assert_pickling_roundtrip(TypedValue(int)) _assert_pickling_roundtrip(KnownValue(None) | TypedValue(str)) + + +def test_unite_and_simplify() -> None: + vals = [GenericValue(list, [TypedValue(int)]), KnownValue([])] + assert unite_and_simplify(*vals, limit=2) == GenericValue( + list, [TypedValue(int)] + ) | GenericValue(list, [AnyValue(AnySource.unreachable)]) diff --git a/pyanalyze/value.py b/pyanalyze/value.py index afed0a51..3427a099 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -56,6 +56,7 @@ def function(x: int, y: list[int], z: Any): # __builtin__ in Python 2 and builtins in Python 3 BUILTIN_MODULE = str.__module__ KNOWN_MUTABLE_TYPES = (list, set, dict, deque) +ITERATION_LIMIT = 1000 TypeVarMap = Mapping["TypeVar", "Value"] GenericBases = Mapping[Union[type, str], TypeVarMap] @@ -271,7 +272,7 @@ def __str__(self) -> str: CanAssign = Union[TypeVarMap, CanAssignError] -def assert_is_value(obj: object, value: Value) -> None: +def assert_is_value(obj: object, value: Value, *, skip_annotated: bool = False) -> None: """Used to test pyanalyze's value inference. Takes two arguments: a Python object and a :class:`Value` object. At runtime @@ -283,6 +284,8 @@ def assert_is_value(obj: object, value: Value) -> None: assert_is_value(1, KnownValue(1)) # passes assert_is_value(1, TypedValue(int)) # shows an error + If skip_annotated is True, unwraps any :class:`AnnotatedValue` in the input. + """ pass @@ -908,7 +911,10 @@ def simplify(self) -> GenericValue: tuple, [member.simplify() for member in self.members] ) members = [member.simplify() for member in self.members] - return GenericValue(self.typ, [unite_values(*members)]) + arg = unite_values(*members) + if arg is NO_RETURN_VALUE: + arg = AnyValue(AnySource.unreachable) + return GenericValue(self.typ, [arg]) @dataclass(frozen=True) @@ -980,7 +986,13 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Value: def simplify(self) -> GenericValue: keys = [pair.key.simplify() for pair in self.kv_pairs] values = [pair.value.simplify() for pair in self.kv_pairs] - return GenericValue(self.typ, [unite_values(*keys), unite_values(*values)]) + key = unite_values(*keys) + value = unite_values(*values) + if key is NO_RETURN_VALUE: + key = AnyValue(AnySource.unreachable) + if value is NO_RETURN_VALUE: + value = AnyValue(AnySource.unreachable) + return GenericValue(self.typ, [key, value]) @property def items(self) -> Sequence[Tuple[Value, Value]]: @@ -1644,6 +1656,18 @@ def walk_values(self) -> Iterable[Value]: yield from self.attribute_type.walk_values() +@dataclass(frozen=True, eq=False) +class ConstraintExtension(Extension): + """Encapsulates a Constraint. If the value is evaluated and is truthy, the + constraint must be True.""" + + constraint: "pyanalyze.stacked_scopes.AbstractConstraint" + + # Comparing them can get too expensive + def __hash__(self) -> int: + return id(self) + + @dataclass(frozen=True) class WeakExtension(Extension): """Used to indicate that a generic argument to a container may be widened. @@ -1955,7 +1979,10 @@ def concrete_values_from_iterable( if all(pair.is_required and not pair.is_many for pair in value.kv_pairs): return [pair.key for pair in value.kv_pairs] elif isinstance(value, KnownValue): - if isinstance(value.val, (str, bytes, range)): + if ( + isinstance(value.val, (str, bytes, range)) + and len(value.val) < ITERATION_LIMIT + ): return [KnownValue(c) for c in value.val] elif value is NO_RETURN_VALUE: return NO_RETURN_VALUE @@ -1983,19 +2010,19 @@ def kv_pairs_from_mapping( value_val = replace_known_sequence_value(value_val) # Special case: if we have a Union including an empty dict, just get the # pairs from the rest of the union and make them all non-required. - if isinstance(value_val, MultiValuedValue) and any( - subval in EMPTY_DICTS for subval in value_val.vals - ): - other_val = unite_values( - *[subval for subval in value_val.vals if subval not in EMPTY_DICTS] - ) - pairs = kv_pairs_from_mapping(other_val, ctx) - if isinstance(pairs, CanAssignError): - return pairs - return [ - KVPair(pair.key, pair.value, pair.is_many, is_required=False) - for pair in pairs - ] + if isinstance(value_val, MultiValuedValue): + subvals = [replace_known_sequence_value(subval) for subval in value_val.vals] + if any(subval in EMPTY_DICTS for subval in subvals): + other_val = unite_values( + *[subval for subval in subvals if subval not in EMPTY_DICTS] + ) + pairs = kv_pairs_from_mapping(other_val, ctx) + if isinstance(pairs, CanAssignError): + return pairs + return [ + KVPair(pair.key, pair.value, pair.is_many, is_required=False) + for pair in pairs + ] if isinstance(value_val, DictIncompleteValue): return value_val.kv_pairs elif isinstance(value_val, TypedDictValue):