diff --git a/python/cog/code_xforms.py b/python/cog/code_xforms.py index 1f8357937..06ea45499 100644 --- a/python/cog/code_xforms.py +++ b/python/cog/code_xforms.py @@ -240,7 +240,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # pylint: disable=inv def _extract_globals(source_code: Union[str, ast.AST]) -> list[ast.Assign]: tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code) - return [x for x in tree.body if isinstance(x, ast.Assign)] + if isinstance(tree, ast.Module): + return [x for x in tree.body if isinstance(x, ast.Assign)] + return [] def _render_globals(globals: list[ast.Assign]) -> str: