diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index a9a9c4900d..72246bcba4 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,7 +1,10 @@ +from dataclasses import replace +from typing import Optional + import click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config +from flytekit.configuration import Config, ImageConfig from flytekit.loggers import cli_logger from flytekit.remote.remote import FlyteRemote @@ -30,3 +33,28 @@ def get_and_save_remote_with_click_context( if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r + + +def patch_image_config(config_file: Optional[str], image_config: ImageConfig) -> ImageConfig: + """ + Merge ImageConfig object with images defined in config file + """ + # Images come from three places: + # * The default flytekit images, which are already supplied by the base run_level_params. + # * The images provided by the user on the command line. + # * The images provided by the user via the config file, if there is one. (Images on the command line should + # override all). + # + # However, the run_level_params already contains both the default flytekit images (lowest priority), as well + # as the images from the command line (highest priority). So when we read from the config file, we only + # want to add in the images that are missing, including the default, if that's also missing. + additional_image_names = set([v.name for v in image_config.images]) + new_additional_images = [v for v in image_config.images] + new_default = image_config.default_image + if config_file: + cfg_ic = ImageConfig.auto(config_file=config_file) + new_default = new_default or cfg_ic.default_image + for addl in cfg_ic.images: + if addl.name not in additional_image_names: + new_additional_images.append(addl) + return replace(image_config, default_image=new_default, images=new_additional_images) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index c0bdcd2416..9f4ba1a81c 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -6,7 +6,7 @@ from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants -from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, patch_image_config from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.configuration.default_images import DefaultImages from flytekit.loggers import cli_logger @@ -137,12 +137,20 @@ def register( if pkgs: raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") + if non_fast and not version: + raise ValueError("Version is a required parameter in case --non-fast is specified.") + if len(package_or_module) == 0: display_help_with_error( ctx, "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) + # Use extra images in the config file if that file exists + config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) + if config_file: + image_config = patch_image_config(config_file, image_config) + cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " @@ -155,12 +163,12 @@ def register( detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") - fast_serialization_settings = None # Create a zip file containing all the entries. zip_file = fast_package(detected_root, output, deref_symlinks) md5_bytes, _ = hash_file(pathlib.Path(zip_file)) + fast_serialization_settings = None if non_fast is False: # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) @@ -194,5 +202,9 @@ def register( version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa cli_logger.info(f"Computed version is {version}") + click.echo( + f"Registering entities under version {version} using the following serialization settings = {serialization_settings}" + ) + # Register using repo code repo_register(registerable_entities, project, domain, version, remote.client) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index d0b890ba7b..7d7e1e3e2f 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -6,7 +6,7 @@ import os import pathlib import typing -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import cast import click @@ -22,7 +22,11 @@ CTX_PROJECT, CTX_PROJECT_ROOT, ) -from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.helpers import ( + FLYTE_REMOTE_INSTANCE_KEY, + get_and_save_remote_with_click_context, + patch_image_config, +) from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager @@ -517,27 +521,8 @@ def _run(*args, **kwargs): remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] config_file = ctx.obj.get(CTX_CONFIG_FILE) - # Images come from three places: - # * The default flytekit images, which are already supplied by the base run_level_params. - # * The images provided by the user on the command line. - # * The images provided by the user via the config file, if there is one. (Images on the command line should - # override all). - # - # However, the run_level_params already contains both the default flytekit images (lowest priority), as well - # as the images from the command line (highest priority). So when we read from the config file, we only - # want to add in the images that are missing, including the default, if that's also missing. - image_config_from_parent_cmd = run_level_params.get("image_config", None) - additional_image_names = set([v.name for v in image_config_from_parent_cmd.images]) - new_additional_images = [v for v in image_config_from_parent_cmd.images] - new_default = image_config_from_parent_cmd.default_image - if config_file: - cfg_ic = ImageConfig.auto(config_file=config_file) - new_default = new_default or cfg_ic.default_image - for addl in cfg_ic.images: - if addl.name not in additional_image_names: - new_additional_images.append(addl) - - image_config = replace(image_config_from_parent_cmd, default_image=new_default, images=new_additional_images) + image_config = run_level_params.get("image_config") + image_config = patch_image_config(config_file, image_config) remote_entity = remote.register_script( entity, diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index e9661dff6a..05660aa054 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -79,6 +79,26 @@ def test_non_fast_register(mock_client, mock_remote): with open(os.path.join("core", "sample.py"), "w") as f: f.write(sample_file_contents) f.close() - result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core"]) + result = runner.invoke(pyflyte.main, ["register", "--non-fast", "--version", "a-version", "core"]) assert "Output given as None, using a temporary directory at" in result.output shutil.rmtree("core") + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_non_fast_register_require_version(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core"]) + assert result.exit_code == 1 + assert str(result.exception) == "Version is a required parameter in case --non-fast is specified." + shutil.rmtree("core")