Skip to content

Commit

Permalink
Support set and dict comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Aug 22, 2024
1 parent 60908a1 commit 05b95f8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
35 changes: 35 additions & 0 deletions rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,26 @@ def visit_Dict(self, node):
self.__skip('}')
return dict

def visit_DictComp(self, node):
self.__skip('for')
return py.ComprehensionExpression(
random_id(),
self.__source_before('{'),
Markers.EMPTY,
py.ComprehensionExpression.Kind.DICT,
py.KeyValue(
random_id(),
self.__whitespace(),
Markers.EMPTY,
self.__pad_right(self.__convert(node.key), self.__source_before(':')),
self.__convert(node.value),
self.__map_type(node.value)
),
cast(List[py.ComprehensionExpression.Clause], [self.__convert(g) for g in node.generators]),
self.__source_before('}'),
self.__map_type(node)
)

def _map_dict_entry(self, key: Optional[ast.expr], value: ast.expr, last: bool) -> JRightPadded[J]:
if key is None:
element = py.StarExpression(
Expand Down Expand Up @@ -496,6 +516,21 @@ def visit_Set(self, node):
self.__map_type(node)
)

def visit_SetComp(self, node):
prefix = self.__source_before('{')
result = self.__convert(node.elt)
self.__skip('for')
return py.ComprehensionExpression(
random_id(),
prefix,
Markers.EMPTY,
py.ComprehensionExpression.Kind.SET,
result,
cast(List[py.ComprehensionExpression.Clause], [self.__convert(g) for g in node.generators]),
self.__source_before('}'),
self.__map_type(node)
)

def visit_Slice(self, node):
prefix = self.__whitespace()
if node.lower:
Expand Down
36 changes: 33 additions & 3 deletions rewrite/tests/python/comprehension_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
from rewrite.test import rewrite_run, python


def test_basic_comprehension():
def test_basic_list_comprehension():
# language=python
rewrite_run(python("a = [ e+1 for e in [1, 2, ]]"))


def test_comprehension_with_if():
def test_list_comprehension_with_if():
# language=python
rewrite_run(python("a = [ e+1 for e in [1, 2, ] if e > 1]"))


def test_comprehension_with_multiple_ifs():
def test_list_comprehension_with_multiple_ifs():
# language=python
rewrite_run(python("a = [ e+1 for e in [1, 2, ] if e > 1 if e < 10]"))


def test_basic_set_comprehension():
# language=python
rewrite_run(python("a = { e for e in range(10)}"))


def test_set_comprehension_with_if():
# language=python
rewrite_run(python("a = { e for e in range(10) if e > 1}"))


def test_set_comprehension_with_multiple_ifs():
# language=python
rewrite_run(python("a = { e for e in range(10) if e > 1 if e < 10}"))


def test_basic_dict_comprehension():
# language=python
rewrite_run(python("a = {n: n * 2 for n in range(10)}"))


def test_dict_comprehension_with_if():
# language=python
rewrite_run(python("a = {e:e for e in range(10) if e > 1}"))


def test_dict_comprehension_with_multiple_ifs():
# language=python
rewrite_run(python("a = {e:None for e in range(10) if e > 1 if e < 10}"))

0 comments on commit 05b95f8

Please sign in to comment.