-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[red-knot] Type inference for comparisons involving intersection types (
#14138) ## Summary This adds type inference for comparison expressions involving intersection types. For example: ```py x = get_random_int() if x != 42: reveal_type(x == 42) # revealed: Literal[False] reveal_type(x == 43) # bool ``` closes #13854 ## Test Plan New Markdown-based tests. --------- Co-authored-by: Carl Meyer <carl@astral.sh>
- Loading branch information
Showing
2 changed files
with
262 additions
and
4 deletions.
There are no files selected for viewing
155 changes: 155 additions & 0 deletions
155
crates/red_knot_python_semantic/resources/mdtest/comparison/intersections.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Comparison: Intersections | ||
|
||
## Positive contributions | ||
|
||
If we have an intersection type `A & B` and we get a definitive true/false answer for one of the | ||
types, we can infer that the result for the intersection type is also true/false: | ||
|
||
```py | ||
class Base: ... | ||
|
||
class Child1(Base): | ||
def __eq__(self, other) -> Literal[True]: | ||
return True | ||
|
||
class Child2(Base): ... | ||
|
||
def get_base() -> Base: ... | ||
|
||
x = get_base() | ||
c1 = Child1() | ||
|
||
# Create an intersection type through narrowing: | ||
if isinstance(x, Child1): | ||
if isinstance(x, Child2): | ||
reveal_type(x) # revealed: Child1 & Child2 | ||
|
||
reveal_type(x == 1) # revealed: Literal[True] | ||
|
||
# Other comparison operators fall back to the base type: | ||
reveal_type(x > 1) # revealed: bool | ||
reveal_type(x is c1) # revealed: bool | ||
``` | ||
|
||
## Negative contributions | ||
|
||
Negative contributions to the intersection type only allow simplifications in a few special cases | ||
(equality and identity comparisons). | ||
|
||
### Equality comparisons | ||
|
||
#### Literal strings | ||
|
||
```py | ||
x = "x" * 1_000_000_000 | ||
y = "y" * 1_000_000_000 | ||
reveal_type(x) # revealed: LiteralString | ||
|
||
if x != "abc": | ||
reveal_type(x) # revealed: LiteralString & ~Literal["abc"] | ||
|
||
reveal_type(x == "abc") # revealed: Literal[False] | ||
reveal_type("abc" == x) # revealed: Literal[False] | ||
reveal_type(x == "something else") # revealed: bool | ||
reveal_type("something else" == x) # revealed: bool | ||
|
||
reveal_type(x != "abc") # revealed: Literal[True] | ||
reveal_type("abc" != x) # revealed: Literal[True] | ||
reveal_type(x != "something else") # revealed: bool | ||
reveal_type("something else" != x) # revealed: bool | ||
|
||
reveal_type(x == y) # revealed: bool | ||
reveal_type(y == x) # revealed: bool | ||
reveal_type(x != y) # revealed: bool | ||
reveal_type(y != x) # revealed: bool | ||
|
||
reveal_type(x >= "abc") # revealed: bool | ||
reveal_type("abc" >= x) # revealed: bool | ||
|
||
reveal_type(x in "abc") # revealed: bool | ||
reveal_type("abc" in x) # revealed: bool | ||
``` | ||
|
||
#### Integers | ||
|
||
```py | ||
def get_int() -> int: ... | ||
|
||
x = get_int() | ||
|
||
if x != 1: | ||
reveal_type(x) # revealed: int & ~Literal[1] | ||
|
||
reveal_type(x != 1) # revealed: Literal[True] | ||
reveal_type(x != 2) # revealed: bool | ||
|
||
reveal_type(x == 1) # revealed: Literal[False] | ||
reveal_type(x == 2) # revealed: bool | ||
``` | ||
|
||
### Identity comparisons | ||
|
||
```py | ||
class A: ... | ||
|
||
def get_object() -> object: ... | ||
|
||
o = object() | ||
|
||
a = A() | ||
n = None | ||
|
||
if o is not None: | ||
reveal_type(o) # revealed: object & ~None | ||
|
||
reveal_type(o is n) # revealed: Literal[False] | ||
reveal_type(o is not n) # revealed: Literal[True] | ||
``` | ||
|
||
## Diagnostics | ||
|
||
### Unsupported operators for positive contributions | ||
|
||
Raise an error if any of the positive contributions to the intersection type are unsupported for the | ||
given operator: | ||
|
||
```py | ||
class Container: | ||
def __contains__(self, x) -> bool: ... | ||
|
||
class NonContainer: ... | ||
|
||
def get_object() -> object: ... | ||
|
||
x = get_object() | ||
|
||
if isinstance(x, Container): | ||
if isinstance(x, NonContainer): | ||
reveal_type(x) # revealed: Container & NonContainer | ||
|
||
# error: [unsupported-operator] "Operator `in` is not supported for types `int` and `NonContainer`" | ||
reveal_type(2 in x) # revealed: bool | ||
``` | ||
|
||
### Unsupported operators for negative contributions | ||
|
||
Do *not* raise an error if any of the negative contributions to the intersection type are | ||
unsupported for the given operator: | ||
|
||
```py | ||
class Container: | ||
def __contains__(self, x) -> bool: ... | ||
|
||
class NonContainer: ... | ||
|
||
def get_object() -> object: ... | ||
|
||
x = get_object() | ||
|
||
if isinstance(x, Container): | ||
if not isinstance(x, NonContainer): | ||
reveal_type(x) # revealed: Container & ~NonContainer | ||
|
||
# No error here! | ||
reveal_type(2 in x) # revealed: bool | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters