Skip to content

Commit

Permalink
Scope provider changes for type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol committed Sep 23, 2023
1 parent 37277e5 commit 0c34d0f
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 38 deletions.
6 changes: 3 additions & 3 deletions libcst/helpers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Type
from typing import Type, TypeVar

from libcst._types import CSTNodeT
T = TypeVar("T")


def ensure_type(node: object, nodetype: Type[CSTNodeT]) -> CSTNodeT:
def ensure_type(node: object, nodetype: Type[T]) -> T:
"""
Takes any python object, and a LibCST :class:`~libcst.CSTNode` subclass and
refines the type of the python object. This is most useful when you already
Expand Down
146 changes: 111 additions & 35 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import abc
import builtins
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from enum import auto, Enum
from typing import (
Expand Down Expand Up @@ -51,6 +51,10 @@
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
cst.TypeVar,
cst.TypeAlias,
cst.TypeVarTuple,
cst.ParamSpec,
)


Expand Down Expand Up @@ -116,15 +120,17 @@ def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)

def record_assignments(self, name: str) -> None:
assignments = self.scope[name]
assignments = self.scope._resolve_scope_for_access(name, self.scope)
# filter out assignments that happened later than this access
previous_assignments = {
assignment
for assignment in assignments
if assignment.scope != self.scope or assignment._index < self.__index
}
if not previous_assignments and assignments and self.scope.parent != self.scope:
previous_assignments = self.scope.parent[name]
previous_assignments = self.scope.parent._resolve_scope_for_access(
name, self.scope
)
self.__assignments |= previous_assignments


Expand Down Expand Up @@ -440,7 +446,7 @@ def record_access(self, name: str, access: Access) -> None:
self._accesses_by_name[name].add(access)
self._accesses_by_node[access.node].add(access)

def _is_visible_from_children(self) -> bool:
def _is_visible_from_children(self, from_scope: "Scope") -> bool:
"""Returns if the assignments in this scope can be accessed from children.
This is normally True, except for class scopes::
Expand All @@ -459,9 +465,11 @@ def inner_fn():
"""
return True

def _next_visible_parent(self, first: Optional["Scope"] = None) -> "Scope":
def _next_visible_parent(
self, from_scope: "Scope", first: Optional["Scope"] = None
) -> "Scope":
parent = first if first is not None else self.parent
while not parent._is_visible_from_children():
while not parent._is_visible_from_children(from_scope):
parent = parent.parent
return parent

Expand All @@ -470,7 +478,6 @@ def __contains__(self, name: str) -> bool:
"""Check if the name str exist in current scope by ``name in scope``."""
...

@abc.abstractmethod
def __getitem__(self, name: str) -> Set[BaseAssignment]:
"""
Get assignments given a name str by ``scope[name]``.
Expand Down Expand Up @@ -508,6 +515,12 @@ def __getitem__(self, name: str) -> Set[BaseAssignment]:
defined a given name by the time a piece of code is executed.
For the above example, value would resolve to a set of both assignments.
"""
return self._resolve_scope_for_access(name, self)

@abc.abstractmethod
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
...

def __hash__(self) -> int:
Expand Down Expand Up @@ -612,7 +625,9 @@ def __init__(self, globals: Scope) -> None:
def __contains__(self, name: str) -> bool:
return hasattr(builtins, name)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._assignments:
return self._assignments[name]
if hasattr(builtins, name):
Expand Down Expand Up @@ -644,13 +659,15 @@ def __init__(self) -> None:
def __contains__(self, name: str) -> bool:
if name in self._assignments:
return len(self._assignments[name]) > 0
return name in self._next_visible_parent()
return name in self._next_visible_parent(self)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._assignments:
return self._assignments[name]

parent = self._next_visible_parent()
parent = self._next_visible_parent(from_scope)
return parent[name]

def record_global_overwrite(self, name: str) -> None:
Expand Down Expand Up @@ -688,7 +705,7 @@ def record_nonlocal_overwrite(self, name: str) -> None:
def _find_assignment_target(self, name: str) -> "Scope":
if name in self._scope_overwrites:
scope = self._scope_overwrites[name]
return self._next_visible_parent(scope)._find_assignment_target(name)
return self._next_visible_parent(self, scope)._find_assignment_target(name)
else:
return super()._find_assignment_target(name)

Expand All @@ -697,16 +714,22 @@ def __contains__(self, name: str) -> bool:
return name in self._scope_overwrites[name]
if name in self._assignments:
return len(self._assignments[name]) > 0
return name in self._next_visible_parent()
return name in self._next_visible_parent(self)

def __getitem__(self, name: str) -> Set[BaseAssignment]:
def _resolve_scope_for_access(
self, name: str, from_scope: "Scope"
) -> Set[BaseAssignment]:
if name in self._scope_overwrites:
scope = self._scope_overwrites[name]
return self._next_visible_parent(scope)[name]
return self._next_visible_parent(
from_scope, scope
)._resolve_scope_for_access(name, from_scope)
if name in self._assignments:
return self._assignments[name]
else:
return self._next_visible_parent()[name]
return self._next_visible_parent(from_scope)._resolve_scope_for_access(
name, from_scope
)

def _make_name_prefix(self) -> str:
# filter falsey strings out
Expand All @@ -728,8 +751,8 @@ class ClassScope(LocalScope):
When a class is defined, it creates a ClassScope.
"""

