diff --git a/packages/kernel/py/stlite-server/stlite_server/codemod.py b/packages/kernel/py/stlite-server/stlite_server/codemod.py index 21bf97c20..24f5d3853 100644 --- a/packages/kernel/py/stlite-server/stlite_server/codemod.py +++ b/packages/kernel/py/stlite-server/stlite_server/codemod.py @@ -1,13 +1,21 @@ import ast from contextlib import contextmanager from enum import Enum -from typing import Self, cast +from typing import NamedTuple, Self, cast # These units defines "code block" in Python. See https://docs.python.org/3/reference/executionmodel.html#structure-of-a-program CodeBlockNode = ast.Module | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef ChildCodeBlockNode = ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef +class MethodCallReplacementRule(NamedTuple): + queried_module: str + queried_func: str + replaced_module: str + replaced_func: str + replaced_module_alias_for_new_import: str + + def patch(code: str | ast.Module, script_path: str) -> ast.Module: if isinstance(code, str): tree = ast.parse(code, script_path, "exec") @@ -16,13 +24,25 @@ def patch(code: str | ast.Module, script_path: str) -> ast.Module: else: raise ValueError("code must be a string or an ast.Module") - targets = { - ("time", "sleep"), - ("streamlit", "write_stream"), + rules = { + MethodCallReplacementRule( + queried_module="time", + queried_func="sleep", + replaced_module="asyncio", + replaced_func="sleep", + replaced_module_alias_for_new_import="__asyncio__", + ), + MethodCallReplacementRule( + queried_module="streamlit", + queried_func="write_stream", + replaced_module="streamlit", + replaced_func="write_stream", + replaced_module_alias_for_new_import="__streamlit__", # This shouldn't be used because the queried module is the same as the replaced module + ), } - scanner = CodeBlockStaticScanner("__main__", None, targets) + scanner = CodeBlockStaticScanner("__main__", None, rules) node_scanner_map = scanner.process(tree) - transformer = CodeBlockTransformer("__main__", None, targets, node_scanner_map) + transformer = CodeBlockTransformer("__main__", None, rules, node_scanner_map) new_tree = transformer.process(tree) new_tree = ast.fix_missing_locations(new_tree) @@ -49,7 +69,7 @@ def __init__( self, code_block_name: str, parent_scanner: Self | None, - wildcard_import_targets: set[tuple[str, str]], + rules: set[MethodCallReplacementRule], ) -> None: """Scan a code block. The `process()` method recursively instantiates this class and scans the child code blocks @@ -60,7 +80,7 @@ def __init__( self.parent_scanner = parent_scanner - self.wildcard_import_targets = wildcard_import_targets + self.rules = rules self.code_block_full_name: str if parent_scanner: @@ -96,9 +116,7 @@ def process( node_scanner_map[tree] = self for child_block in self.child_code_blocks: - scanner = CodeBlockStaticScanner( - child_block.name, self, self.wildcard_import_targets - ) + scanner = CodeBlockStaticScanner(child_block.name, self, self.rules) node_scanner_map[child_block] = scanner child_node_scanner_map = scanner.process(child_block) @@ -235,10 +253,13 @@ def visit(self, node: ast.AST) -> None: # `from time import sleep as ts`: node.module = "time", alias.name = "sleep", alias.asname = "ts" if node.module: if alias.name == "*": - # For a wild-card import, add a binding for a target whose module name is matched. - for module, name in self.wildcard_import_targets: - if node.module == module: - self._bind_name(name, module + "." + name) + # For a wild-card import, add a binding for the queried object whose module name is matched. + for rule in self.rules: + if node.module == rule.queried_module: + self._bind_name( + rule.queried_func, + rule.queried_module + "." + rule.queried_func, + ) else: name = alias.asname or alias.name self._bind_name(name, node.module + "." + alias.name) @@ -281,7 +302,7 @@ def __init__( self, code_block_name: str, parent_transformer: Self | None, - targets: set[tuple[str, str]], + rules: set[MethodCallReplacementRule], node_scanner_map: dict[CodeBlockNode, CodeBlockStaticScanner], ) -> None: super().__init__() @@ -292,7 +313,7 @@ def __init__( self.parent_transformer = parent_transformer - self.targets = targets + self.rules = rules self.code_block_full_name: str if parent_transformer: @@ -442,54 +463,37 @@ def handle_Call(self, node: ast.Call) -> ast.AST: if func_fully_qual_name is None: return node - for target in self.targets: + for rule in self.rules: + queried_full_qual_name = rule.queried_module + "." + rule.queried_func if ( - ".".join(target) == func_fully_qual_name + queried_full_qual_name == func_fully_qual_name and func_fully_qual_name not in self.invalidated_names ): - return self._handle_target_call(node, target) + return self._handle_queried_call(node, rule) return node - def _handle_target_call(self, node: ast.Call, target: tuple[str, str]) -> ast.AST: - if target == ("time", "sleep"): - # Convert the node to `await asyncio.sleep(...)` - if "asyncio" in self.imported_modules: - asyncio_as_name = self.imported_modules["asyncio"] - else: - asyncio_as_name = "__asyncio__" - self.required_imports.add(("asyncio", asyncio_as_name)) - return ast.Await( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=asyncio_as_name, ctx=ast.Load()), - attr="sleep", - ctx=ast.Load(), - ), - args=node.args, - keywords=node.keywords, - ) + def _handle_queried_call( + self, node: ast.Call, rule: MethodCallReplacementRule + ) -> ast.AST: + if rule.replaced_module in self.imported_modules: + replaced_module_imported_name = self.imported_modules[rule.replaced_module] + else: + replaced_module_imported_name = rule.replaced_module_alias_for_new_import + self.required_imports.add( + (rule.replaced_module, replaced_module_imported_name) ) - elif target == ("streamlit", "write_stream"): - # Convert the node to `await st.write_stream(...)` - if "streamlit" in self.imported_modules: - streamlit_as_name = self.imported_modules["streamlit"] - else: - streamlit_as_name = "streamlit" - self.required_imports.add(("streamlit", streamlit_as_name)) - return ast.Await( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id=streamlit_as_name, ctx=ast.Load()), - attr="write_stream", - ctx=ast.Load(), - ), - args=node.args, - keywords=node.keywords, - ) + return ast.Await( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=replaced_module_imported_name, ctx=ast.Load()), + attr=rule.replaced_func, + ctx=ast.Load(), + ), + args=node.args, + keywords=node.keywords, ) - - return node + ) def visit(self, node: ast.AST) -> ast.AST: is_control_flow = isinstance( @@ -517,7 +521,7 @@ def visit(self, node: ast.AST) -> ast.AST: return node transformer = CodeBlockTransformer( - node.name, self, self.targets, self._node_scanner_map + node.name, self, self.rules, self._node_scanner_map ) return transformer.process(node) @@ -548,10 +552,13 @@ def visit(self, node: ast.AST) -> ast.AST: # `from time import sleep as ts`: node.module = "time", alias.name = "sleep", alias.asname = "ts" if node.module: if alias.name == "*": - # For a wild-card import, add a binding for a target whose module name is matched. - for module, name in self.targets: - if node.module == module: - self._bind_name(name, module + "." + name) + # For a wild-card import, add a binding for the queried object whose module name is matched. + for rule in self.rules: + if node.module == rule.queried_module: + self._bind_name( + rule.queried_func, + rule.queried_module + "." + rule.queried_func, + ) else: name = alias.asname or alias.name self._bind_name(name, node.module + "." + alias.name)