From 70c682161d4ba98932b9fc1eb013b7173bf3976a Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Thu, 29 Aug 2024 08:42:43 +0200 Subject: [PATCH] Fix a few more parser bugs --- rewrite/rewrite/python/_parser_visitor.py | 67 ++++++++++++++++--- rewrite/tests/python/all/class_test.py | 12 ++++ .../python/all/method_declaration_test.py | 4 ++ rewrite/tests/python/all/try_test.py | 18 +++++ rewrite/tests/python/all/with_test.py | 16 +++++ 5 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 rewrite/tests/python/all/with_test.py diff --git a/rewrite/rewrite/python/_parser_visitor.py b/rewrite/rewrite/python/_parser_visitor.py index 8faf928..56681b4 100644 --- a/rewrite/rewrite/python/_parser_visitor.py +++ b/rewrite/rewrite/python/_parser_visitor.py @@ -56,7 +56,7 @@ def visit_arguments(self, node) -> JContainer[j.VariableDeclarations]: args = [self.__pad_list_element( self.map_arg(a, node.defaults[i - len(node.defaults)] if i >= first_with_default else None), i == len(node.args) - 1, - end_delim=')') for i, n in enumerate(node.args)] + end_delim=')') for i, a in enumerate(node.args)] return JContainer(prefix, args, Markers.EMPTY) def map_arg(self, node, default=None): @@ -137,6 +137,28 @@ def visit_AsyncFunctionDef(self, node): def visit_ClassDef(self, node): prefix = self.__whitespace() if node.decorator_list else self.__source_before('class') + name = self.__convert_name(node.name) + save_cursor = self._cursor + interfaces_prefix = self.__whitespace() + if self._source[self._cursor] == '(' and node.bases: + self.__skip('(') + interfaces = JContainer( + interfaces_prefix, + [ + self.__pad_list_element(self.__convert(n), i == len(node.bases) - 1, end_delim=')') for i, n in + enumerate(node.bases)], + Markers.EMPTY + ) + elif self._source[self._cursor] == '(': + self.__skip('(') + interfaces = JContainer( + interfaces_prefix, + [self.__pad_right(j.Empty(random_id(), self.__source_before(')'), Markers.EMPTY), Space.EMPTY)], + Markers.EMPTY + ) + else: + interfaces = None + self._cursor = save_cursor return j.ClassDeclaration( random_id(), prefix, @@ -150,16 +172,11 @@ def visit_ClassDef(self, node): [], j.ClassDeclaration.Kind.Type.Class ), - self.__convert_name(node.name), + name, None, None, None, # no `extends`, all in `implements` - None if not node.bases else JContainer( - self.__source_before('('), - [self.__pad_list_element(self.__convert(n), i == len(node.bases) - 1, end_delim=')') for i, n in - enumerate(node.bases)], - Markers.EMPTY, - ), + interfaces, None, self.__convert_block(node.body), self.__map_type(node) @@ -238,7 +255,39 @@ def visit_If(self, node): ) def visit_With(self, node): - raise NotImplementedError("Implement visit_With!") + return j.Try( + random_id(), + self.__source_before('with'), + Markers.EMPTY, + JContainer( + self.__whitespace(), + [self.__pad_list_element(self.__convert(r), i == len(node.items) - 1) for i, r in enumerate(node.items)], + Markers.EMPTY + ), + self.__convert_block(node.body), + [], + None + ) + + def visit_withitem(self, node): + prefix = self.__whitespace() + expr = self.__convert(node.context_expr) + value = self.__pad_left(self.__source_before('as'), expr) + name = self.__convert(node.optional_vars) + return j.Try.Resource( + random_id(), + prefix, + Markers.EMPTY, + j.Assignment( + random_id(), + Space.EMPTY, + Markers.EMPTY, + name, + value, + self.__map_type(node.context_expr) + ), + False + ) def visit_AsyncWith(self, node): raise NotImplementedError("Implement visit_AsyncWith!") diff --git a/rewrite/tests/python/all/class_test.py b/rewrite/tests/python/all/class_test.py index 3a75975..20163ce 100644 --- a/rewrite/tests/python/all/class_test.py +++ b/rewrite/tests/python/all/class_test.py @@ -49,3 +49,15 @@ class Foo(abc.ABC, abc.ABC,): """ ) ) + + +def test_empty_parens(): + # language=python + rewrite_run( + python( + """\ + class Foo ( ): + pass + """ + ) + ) diff --git a/rewrite/tests/python/all/method_declaration_test.py b/rewrite/tests/python/all/method_declaration_test.py index 7390f4b..a194b15 100644 --- a/rewrite/tests/python/all/method_declaration_test.py +++ b/rewrite/tests/python/all/method_declaration_test.py @@ -13,3 +13,7 @@ def foo() : """ ) ) + +def test_one_line(): + # language=python + rewrite_run(python("def f(x): x = x + 1; return x")) diff --git a/rewrite/tests/python/all/try_test.py b/rewrite/tests/python/all/try_test.py index 5b70d7f..a4c0dc4 100644 --- a/rewrite/tests/python/all/try_test.py +++ b/rewrite/tests/python/all/try_test.py @@ -37,3 +37,21 @@ def test(): """ ) ) + + +@pytest.mark.xfail(reason="Implementation still not quite correct", strict=True) +def test_try_else(): + # language=python + rewrite_run( + python( + """\ + def test(): + try: + result = 1 / 1 + except ZeroDivisionError: + print("Caught a division by zero error!") + else: + print("No error occurred, result is:", result) + """ + ) + ) diff --git a/rewrite/tests/python/all/with_test.py b/rewrite/tests/python/all/with_test.py new file mode 100644 index 0000000..b6dfe76 --- /dev/null +++ b/rewrite/tests/python/all/with_test.py @@ -0,0 +1,16 @@ +import pytest + +from rewrite.test import rewrite_run, python + + +def test_with(): + # language=python + rewrite_run( + python( + """\ + def test(i): + with len([]) as x: + pass + """ + ) + )