Skip to content

Commit

Permalink
Add logic to keep referenced globals
Browse files Browse the repository at this point in the history
* Keep a global if it is referenced as a default
from a functions arguments.
  • Loading branch information
8W9aG committed Sep 25, 2024
1 parent 8fc79b0 commit 6b15565
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
50 changes: 45 additions & 5 deletions python/cog/code_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # pylint: disable=i
return extractor.function_source if extractor.function_source else ""


def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str) -> str:
def make_class_methods_empty(
source_code: Union[str, ast.AST], class_name: str, globals: list[ast.Assign]
) -> tuple[str, list[ast.Assign]]:
"""
Transforms the source code of a specified class to remove the bodies of all its methods
and replace them with 'return None'.
Expand All @@ -79,6 +81,15 @@ def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str)
"""

class MethodBodyTransformer(ast.NodeTransformer):
def __init__(self, globals: list[ast.Assign]) -> None:
self.used_globals = set()
self._targets = {
target.id: global_name
for global_name in globals
for target in global_name.targets
if isinstance(target, ast.Name)
}

def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.AST]: # pylint: disable=invalid-name
if node.name == class_name:
for body_item in node.body:
Expand All @@ -87,15 +98,25 @@ def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.AST]: # pylint: di
body_item.body = [ast.Return(value=ast.Constant(value=None))]
# Remove decorators from the function
body_item.decorator_list = []
# Determine if one our globals is referenced by the function.
for default in body_item.args.defaults:
if isinstance(default, ast.Call):
for keyword in default.keywords:
if isinstance(keyword.value, ast.Name):
corresponding_global = self._targets.get(
keyword.value.id
)
if corresponding_global is not None:
self.used_globals.add(corresponding_global)
return node

return None

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
transformer = MethodBodyTransformer()
transformer = MethodBodyTransformer(globals)
transformed_tree = transformer.visit(tree)
class_code = ast.unparse(transformed_tree)
return class_code
return class_code, transformer.used_globals


def extract_method_return_type(
Expand Down Expand Up @@ -217,6 +238,15 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint: disable=inv
return "\n".join(extractor.imports)


def _extract_globals(source_code: Union[str, ast.AST]) -> list[ast.Assign]:
tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
return [x for x in tree.body if isinstance(x, ast.Assign)]


def _render_globals(globals: list[ast.Assign]) -> str:
return "\n".join([ast.unparse(x) for x in globals])


def strip_model_source_code(
source_code: str, class_name: str, method_name: str
) -> Optional[str]:
Expand All @@ -236,13 +266,23 @@ def strip_model_source_code(
class_source = (
None if not class_name else extract_class_source(source_code, class_name)
)
globals = _extract_globals(source_code)
if class_source:
class_source = make_class_methods_empty(class_source, class_name)
class_source, globals = make_class_methods_empty(
class_source, class_name, globals
)
return_type = extract_method_return_type(class_source, class_name, method_name)
return_class_source = (
extract_class_source(source_code, return_type) if return_type else ""
)
model_source = "\n".join([imports, return_class_source, class_source])
rendered_globals = _render_globals(globals)
model_source = "\n".join(
[
x
for x in [imports, rendered_globals, return_class_source, class_source]
if x
]
)
else:
# use class_name specified in cog.yaml as method_name
method_name = class_name
Expand Down
50 changes: 48 additions & 2 deletions python/tests/test_code_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_predict_many_inputs():
new_source = strip_model_source_code(source_code, "Predictor", "predict")
expected_source = """
from cog import BasePredictor, Input, Path
class Predictor(BasePredictor):
def predict(self, no_default: str, default_without_input: str='default', input_with_default: int=Input(default=10), path: Path=Input(description='Some path'), image: Path=Input(description='Some path'), choices: str=Input(choices=['foo', 'bar']), int_choices: int=Input(description='hello', choices=[3, 4, 5])) -> str:
Expand All @@ -59,7 +58,6 @@ def test_predict_output_path_model():
expected_source = """
import os
from cog import BasePredictor, Path
class Predictor(BasePredictor):
def predict(self) -> Path:
Expand Down Expand Up @@ -153,3 +151,51 @@ class Predictor(BasePredictor):
def predict(self, msg: str) -> ModelOutput:
return None"""
), "Stripped code needs to equal the minimum viable type inference."


@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires Python 3.9 or newer")
def test_strip_model_source_code_keeps_referenced_globals():
stripped_code = strip_model_source_code(
"""
import io
from cog import BasePredictor, Path
from typing import Optional
from pydantic import BaseModel
import torch
import numpy as np
INPUT_DIMS = list(np.arange(32, 64, 32))
class ModelOutput(BaseModel):
success: bool
error: Optional[str]
segmentedImage: Optional[Path]
class Predictor(BasePredictor):
# setup code
def predict(self, height: int=Input(description='Height of image', default=128, choices=INPUT_DIMS)) -> ModelOutput:
return ModelOutput(success=False, error=msg, segmentedImage=None)
""",
"Predictor",
"predict",
)
assert (
stripped_code
== """from cog import BasePredictor, Path
from typing import Optional
from pydantic import BaseModel
import numpy as np
INPUT_DIMS = list(np.arange(32, 64, 32))
class ModelOutput(BaseModel):
success: bool
error: Optional[str]
segmentedImage: Optional[Path]
class Predictor(BasePredictor):
def predict(self, height: int=Input(description='Height of image', default=128, choices=INPUT_DIMS)) -> ModelOutput:
return None"""
), "Stripped code needs to equal the minimum viable type inference."

0 comments on commit 6b15565

Please sign in to comment.