diff --git a/CHANGES.md b/CHANGES.md index 514fd14036b..405b71a6c2f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,6 +19,9 @@ +- Fix crashes involving comments in parenthesised return types or `X | Y` style unions. + (#4453) + ### Preview style diff --git a/src/black/linegen.py b/src/black/linegen.py index 46945ca2a14..ba6e906a388 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) +def _ensure_trailing_comma( + leaves: List[Leaf], original: Line, opening_bracket: Leaf +) -> bool: + if not leaves: + return False + # Ensure a trailing comma for imports + if original.is_import: + return True + # ...and standalone function arguments + if not original.is_def: + return False + if opening_bracket.value != "(": + return False + # Don't add commas if we already have any commas + if any( + leaf.type == token.COMMA + and ( + Preview.typed_params_trailing_comma not in original.mode + or not is_part_of_annotation(leaf) + ) + for leaf in leaves + ): + return False + + # Find a leaf with a parent (comments don't have parents) + leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None) + if leaf_with_parent is None: + return True + # Don't add commas inside parenthesized return annotations + if get_annotation_type(leaf_with_parent) == "return": + return False + # Don't add commas inside PEP 604 unions + if ( + leaf_with_parent.parent + and leaf_with_parent.parent.next_sibling + and leaf_with_parent.parent.next_sibling.type == token.VBAR + ): + return False + return True + + def bracket_split_build_line( leaves: List[Leaf], original: Line, @@ -1099,40 +1140,15 @@ def bracket_split_build_line( if component is _BracketSplitComponent.body: result.inside_brackets = True result.depth += 1 - if leaves: - no_commas = ( - # Ensure a trailing comma for imports and standalone function arguments - original.is_def - # Don't add one after any comments or within type annotations - and opening_bracket.value == "(" - # Don't add one if there's already one there - and not any( - leaf.type == token.COMMA - and ( - Preview.typed_params_trailing_comma not in original.mode - or not is_part_of_annotation(leaf) - ) - for leaf in leaves - ) - # Don't add one inside parenthesized return annotations - and get_annotation_type(leaves[0]) != "return" - # Don't add one inside PEP 604 unions - and not ( - leaves[0].parent - and leaves[0].parent.next_sibling - and leaves[0].parent.next_sibling.type == token.VBAR - ) - ) - - if original.is_import or no_commas: - for i in range(len(leaves) - 1, -1, -1): - if leaves[i].type == STANDALONE_COMMENT: - continue + if _ensure_trailing_comma(leaves, original, opening_bracket): + for i in range(len(leaves) - 1, -1, -1): + if leaves[i].type == STANDALONE_COMMENT: + continue - if leaves[i].type != token.COMMA: - new_comma = Leaf(token.COMMA, ",") - leaves.insert(i + 1, new_comma) - break + if leaves[i].type != token.COMMA: + new_comma = Leaf(token.COMMA, ",") + leaves.insert(i + 1, new_comma) + break leaves_to_track: Set[LeafID] = set() if component is _BracketSplitComponent.head: diff --git a/src/black/nodes.py b/src/black/nodes.py index dae787939ea..bf8e9e1a36a 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]: def is_part_of_annotation(leaf: Leaf) -> bool: """Returns whether this leaf is part of a type annotation.""" + assert leaf.parent is not None return get_annotation_type(leaf) is not None diff --git a/src/black/trans.py b/src/black/trans.py index 29a978c6b71..1853584108d 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult: break i += 1 - if not is_part_of_annotation(leaf) and not contains_comment: + if not contains_comment and not is_part_of_annotation(leaf): string_indices.append(idx) # Advance to the next non-STRING leaf. diff --git a/tests/data/cases/funcdef_return_type_trailing_comma.py b/tests/data/cases/funcdef_return_type_trailing_comma.py index 9b9b9c673de..14fd763d9d1 100644 --- a/tests/data/cases/funcdef_return_type_trailing_comma.py +++ b/tests/data/cases/funcdef_return_type_trailing_comma.py @@ -142,6 +142,7 @@ def SimplePyFn( Buffer[UInt8, 2], Buffer[UInt8, 2], ]: ... + # output # normal, short, function definition def foo(a, b) -> tuple[int, float]: ... diff --git a/tests/data/cases/function_trailing_comma.py b/tests/data/cases/function_trailing_comma.py index 92f46e27516..63cf3999c2e 100644 --- a/tests/data/cases/function_trailing_comma.py +++ b/tests/data/cases/function_trailing_comma.py @@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr argument1, (one, two,), argument4, argument5, argument6 ) +def foo() -> ( + # comment inside parenthesised return type + int +): + ... + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): + ... + +def foo() -> ( + # comment inside parenthesised new union return type + int | str | bytes +): + ... + +def foo() -> ( + # comment inside plain tuple +): + pass + +def foo(arg: (# comment with non-return annotation + int + # comment with non-return annotation +)): + pass + +def foo(arg: (# comment with non-return annotation + int | range | memoryview + # comment with non-return annotation +)): + pass + +def foo(arg: (# only before + int +)): + pass + +def foo(arg: ( + int + # only after +)): + pass + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +) + # output def f( @@ -176,3 +234,75 @@ def func() -> ( argument5, argument6, ) + + +def foo() -> ( + # comment inside parenthesised return type + int +): ... + + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): ... + + +def foo() -> ( + # comment inside parenthesised new union return type + int + | str + | bytes +): ... + + +def foo() -> ( + # comment inside plain tuple +): + pass + + +def foo( + arg: ( # comment with non-return annotation + int + # comment with non-return annotation + ), +): + pass + + +def foo( + arg: ( # comment with non-return annotation + int + | range + | memoryview + # comment with non-return annotation + ), +): + pass + + +def foo(arg: int): # only before + pass + + +def foo( + arg: ( + int + # only after + ), +): + pass + + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +)