Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
whitphx committed Jul 29, 2024
1 parent 47929f9 commit 7222fd9
Showing 1 changed file with 69 additions and 62 deletions.
131 changes: 69 additions & 62 deletions packages/kernel/py/stlite-server/stlite_server/codemod.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import ast
from contextlib import contextmanager
from enum import Enum
from typing import Self, cast
from typing import NamedTuple, Self, cast

# These units defines "code block" in Python. See https://docs.python.org/3/reference/executionmodel.html#structure-of-a-program
CodeBlockNode = ast.Module | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef
ChildCodeBlockNode = ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef


class MethodCallReplacementRule(NamedTuple):
queried_module: str
queried_func: str
replaced_module: str
replaced_func: str
replaced_module_alias_for_new_import: str


def patch(code: str | ast.Module, script_path: str) -> ast.Module:
if isinstance(code, str):
tree = ast.parse(code, script_path, "exec")
Expand All @@ -16,13 +24,25 @@ def patch(code: str | ast.Module, script_path: str) -> ast.Module:
else:
raise ValueError("code must be a string or an ast.Module")

targets = {
("time", "sleep"),
("streamlit", "write_stream"),
rules = {
MethodCallReplacementRule(
queried_module="time",
queried_func="sleep",
replaced_module="asyncio",
replaced_func="sleep",
replaced_module_alias_for_new_import="__asyncio__",
),
MethodCallReplacementRule(
queried_module="streamlit",
queried_func="write_stream",
replaced_module="streamlit",
replaced_func="write_stream",
replaced_module_alias_for_new_import="__streamlit__", # This shouldn't be used because the queried module is the same as the replaced module
),
}
scanner = CodeBlockStaticScanner("__main__", None, targets)
scanner = CodeBlockStaticScanner("__main__", None, rules)
node_scanner_map = scanner.process(tree)
transformer = CodeBlockTransformer("__main__", None, targets, node_scanner_map)
transformer = CodeBlockTransformer("__main__", None, rules, node_scanner_map)
new_tree = transformer.process(tree)
new_tree = ast.fix_missing_locations(new_tree)

Expand All @@ -49,7 +69,7 @@ def __init__(
self,
code_block_name: str,
parent_scanner: Self | None,
wildcard_import_targets: set[tuple[str, str]],
rules: set[MethodCallReplacementRule],
) -> None:
"""Scan a code block.
The `process()` method recursively instantiates this class and scans the child code blocks
Expand All @@ -60,7 +80,7 @@ def __init__(

self.parent_scanner = parent_scanner

self.wildcard_import_targets = wildcard_import_targets
self.rules = rules

self.code_block_full_name: str
if parent_scanner:
Expand Down Expand Up @@ -96,9 +116,7 @@ def process(

node_scanner_map[tree] = self
for child_block in self.child_code_blocks:
scanner = CodeBlockStaticScanner(
child_block.name, self, self.wildcard_import_targets
)
scanner = CodeBlockStaticScanner(child_block.name, self, self.rules)
node_scanner_map[child_block] = scanner

child_node_scanner_map = scanner.process(child_block)
Expand Down Expand Up @@ -235,10 +253,13 @@ def visit(self, node: ast.AST) -> None:
# `from time import sleep as ts`: node.module = "time", alias.name = "sleep", alias.asname = "ts"
if node.module:
if alias.name == "*":
# For a wild-card import, add a binding for a target whose module name is matched.
for module, name in self.wildcard_import_targets:
if node.module == module:
self._bind_name(name, module + "." + name)
# For a wild-card import, add a binding for the queried object whose module name is matched.
for rule in self.rules:
if node.module == rule.queried_module:
self._bind_name(
rule.queried_func,
rule.queried_module + "." + rule.queried_func,
)
else:
name = alias.asname or alias.name
self._bind_name(name, node.module + "." + alias.name)
Expand Down Expand Up @@ -281,7 +302,7 @@ def __init__(
self,
code_block_name: str,
parent_transformer: Self | None,
targets: set[tuple[str, str]],
rules: set[MethodCallReplacementRule],
node_scanner_map: dict[CodeBlockNode, CodeBlockStaticScanner],
) -> None:
super().__init__()
Expand All @@ -292,7 +313,7 @@ def __init__(

self.parent_transformer = parent_transformer

self.targets = targets
self.rules = rules

self.code_block_full_name: str
if parent_transformer:
Expand Down Expand Up @@ -442,54 +463,37 @@ def handle_Call(self, node: ast.Call) -> ast.AST:
if func_fully_qual_name is None:
return node

for target in self.targets:
for rule in self.rules:
queried_full_qual_name = rule.queried_module + "." + rule.queried_func
if (
".".join(target) == func_fully_qual_name
queried_full_qual_name == func_fully_qual_name
and func_fully_qual_name not in self.invalidated_names
):
return self._handle_target_call(node, target)
return self._handle_queried_call(node, rule)

return node

def _handle_target_call(self, node: ast.Call, target: tuple[str, str]) -> ast.AST:
if target == ("time", "sleep"):
# Convert the node to `await asyncio.sleep(...)`
if "asyncio" in self.imported_modules:
asyncio_as_name = self.imported_modules["asyncio"]
else:
asyncio_as_name = "__asyncio__"
self.required_imports.add(("asyncio", asyncio_as_name))
return ast.Await(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=asyncio_as_name, ctx=ast.Load()),
attr="sleep",
ctx=ast.Load(),
),
args=node.args,
keywords=node.keywords,
)
def _handle_queried_call(
self, node: ast.Call, rule: MethodCallReplacementRule
) -> ast.AST:
if rule.replaced_module in self.imported_modules:
replaced_module_imported_name = self.imported_modules[rule.replaced_module]
else:
replaced_module_imported_name = rule.replaced_module_alias_for_new_import
self.required_imports.add(
(rule.replaced_module, replaced_module_imported_name)
)
elif target == ("streamlit", "write_stream"):
# Convert the node to `await st.write_stream(...)`
if "streamlit" in self.imported_modules:
streamlit_as_name = self.imported_modules["streamlit"]
else:
streamlit_as_name = "streamlit"
self.required_imports.add(("streamlit", streamlit_as_name))
return ast.Await(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=streamlit_as_name, ctx=ast.Load()),
attr="write_stream",
ctx=ast.Load(),
),
args=node.args,
keywords=node.keywords,
)
return ast.Await(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=replaced_module_imported_name, ctx=ast.Load()),
attr=rule.replaced_func,
ctx=ast.Load(),
),
args=node.args,
keywords=node.keywords,
)

return node
)

def visit(self, node: ast.AST) -> ast.AST:
is_control_flow = isinstance(
Expand Down Expand Up @@ -517,7 +521,7 @@ def visit(self, node: ast.AST) -> ast.AST:
return node

transformer = CodeBlockTransformer(
node.name, self, self.targets, self._node_scanner_map
node.name, self, self.rules, self._node_scanner_map
)
return transformer.process(node)

Expand Down Expand Up @@ -548,10 +552,13 @@ def visit(self, node: ast.AST) -> ast.AST:
# `from time import sleep as ts`: node.module = "time", alias.name = "sleep", alias.asname = "ts"
if node.module:
if alias.name == "*":
# For a wild-card import, add a binding for a target whose module name is matched.
for module, name in self.targets:
if node.module == module:
self._bind_name(name, module + "." + name)
# For a wild-card import, add a binding for the queried object whose module name is matched.
for rule in self.rules:
if node.module == rule.queried_module:
self._bind_name(
rule.queried_func,
rule.queried_module + "." + rule.queried_func,
)
else:
name = alias.asname or alias.name
self._bind_name(name, node.module + "." + alias.name)
Expand Down

0 comments on commit 7222fd9

Please sign in to comment.