Skip to content

Commit

Permalink
Try to better support quoted type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Sep 5, 2024
1 parent 8b69a31 commit 6d5615b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
62 changes: 33 additions & 29 deletions rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,35 +1534,19 @@ def visit_Starred(self, node):


def visit_Subscript(self, node):
if isinstance(node.slice, (ast.Constant, ast.Slice)):
return j.ArrayAccess(
random_id(),
self.__whitespace(),
Markers.EMPTY,
self.__convert(node.value),
j.ArrayDimension(
random_id(),
self.__source_before('['),
Markers.EMPTY,
self.__pad_right(self.__convert(node.slice), self.__source_before(']'))
),
self.__map_type(node)
)
else:
slices = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
return j.ParameterizedType(
return j.ArrayAccess(
random_id(),
self.__whitespace(),
Markers.EMPTY,
self.__convert(node.value),
j.ArrayDimension(
random_id(),
self.__whitespace(),
self.__source_before('['),
Markers.EMPTY,
self.__convert(node.value),
JContainer(
self.__source_before('['),
[self.__pad_list_element(self.__convert(s), last=i == len(slices) - 1, end_delim=']') for i, s in enumerate(slices)],
Markers.EMPTY
),
None,
None
)
self.__pad_right(self.__convert(node.slice), self.__source_before(']'))
),
self.__map_type(node)
)


def visit_Tuple(self, node):
Expand Down Expand Up @@ -1616,10 +1600,30 @@ def __convert_type_hint(self, node) -> Optional[TypeTree]:
self.__map_type(node),
None
)
return self.__convert(node)
elif isinstance(node, ast.Subscript):
slices = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
return j.ParameterizedType(
random_id(),
self.__whitespace(),
Markers.EMPTY,
self.__convert(node.value),
JContainer(
self.__source_before('['),
[self.__pad_list_element(self.__convert_type_hint(s), last=i == len(slices) - 1, end_delim=']') for i, s in
enumerate(slices)],
Markers.EMPTY
),
None,
None
)
return self.__convert_internal(node, self.__convert_type_hint)


def __convert(self, node) -> Optional[J]:
return self.__convert_internal(node, self.__convert)


def __convert_internal(self, node, recursion) -> Optional[J]:
if node:
if isinstance(node, ast.expr) and not isinstance(node, (ast.Tuple, ast.GeneratorExp)):
save_cursor = self._cursor
Expand All @@ -1638,7 +1642,7 @@ def __convert(self, node) -> Optional[J]:
self.__pad_right(e.with_prefix(expr_prefix), r)
), self._cursor))
# handle nested parens
result = self.__convert(node)
result = recursion(node)
else:
self._cursor = save_cursor
result = self.visit(cast(ast.AST, node))
Expand Down
4 changes: 0 additions & 4 deletions rewrite/tests/python/all/array_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from rewrite.test import rewrite_run, python


Expand Down Expand Up @@ -43,13 +41,11 @@ def test_array_slice_full():
rewrite_run(python("a = [1, 2][0:1:1]"))


@pytest.mark.xfail(reason="Need to differentiate from parameterized types", strict=True)
def test_array_slice_tuple_index_1():
# language=python
rewrite_run(python("a = [1, 2][0,1]"))


@pytest.mark.xfail(reason="Need to differentiate from parameterized types", strict=True)
def test_array_slice_tuple_index_2():
# language=python
rewrite_run(python("a = [1, 2][(0,1)]"))
5 changes: 5 additions & 0 deletions rewrite/tests/python/all/type_hint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def test_variable_with_parameterized_type_hint():
rewrite_run(python("""foo: Union[None, ...] = None"""))


def test_variable_with_parameterized_type_hint_in_quotes():
# language=python
rewrite_run(python("""foo: Dict["Foo", str] = None"""))


def test_variable_with_quoted_type_hint():
# language=python
rewrite_run(python("""foo: 'Foo' = None"""))
Expand Down

0 comments on commit 6d5615b

Please sign in to comment.