Skip to content

Commit

Permalink
Fix a few more parser bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Aug 29, 2024
1 parent fc20f56 commit 70c6821
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 9 deletions.
67 changes: 58 additions & 9 deletions rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def visit_arguments(self, node) -> JContainer[j.VariableDeclarations]:
args = [self.__pad_list_element(
self.map_arg(a, node.defaults[i - len(node.defaults)] if i >= first_with_default else None),
i == len(node.args) - 1,
end_delim=')') for i, n in enumerate(node.args)]
end_delim=')') for i, a in enumerate(node.args)]
return JContainer(prefix, args, Markers.EMPTY)

def map_arg(self, node, default=None):
Expand Down Expand Up @@ -137,6 +137,28 @@ def visit_AsyncFunctionDef(self, node):

def visit_ClassDef(self, node):
prefix = self.__whitespace() if node.decorator_list else self.__source_before('class')
name = self.__convert_name(node.name)
save_cursor = self._cursor
interfaces_prefix = self.__whitespace()
if self._source[self._cursor] == '(' and node.bases:
self.__skip('(')
interfaces = JContainer(
interfaces_prefix,
[
self.__pad_list_element(self.__convert(n), i == len(node.bases) - 1, end_delim=')') for i, n in
enumerate(node.bases)],
Markers.EMPTY
)
elif self._source[self._cursor] == '(':
self.__skip('(')
interfaces = JContainer(
interfaces_prefix,
[self.__pad_right(j.Empty(random_id(), self.__source_before(')'), Markers.EMPTY), Space.EMPTY)],
Markers.EMPTY
)
else:
interfaces = None
self._cursor = save_cursor
return j.ClassDeclaration(
random_id(),
prefix,
Expand All @@ -150,16 +172,11 @@ def visit_ClassDef(self, node):
[],
j.ClassDeclaration.Kind.Type.Class
),
self.__convert_name(node.name),
name,
None,
None,
None, # no `extends`, all in `implements`
None if not node.bases else JContainer(
self.__source_before('('),
[self.__pad_list_element(self.__convert(n), i == len(node.bases) - 1, end_delim=')') for i, n in
enumerate(node.bases)],
Markers.EMPTY,
),
interfaces,
None,
self.__convert_block(node.body),
self.__map_type(node)
Expand Down Expand Up @@ -238,7 +255,39 @@ def visit_If(self, node):
)

def visit_With(self, node):
raise NotImplementedError("Implement visit_With!")
return j.Try(
random_id(),
self.__source_before('with'),
Markers.EMPTY,
JContainer(
self.__whitespace(),
[self.__pad_list_element(self.__convert(r), i == len(node.items) - 1) for i, r in enumerate(node.items)],
Markers.EMPTY
),
self.__convert_block(node.body),
[],
None
)

def visit_withitem(self, node):
prefix = self.__whitespace()
expr = self.__convert(node.context_expr)
value = self.__pad_left(self.__source_before('as'), expr)
name = self.__convert(node.optional_vars)
return j.Try.Resource(
random_id(),
prefix,
Markers.EMPTY,
j.Assignment(
random_id(),
Space.EMPTY,
Markers.EMPTY,
name,
value,
self.__map_type(node.context_expr)
),
False
)

def visit_AsyncWith(self, node):
raise NotImplementedError("Implement visit_AsyncWith!")
Expand Down
12 changes: 12 additions & 0 deletions rewrite/tests/python/all/class_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,15 @@ class Foo(abc.ABC, abc.ABC,):
"""
)
)


def test_empty_parens():
# language=python
rewrite_run(
python(
"""\
class Foo ( ):
pass
"""
)
)
4 changes: 4 additions & 0 deletions rewrite/tests/python/all/method_declaration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ def foo() :
"""
)
)

def test_one_line():
# language=python
rewrite_run(python("def f(x): x = x + 1; return x"))
18 changes: 18 additions & 0 deletions rewrite/tests/python/all/try_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,21 @@ def test():
"""
)
)


@pytest.mark.xfail(reason="Implementation still not quite correct", strict=True)
def test_try_else():
# language=python
rewrite_run(
python(
"""\
def test():
try:
result = 1 / 1
except ZeroDivisionError:
print("Caught a division by zero error!")
else:
print("No error occurred, result is:", result)
"""
)
)
16 changes: 16 additions & 0 deletions rewrite/tests/python/all/with_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from rewrite.test import rewrite_run, python


def test_with():
# language=python
rewrite_run(
python(
"""\
def test(i):
with len([]) as x:
pass
"""
)
)

0 comments on commit 70c6821

Please sign in to comment.