diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 37baacef70..a110f0bd53 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -7,7 +7,6 @@ import base64 import hashlib -import importlib import os import pathlib import tempfile @@ -34,7 +33,6 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceSpec -from flytekit.core.tracker import get_full_module_path from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.exceptions import user as user_exceptions @@ -70,7 +68,7 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.fast_registration import fast_package -from flytekit.tools.script_mode import compress_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file from flytekit.tools.translator import ( FlyteControlPlaneEntity, FlyteLocalEntity, @@ -821,8 +819,7 @@ def register_script( with tempfile.TemporaryDirectory() as tmp_dir: archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) - mod = importlib.import_module(module_name) - compress_single_script(source_path, str(archive_fname), get_full_module_path(mod, mod.__name__)) + compress_scripts(source_path, str(archive_fname), module_name) md5_bytes, upload_native_url = self._upload_file( archive_fname, project or self.default_project, domain or self.default_domain ) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 1f3e31a382..1b494925ae 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,5 +1,6 @@ import gzip import hashlib +import importlib import os import shutil import tarfile @@ -7,8 +8,12 @@ import typing from pathlib import Path +from flytekit import PythonFunctionTask +from flytekit.core.tracker import get_full_module_path +from flytekit.core.workflow import WorkflowBase -def compress_single_script(source_path: str, destination: str, full_module_name: str): + +def compress_scripts(source_path: str, destination: str, module_name: str): """ Compresses the single script while maintaining the folder structure for that file. @@ -33,33 +38,14 @@ def compress_single_script(source_path: str, destination: str, full_module_name: │   ├── example.py │   └── __init__.py - Note how `another_example.py` and `yet_another_example.py` were not copied to the destination. + Note: If `example.py` didn't import tasks or workflows from `another_example.py` and `yet_another_example.py`, these files were not copied to the destination.. + """ with tempfile.TemporaryDirectory() as tmp_dir: destination_path = os.path.join(tmp_dir, "code") - # This is the script relative path to the root of the project - script_relative_path = Path() - # For each package in pkgs, create a directory and copy the __init__.py in it. - # Skip the last package as that is the script file. - pkgs = full_module_name.split(".") - for p in pkgs[:-1]: - os.makedirs(os.path.join(destination_path, p)) - source_path = os.path.join(source_path, p) - destination_path = os.path.join(destination_path, p) - script_relative_path = Path(script_relative_path, p) - init_file = Path(os.path.join(source_path, "__init__.py")) - if init_file.exists(): - shutil.copy(init_file, Path(os.path.join(tmp_dir, "code", script_relative_path, "__init__.py"))) - - # Ensure destination path exists to cover the case of a single file and no modules. - os.makedirs(destination_path, exist_ok=True) - script_file = Path(source_path, f"{pkgs[-1]}.py") - script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") - # Build the final script relative path and copy it to a known place. - shutil.copy( - script_file, - script_file_destination, - ) + + visited: typing.List[str] = [] + copy_module_to_destination(source_path, destination_path, module_name, visited) tar_path = os.path.join(tmp_dir, "tmp.tar") with tarfile.open(tar_path, "w") as tar: tar.add(os.path.join(tmp_dir, "code"), arcname="", filter=tar_strip_file_attributes) @@ -68,6 +54,50 @@ def compress_single_script(source_path: str, destination: str, full_module_name: gzipped.write(tar_file.read()) +def copy_module_to_destination( + original_source_path: str, original_destination_path: str, module_name: str, visited: typing.List[str] +): + """ + Copy the module (file) to the destination directory. If the module relative imports other modules, flytekit will + recursively copy them as well. + """ + mod = importlib.import_module(module_name) + full_module_name = get_full_module_path(mod, mod.__name__) + if full_module_name in visited: + return + visited.append(full_module_name) + + source_path = original_source_path + destination_path = original_destination_path + pkgs = full_module_name.split(".") + + for p in pkgs[:-1]: + os.makedirs(os.path.join(destination_path, p), exist_ok=True) + destination_path = os.path.join(destination_path, p) + source_path = os.path.join(source_path, p) + init_file = Path(os.path.join(source_path, "__init__.py")) + if init_file.exists(): + shutil.copy(init_file, Path(os.path.join(destination_path, "__init__.py"))) + + # Ensure destination path exists to cover the case of a single file and no modules. + os.makedirs(destination_path, exist_ok=True) + script_file = Path(source_path, f"{pkgs[-1]}.py") + script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") + # Build the final script relative path and copy it to a known place. + shutil.copy( + script_file, + script_file_destination, + ) + + # Try to copy other files to destination if tasks or workflows aren't in the same file + for flyte_entity_name in mod.__dict__: + flyte_entity = mod.__dict__[flyte_entity_name] + if isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) and flyte_entity.instantiated_in: + copy_module_to_destination( + original_source_path, original_destination_path, flyte_entity.instantiated_in, visited + ) + + # Takes in a TarInfo and returns the modified TarInfo: # https://docs.python.org/3/library/tarfile.html#tarinfo-objects # intented to be passed as a filter to tarfile.add diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index a433769075..c597e9bdc2 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -1,13 +1,38 @@ import os +import subprocess +import sys -from flytekit.tools.script_mode import compress_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file + +MAIN_WORKFLOW = """ +from flytekit import task, workflow +from wf1.test import t1 -WORKFLOW = """ @workflow def my_wf() -> str: return "hello world" """ +T1_TASK = """ +from flytekit import task +from wf2.test import t2 + + +@task() +def t1() -> str: + print("hello") + return "hello" +""" + +T2_TASK = """ +from flytekit import task + +@task() +def t2() -> str: + print("hello") + return "hello" +""" + def test_deterministic_hash(tmp_path): workflows_dir = tmp_path / "workflows" @@ -17,19 +42,42 @@ def test_deterministic_hash(tmp_path): open(workflows_dir / "__init__.py", "a").close() # Write a dummy workflow workflow_file = workflows_dir / "hello_world.py" - workflow_file.write_text(WORKFLOW) + workflow_file.write_text(MAIN_WORKFLOW) + + t1_dir = tmp_path / "wf1" + t1_dir.mkdir() + open(t1_dir / "__init__.py", "a").close() + t1_file = t1_dir / "test.py" + t1_file.write_text(T1_TASK) + + t2_dir = tmp_path / "wf2" + t2_dir.mkdir() + open(t2_dir / "__init__.py", "a").close() + t2_file = t2_dir / "test.py" + t2_file.write_text(T2_TASK) destination = tmp_path / "destination" - compress_single_script(workflows_dir, destination, "hello_world") - print(f"{os.listdir(tmp_path)}") + print(workflows_dir) + sys.path.append(str(workflows_dir.parent)) + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.hello_world") digest, hex_digest = hash_file(destination) # Try again to assert digest determinism destination2 = tmp_path / "destination2" - compress_single_script(workflows_dir, destination2, "hello_world") + compress_scripts(str(workflows_dir.parent), str(destination2), "workflows.hello_world") digest2, hex_digest2 = hash_file(destination) assert digest == digest2 assert hex_digest == hex_digest2 + + test_dir = tmp_path / "test" + test_dir.mkdir() + + result = subprocess.run( + ["tar", "-xvf", destination, "-C", test_dir], + stdout=subprocess.PIPE, + ) + result.check_returncode() + assert len(next(os.walk(test_dir))[1]) == 3