def _is_visible_from_children(self) -> bool:
return False
def _is_visible_from_children(self, from_scope: "Scope") -> bool:
return from_scope.parent is self and isinstance(from_scope, AnnotationScope)

def _make_name_prefix(self) -> str:
# filter falsey strings out
Expand All @@ -755,6 +778,19 @@ def _make_name_prefix(self) -> str:
return ".".join(filter(None, [self.parent._name_prefix, "<comprehension>"]))


class AnnotationScope(LocalScope):
"""
Scopes used for type aliases and type parameters as defined by PEP-695.
These scopes are created for type parameters using the special syntax, as well as
type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more.
"""

def _make_name_prefix(self) -> str:
# these scopes are transparent for the purposes of qualified names
return self.parent._name_prefix


# Generates dotted names from an Attribute or Name node:
# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a")
# each string has the corresponding CSTNode attached to it
Expand Down Expand Up @@ -822,6 +858,7 @@ class DeferredAccess:
class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
super().__init__()
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[DeferredAccess] = []
Expand Down Expand Up @@ -992,15 +1029,22 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)

with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)):
node.params.visit(self)
node.body.visit(self)
with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)
with self._new_scope(
FunctionScope, node, get_full_name_for_node(node.name)
):
node.params.visit(self)
node.body.visit(self)

for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)

return False

Expand Down Expand Up @@ -1032,14 +1076,20 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.provider.set_metadata(node.name, self.scope)
for decorator in node.decorators:
decorator.visit(self)
for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)

with ExitStack() as stack:
if node.type_parameters:
stack.enter_context(self._new_scope(AnnotationScope, node, None))
node.type_parameters.visit(self)

for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)

with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)
return False

def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
Expand Down Expand Up @@ -1163,7 +1213,7 @@ def infer_accesses(self) -> None:
access.scope.record_access(name, access)

for (scope, name), accesses in scope_name_accesses.items():
for assignment in scope[name]:
for assignment in scope._resolve_scope_for_access(name, scope):
assignment.record_accesses(accesses)

self.__deferred_accesses = []
Expand All @@ -1174,6 +1224,32 @@ def on_leave(self, original_node: cst.CSTNode) -> None:
self.scope._assignment_count += 1
super().on_leave(original_node)

def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

with self._new_scope(AnnotationScope, node, None):
if node.type_parameters is not None:
node.type_parameters.visit(self)
node.value.visit(self)

return False

def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)

if node.bound is not None:
node.bound.visit(self)

return False

def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False

def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
return False


class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]):
"""
Expand Down
Loading

0 comments on commit 0c34d0f

Please sign in to comment.