From 30c72e1e88b14a41fcbd73898c80249e4a46d0c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Simon?= Date: Sun, 11 Feb 2024 15:37:26 +0100 Subject: [PATCH 1/3] Narrow individual items when matching a tuple to a sequence pattern --- mypy/checker.py | 17 +++++++++++++ test-data/unit/check-python310.test | 38 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 391f28e93b1d..a32851ed63c5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5086,6 +5086,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None: ) self.remove_capture_conflicts(pattern_type.captures, inferred_types) self.push_type_map(pattern_map) + if pattern_map: + for expr, typ in pattern_map.items(): + self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ)) self.push_type_map(pattern_type.captures) if g is not None: with self.binder.frame_context(can_skip=False, fall_through=3): @@ -5123,6 +5126,20 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _get_recursive_sub_patterns_map( + self, expr: Expression, typ: Type + ) -> dict[Expression, Type]: + sub_patterns_map = dict[Expression, Type]() + typ_ = get_proper_type(typ) + if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType): + # When matching a tuple expression with a sequence pattern, narrow individual tuple items + assert len(expr.items) == len(typ_.items) + for item_expr, item_typ in zip(expr.items, typ_.items): + sub_patterns_map[item_expr] = item_typ + sub_patterns_map.update(self._get_recursive_sub_patterns_map(item_expr, item_typ)) + + return sub_patterns_map + def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]: all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list) for tm in type_maps: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index cbb26a130738..892ab7e28297 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -341,6 +341,44 @@ match m: reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]" [builtins fixtures/list.pyi] +[case testMatchSequencePatternNarrowSubjectItems] +m: int +n: str +o: bool + +match m, n, o: + case [3, "foo", True]: + reveal_type(m) # N: Revealed type is "Literal[3]" + reveal_type(n) # N: Revealed type is "Literal['foo']" + reveal_type(o) # N: Revealed type is "Literal[True]" + case [a, b, c]: + reveal_type(m) # N: Revealed type is "builtins.int" + reveal_type(n) # N: Revealed type is "builtins.str" + reveal_type(o) # N: Revealed type is "builtins.bool" + +reveal_type(m) # N: Revealed type is "builtins.int" +reveal_type(n) # N: Revealed type is "builtins.str" +reveal_type(o) # N: Revealed type is "builtins.bool" +[builtins fixtures/tuple.pyi] + +[case testMatchSequencePatternNarrowSubjectItemsRecursive] +m: int +n: int +o: int +p: int +q: int +r: int + +match m, (n, o), (p, (q, r)): + case [0, [1, 2], [3, [4, 5]]]: + reveal_type(m) # N: Revealed type is "Literal[0]" + reveal_type(n) # N: Revealed type is "Literal[1]" + reveal_type(o) # N: Revealed type is "Literal[2]" + reveal_type(p) # N: Revealed type is "Literal[3]" + reveal_type(q) # N: Revealed type is "Literal[4]" + reveal_type(r) # N: Revealed type is "Literal[5]" +[builtins fixtures/tuple.pyi] + -- Mapping Pattern -- [case testMatchMappingPatternCaptures] From 1c33addc565ea37d58abc9e97daebbb70814954d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Simon?= Date: Sun, 11 Feb 2024 16:04:12 +0100 Subject: [PATCH 2/3] Fix annotation for Python < 3.9 --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index a32851ed63c5..fa6215399fc1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5129,7 +5129,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: - sub_patterns_map = dict[Expression, Type]() + sub_patterns_map: dict[Expression, Type] = {} typ_ = get_proper_type(typ) if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType): # When matching a tuple expression with a sequence pattern, narrow individual tuple items From 92cb2c54b777381a43027aee157b6503cb35b9c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Simon?= Date: Tue, 2 Apr 2024 20:48:26 +0200 Subject: [PATCH 3/3] Add test cases with mismatching subject and target lengths --- test-data/unit/check-python310.test | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 892ab7e28297..3586c192275e 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -379,6 +379,34 @@ match m, (n, o), (p, (q, r)): reveal_type(r) # N: Revealed type is "Literal[5]" [builtins fixtures/tuple.pyi] +[case testMatchSequencePatternSequencesLengthMismatchNoNarrowing] +m: int +n: str +o: bool + +match m, n, o: + case [3, "foo"]: + pass + case [3, "foo", True, True]: + pass +[builtins fixtures/tuple.pyi] + +[case testMatchSequencePatternSequencesLengthMismatchNoNarrowingRecursive] +m: int +n: int +o: int + +match m, (n, o): + case [0]: + pass + case [0, 1, [2]]: + pass + case [0, [1]]: + pass + case [0, [1, 2, 3]]: + pass +[builtins fixtures/tuple.pyi] + -- Mapping Pattern -- [case testMatchMappingPatternCaptures]