diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/intersections.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/intersections.md new file mode 100644 index 0000000000000..39efa500fa1fe --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/intersections.md @@ -0,0 +1,155 @@ +# Comparison: Intersections + +## Positive contributions + +If we have an intersection type `A & B` and we get a definitive true/false answer for one of the +types, we can infer that the result for the intersection type is also true/false: + +```py +class Base: ... + +class Child1(Base): + def __eq__(self, other) -> Literal[True]: + return True + +class Child2(Base): ... + +def get_base() -> Base: ... + +x = get_base() +c1 = Child1() + +# Create an intersection type through narrowing: +if isinstance(x, Child1): + if isinstance(x, Child2): + reveal_type(x) # revealed: Child1 & Child2 + + reveal_type(x == 1) # revealed: Literal[True] + + # Other comparison operators fall back to the base type: + reveal_type(x > 1) # revealed: bool + reveal_type(x is c1) # revealed: bool +``` + +## Negative contributions + +Negative contributions to the intersection type only allow simplifications in a few special cases +(equality and identity comparisons). + +### Equality comparisons + +#### Literal strings + +```py +x = "x" * 1_000_000_000 +y = "y" * 1_000_000_000 +reveal_type(x) # revealed: LiteralString + +if x != "abc": + reveal_type(x) # revealed: LiteralString & ~Literal["abc"] + + reveal_type(x == "abc") # revealed: Literal[False] + reveal_type("abc" == x) # revealed: Literal[False] + reveal_type(x == "something else") # revealed: bool + reveal_type("something else" == x) # revealed: bool + + reveal_type(x != "abc") # revealed: Literal[True] + reveal_type("abc" != x) # revealed: Literal[True] + reveal_type(x != "something else") # revealed: bool + reveal_type("something else" != x) # revealed: bool + + reveal_type(x == y) # revealed: bool + reveal_type(y == x) # revealed: bool + reveal_type(x != y) # revealed: bool + reveal_type(y != x) # revealed: bool + + reveal_type(x >= "abc") # revealed: bool + reveal_type("abc" >= x) # revealed: bool + + reveal_type(x in "abc") # revealed: bool + reveal_type("abc" in x) # revealed: bool +``` + +#### Integers + +```py +def get_int() -> int: ... + +x = get_int() + +if x != 1: + reveal_type(x) # revealed: int & ~Literal[1] + + reveal_type(x != 1) # revealed: Literal[True] + reveal_type(x != 2) # revealed: bool + + reveal_type(x == 1) # revealed: Literal[False] + reveal_type(x == 2) # revealed: bool +``` + +### Identity comparisons + +```py +class A: ... + +def get_object() -> object: ... + +o = object() + +a = A() +n = None + +if o is not None: + reveal_type(o) # revealed: object & ~None + + reveal_type(o is n) # revealed: Literal[False] + reveal_type(o is not n) # revealed: Literal[True] +``` + +## Diagnostics + +### Unsupported operators for positive contributions + +Raise an error if any of the positive contributions to the intersection type are unsupported for the +given operator: + +```py +class Container: + def __contains__(self, x) -> bool: ... + +class NonContainer: ... + +def get_object() -> object: ... + +x = get_object() + +if isinstance(x, Container): + if isinstance(x, NonContainer): + reveal_type(x) # revealed: Container & NonContainer + + # error: [unsupported-operator] "Operator `in` is not supported for types `int` and `NonContainer`" + reveal_type(2 in x) # revealed: bool +``` + +### Unsupported operators for negative contributions + +Do *not* raise an error if any of the negative contributions to the intersection type are +unsupported for the given operator: + +```py +class Container: + def __contains__(self, x) -> bool: ... + +class NonContainer: ... + +def get_object() -> object: ... + +x = get_object() + +if isinstance(x, Container): + if not isinstance(x, NonContainer): + reveal_type(x) # revealed: Container & ~NonContainer + + # No error here! + reveal_type(2 in x) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index aae55c04c6991..ed26412c87a01 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -57,9 +57,9 @@ use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol, Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType, - IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind, - SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay, - UnionBuilder, UnionType, + IntersectionBuilder, IntersectionType, IterationOutcome, KnownClass, KnownFunction, + KnownInstance, MetaclassErrorKind, SliceLiteralType, StringLiteralType, Symbol, Truthiness, + TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType, }; use crate::unpack::Unpack; use crate::util::subscript::{PyIndex, PySlice}; @@ -266,6 +266,13 @@ impl<'db> TypeInference<'db> { } } +/// Whether the intersection type is on the left or right side of the comparison. +#[derive(Debug, Clone, Copy)] +enum IntersectionOn { + Left, + Right, +} + /// Builder to infer all types in a region. /// /// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling @@ -3086,7 +3093,7 @@ impl<'db> TypeInferenceBuilder<'db> { // https://docs.python.org/3/reference/expressions.html#comparisons // > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison - // > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and + // > operators, then `a op1 b op2 c ... y opN z` is equivalent to `a op1 b and b op2 c and // ... > y opN z`, except that each expression is evaluated at most once. // // As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below @@ -3140,6 +3147,87 @@ impl<'db> TypeInferenceBuilder<'db> { ) } + fn infer_binary_intersection_type_comparison( + &mut self, + intersection: IntersectionType<'db>, + op: ast::CmpOp, + other: Type<'db>, + intersection_on: IntersectionOn, + ) -> Result, CompareUnsupportedError<'db>> { + // If a comparison yields a definitive true/false answer on a (positive) part + // of an intersection type, it will also yield a definitive answer on the full + // intersection type, which is even more specific. + for pos in intersection.positive(self.db) { + let result = match intersection_on { + IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?, + IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?, + }; + if let Type::BooleanLiteral(b) = result { + return Ok(Type::BooleanLiteral(b)); + } + } + + // For negative contributions to the intersection type, there are only a few + // special cases that allow us to narrow down the result type of the comparison. + for neg in intersection.negative(self.db) { + let result = match intersection_on { + IntersectionOn::Left => self.infer_binary_type_comparison(*neg, op, other).ok(), + IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *neg).ok(), + }; + + match (op, result) { + (ast::CmpOp::Eq, Some(Type::BooleanLiteral(true))) => { + return Ok(Type::BooleanLiteral(false)); + } + (ast::CmpOp::NotEq, Some(Type::BooleanLiteral(false))) => { + return Ok(Type::BooleanLiteral(true)); + } + (ast::CmpOp::Is, Some(Type::BooleanLiteral(true))) => { + return Ok(Type::BooleanLiteral(false)); + } + (ast::CmpOp::IsNot, Some(Type::BooleanLiteral(false))) => { + return Ok(Type::BooleanLiteral(true)); + } + _ => {} + } + } + + // If none of the simplifications above apply, we still need to return *some* + // result type for the comparison 'T_inter `op` T_other' (or reversed), where + // + // T_inter = P1 & P2 & ... & Pn & ~N1 & ~N2 & ... & ~Nm + // + // is the intersection type. If f(T) is the function that computes the result + // type of a `op`-comparison with `T_other`, we are interested in f(T_inter). + // Since we can't compute it exactly, we return the following approximation: + // + // f(T_inter) = f(P1) & f(P2) & ... & f(Pn) + // + // The reason for this is the following: In general, for any function 'f', the + // set f(A) & f(B) can be *larger than* the set f(A & B). This means that we + // will return a type that is too wide, which is not necessarily problematic. + // + // However, we do have to leave out the negative contributions. If we were to + // add a contribution like ~f(N1), we would potentially infer result types + // that are too narrow, since ~f(A) can be larger than f(~A). + // + // As an example for this, consider the intersection type `int & ~Literal[1]`. + // If 'f' would be the `==`-comparison with 2, we obviously can't tell if that + // answer would be true or false, so we need to return `bool`. However, if we + // compute f(int) & ~f(Literal[1]), we get `bool & ~Literal[False]`, which can + // be simplified to `Literal[True]` -- a type that is too narrow. + let mut builder = IntersectionBuilder::new(self.db); + for pos in intersection.positive(self.db) { + let result = match intersection_on { + IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?, + IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?, + }; + builder = builder.add_positive(result); + } + + Ok(builder.build()) + } + /// Infers the type of a binary comparison (e.g. 'left == right'). See /// `infer_compare_expression` for the higher level logic dealing with multi-comparison /// expressions. @@ -3172,6 +3260,21 @@ impl<'db> TypeInferenceBuilder<'db> { Ok(builder.build()) } + (Type::Intersection(intersection), right) => self + .infer_binary_intersection_type_comparison( + intersection, + op, + right, + IntersectionOn::Left, + ), + (left, Type::Intersection(intersection)) => self + .infer_binary_intersection_type_comparison( + intersection, + op, + left, + IntersectionOn::Right, + ), + (Type::IntLiteral(n), Type::IntLiteral(m)) => match op { ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)), ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),