Skip to content

Commit

Permalink
find nonlocal
Browse files Browse the repository at this point in the history
  • Loading branch information
yueyinqiu committed Oct 19, 2024
1 parent 56b6c64 commit 0331862
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 34 deletions.
29 changes: 24 additions & 5 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def __init__(
self._manager = manager
self._data = data.split("\n") if data else None
self._global_names: list[dict[str, list[nodes.Global]]] = []
self._nonlocal_names: list[dict[str, list[nodes.Nonlocal]]] = []
# In _nonlocal_names,
# what we save is the function where the variable is created,
# rather than nodes.Nonlocal.
# We don't really need the Nonlocal statement.
self._nonlocal_names: list[dict[str, nodes.FunctionDef]] = []
self._import_from_nodes: list[nodes.ImportFrom] = []
self._delayed_assattr: list[nodes.AssignAttr] = []
self._visit_meths: dict[type[ast.AST], Callable[[ast.AST, NodeNG], NodeNG]] = {}
Expand Down Expand Up @@ -453,7 +457,8 @@ def _save_assignment(self, node: nodes.AssignName | nodes.DelName) -> None:
if self._global_names and node.name in self._global_names[-1]:
node.root().set_local(node.name, node)
elif self._nonlocal_names and node.name in self._nonlocal_names[-1]:
node.root().set_local(node.name, node)
function_def = self._nonlocal_names[-1][node.name]
function_def.set_local(node.name, node)
else:
assert node.parent
assert node.name
Expand Down Expand Up @@ -1396,9 +1401,23 @@ def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal:
end_col_offset=node.end_col_offset,
parent=parent,
)
if self._nonlocal_names:
for name in node.names:
self._nonlocal_names[-1].setdefault(name, []).append(newnode)
names = set(newnode.names)
# Go through the tree and find where those names are created
scope = newnode
while len(names) != 0:
scope = scope.parent
if not scope:
# It's not inside a nested function or there are no variables with that name.
# Just ignore it as visit_global does when global is used in module scope.
break
if isinstance(scope, nodes.FunctionDef):
found = []
for name in names:
if name in scope.locals:

Check failure on line 1416 in astroid/rebuilder.py

View workflow job for this annotation

GitHub Actions / Checks

E1101

Instance of 'Nonlocal' has no 'locals' member
found.append(name)
self._nonlocal_names[-1][name] = scope
for name in found:
names.remove(name)
return newnode

def visit_constant(self, node: ast.Constant, parent: NodeNG) -> nodes.Const:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,39 @@ def test_type_comments_without_content(self) -> None:
)
assert node

def test_locals_with_global_and_nonlocal(self) -> None:
module = builder.parse(
"""
x1 = 1 # Line 2
def f1(): # Line 3
x2 = 2 # Line 4
def f2(): # Line 5
global x1 # Line 6
nonlocal x2 # Line 7
x1 = 1 # Line 8
x2 = 2 # Line 9
x3 = 3 # Line 10
"""
)
self.assertSetEqual(set(module.locals), {"x1", "f1"})
x1 = module.locals["x1"]
f1 = module.locals["f1"][0]
self.assertEqual(len(x1), 2)
self.assertEqual(x1[0].lineno, 2)
self.assertEqual(x1[1].lineno, 8)

self.assertSetEqual(set(f1.locals), {"x2", "f2"})
x2 = f1.locals["x2"]
f2 = f1.locals["f2"][0]
self.assertEqual(len(x2), 2)
self.assertEqual(x2[0].lineno, 4)
self.assertEqual(x2[1].lineno, 9)

self.assertSetEqual(set(f2.locals), {"x3"})
x3 = f2.locals["x3"]
self.assertEqual(len(x3), 1)
self.assertEqual(x3[0].lineno, 10)


class FileBuildTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
29 changes: 0 additions & 29 deletions tests/test_locals.py

This file was deleted.

0 comments on commit 0331862

Please sign in to comment.