Skip to content

Commit

Permalink
fix certain matchers breaking under multiprocessing by initializing t…
Browse files Browse the repository at this point in the history
…hem late (#1204)

* Add is_property check

Skip properties to prevent exceptions

* Delayed initialization of matchers

To support multiprocessing on Windows/macOS
Issue #1181

* Add a test for matcher decorators with multiprocessing
  • Loading branch information
kiri11 committed Sep 25, 2024
1 parent 6a059be commit 9fd67bc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 43 deletions.
31 changes: 31 additions & 0 deletions libcst/codemod/tests/test_codemod_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,34 @@ def baz() -> str:
"- 3 warnings were generated.",
output.stderr,
)

def test_matcher_decorators_multiprocessing(self) -> None:
file_count = 5
code = """
def baz(): # type: int
return 5
"""
with tempfile.TemporaryDirectory() as tmpdir:
p = Path(tmpdir)
# Using more than chunksize=4 files to trigger multiprocessing
for i in range(file_count):
(p / f"mod{i}.py").write_text(CodemodTest.make_fixture_data(code))
output = subprocess.run(
[
sys.executable,
"-m",
"libcst.tool",
"codemod",
# Good candidate since it uses matcher decorators
"convert_type_comments.ConvertTypeComments",
str(p),
"--jobs",
str(file_count),
],
encoding="utf-8",
stderr=subprocess.PIPE,
)
self.assertIn(
f"Transformed {file_count} files successfully.",
output.stderr,
)
95 changes: 52 additions & 43 deletions libcst/matchers/_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class UnionType:
}


def is_property(obj: object, attr_name: str) -> bool:
"""Check if obj.attr is a property without evaluating it."""
return isinstance(getattr(type(obj), attr_name, None), property)


# pyre-ignore We don't care about Any here, its not exposed.
def _match_decorator_unpickler(kwargs: Any) -> "MatchDecoratorMismatch":
return MatchDecoratorMismatch(**kwargs)
Expand Down Expand Up @@ -265,20 +270,22 @@ def _check_types(
)


def _gather_matchers(obj: object) -> Set[BaseMatcherNode]:
visit_matchers: Set[BaseMatcherNode] = set()
def _gather_matchers(obj: object) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
"""
Set of gating matchers that we need to track and evaluate. We use these
in conjunction with the call_if_inside and call_if_not_inside decorators
to determine whether to call a visit/leave function.
"""

for func in dir(obj):
try:
for matcher in getattr(getattr(obj, func), VISIT_POSITIVE_MATCHER_ATTR, []):
visit_matchers.add(cast(BaseMatcherNode, matcher))
for matcher in getattr(getattr(obj, func), VISIT_NEGATIVE_MATCHER_ATTR, []):
visit_matchers.add(cast(BaseMatcherNode, matcher))
except Exception:
# This could be a caculated property, and calling getattr() evaluates it.
# We have no control over the implementation detail, so if it raises, we
# should not crash.
pass
visit_matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {}

for attr_name in dir(obj):
if not is_property(obj, attr_name):
func = getattr(obj, attr_name)
for matcher in getattr(func, VISIT_POSITIVE_MATCHER_ATTR, []):
visit_matchers[cast(BaseMatcherNode, matcher)] = None
for matcher in getattr(func, VISIT_NEGATIVE_MATCHER_ATTR, []):
visit_matchers[cast(BaseMatcherNode, matcher)] = None

return visit_matchers

Expand All @@ -302,16 +309,12 @@ def _gather_constructed_visit_funcs(
] = {}

for funcname in dir(obj):
try:
possible_func = getattr(obj, funcname)
if not ismethod(possible_func):
continue
func = cast(Callable[[cst.CSTNode], None], possible_func)
except Exception:
# This could be a caculated property, and calling getattr() evaluates it.
# We have no control over the implementation detail, so if it raises, we
# should not crash.
if is_property(obj, funcname):
continue
possible_func = getattr(obj, funcname)
if not ismethod(possible_func):
continue
func = cast(Callable[[cst.CSTNode], None], possible_func)
matchers = getattr(func, CONSTRUCTED_VISIT_MATCHER_ATTR, [])
if matchers:
# Make sure that we aren't accidentally putting a @visit on a visit_Node.
Expand All @@ -337,16 +340,12 @@ def _gather_constructed_leave_funcs(
] = {}

for funcname in dir(obj):
try:
possible_func = getattr(obj, funcname)
if not ismethod(possible_func):
continue
func = cast(Callable[[cst.CSTNode], None], possible_func)
except Exception:
# This could be a caculated property, and calling getattr() evaluates it.
# We have no control over the implementation detail, so if it raises, we
# should not crash.
if is_property(obj, funcname):
continue
possible_func = getattr(obj, funcname)
if not ismethod(possible_func):
continue
func = cast(Callable[[cst.CSTNode], None], possible_func)
matchers = getattr(func, CONSTRUCTED_LEAVE_MATCHER_ATTR, [])
if matchers:
# Make sure that we aren't accidentally putting a @leave on a leave_Node.
Expand Down Expand Up @@ -448,12 +447,7 @@ class MatcherDecoratableTransformer(CSTTransformer):

def __init__(self) -> None:
CSTTransformer.__init__(self)
# List of gating matchers that we need to track and evaluate. We use these
# in conjuction with the call_if_inside and call_if_not_inside decorators
# to determine whether or not to call a visit/leave function.
self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
m: None for m in _gather_matchers(self)
}
self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None
# Mapping of matchers to functions. If in the course of visiting the tree,
# a node matches one of these matchers, the corresponding function will be
# called as if it was a visit_* method.
Expand Down Expand Up @@ -486,6 +480,16 @@ def __init__(self) -> None:
expected_none_return=False,
)

@property
def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
if self.__matchers is None:
self.__matchers = _gather_matchers(self)
return self.__matchers

@_matchers.setter
def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None:
self.__matchers = value

def on_visit(self, node: cst.CSTNode) -> bool:
# First, evaluate any matchers that we have which we are not inside already.
self._matchers = _visit_matchers(self._matchers, node, self)
Expand Down Expand Up @@ -660,12 +664,7 @@ class MatcherDecoratableVisitor(CSTVisitor):

def __init__(self) -> None:
CSTVisitor.__init__(self)
# List of gating matchers that we need to track and evaluate. We use these
# in conjuction with the call_if_inside and call_if_not_inside decorators
# to determine whether or not to call a visit/leave function.
self._matchers: Dict[BaseMatcherNode, Optional[cst.CSTNode]] = {
m: None for m in _gather_matchers(self)
}
self.__matchers: Optional[Dict[BaseMatcherNode, Optional[cst.CSTNode]]] = None
# Mapping of matchers to functions. If in the course of visiting the tree,
# a node matches one of these matchers, the corresponding function will be
# called as if it was a visit_* method.
Expand Down Expand Up @@ -693,6 +692,16 @@ def __init__(self) -> None:
expected_none_return=True,
)

@property
def _matchers(self) -> Dict[BaseMatcherNode, Optional[cst.CSTNode]]:
if self.__matchers is None:
self.__matchers = _gather_matchers(self)
return self.__matchers

@_matchers.setter
def _matchers(self, value: Dict[BaseMatcherNode, Optional[cst.CSTNode]]) -> None:
self.__matchers = value

def on_visit(self, node: cst.CSTNode) -> bool:
# First, evaluate any matchers that we have which we are not inside already.
self._matchers = _visit_matchers(self._matchers, node, self)
Expand Down

0 comments on commit 9fd67bc

Please sign in to comment.