diff --git a/astroid/builder.py b/astroid/builder.py index ddc1aee71..b80090b51 100644 --- a/astroid/builder.py +++ b/astroid/builder.py @@ -12,6 +12,7 @@ import ast import os +import re import textwrap import types import warnings @@ -33,7 +34,6 @@ # The comment used to select a statement to be extracted # when calling extract_node. _STATEMENT_SELECTOR = "#@" -MISPLACED_TYPE_ANNOTATION_ERROR = "misplaced type annotation" if PY312_PLUS: warnings.filterwarnings("ignore", "invalid escape sequence", SyntaxWarning) @@ -478,9 +478,11 @@ def _parse_string( ) except SyntaxError as exc: # If the type annotations are misplaced for some reason, we do not want - # to fail the entire parsing of the file, so we need to retry the parsing without - # type comment support. - if exc.args[0] != MISPLACED_TYPE_ANNOTATION_ERROR or not type_comments: + # to fail the entire parsing of the file, so we need to retry the + # parsing without type comment support. We use a heuristic for + # determining if the error is due to type annotations. + type_annot_related = re.search(r"#\s+type:", exc.text or "") + if not (type_annot_related and type_comments): raise parser_module = get_parser_module(type_comments=False) diff --git a/tests/test_builder.py b/tests/test_builder.py index 0cc6fb3c4..9de7f16ba 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -883,12 +883,6 @@ def test_module_build_dunder_file() -> None: assert module.path[0] == collections.__file__ -@pytest.mark.xfail( - reason=( - "The builtin ast module does not fail with a specific error " - "for syntax error caused by invalid type comments." - ), -) def test_parse_module_with_invalid_type_comments_does_not_crash(): node = builder.parse( """ diff --git a/tests/test_nodes.py b/tests/test_nodes.py index cc2589396..644ceb150 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1304,6 +1304,10 @@ def test_type_comments_invalid_expression() -> None: def test_type_comments_invalid_function_comments() -> None: module = builder.parse( """ + def func( + # type: () -> int # inside parentheses + ): + pass def func(): # type: something completely invalid pass