Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable use of image stanza in pyflyte register and other small changes #1227

Merged
merged 5 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion flytekit/clis/sdk_in_container/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from dataclasses import replace
from typing import Optional

import click

from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE
from flytekit.configuration import Config
from flytekit.configuration import Config, ImageConfig
from flytekit.loggers import cli_logger
from flytekit.remote.remote import FlyteRemote

Expand Down Expand Up @@ -30,3 +33,28 @@ def get_and_save_remote_with_click_context(
if save:
ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r
return r


def patch_image_config(config_file: Optional[str], image_config: ImageConfig) -> ImageConfig:
"""
Merge ImageConfig object with images defined in config file
"""
# Images come from three places:
# * The default flytekit images, which are already supplied by the base run_level_params.
# * The images provided by the user on the command line.
# * The images provided by the user via the config file, if there is one. (Images on the command line should
# override all).
#
# However, the run_level_params already contains both the default flytekit images (lowest priority), as well
# as the images from the command line (highest priority). So when we read from the config file, we only
# want to add in the images that are missing, including the default, if that's also missing.
additional_image_names = set([v.name for v in image_config.images])
new_additional_images = [v for v in image_config.images]
new_default = image_config.default_image
if config_file:
cfg_ic = ImageConfig.auto(config_file=config_file)
new_default = new_default or cfg_ic.default_image
for addl in cfg_ic.images:
if addl.name not in additional_image_names:
new_additional_images.append(addl)
return replace(image_config, default_image=new_default, images=new_additional_images)
16 changes: 14 additions & 2 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from flytekit.clis.helpers import display_help_with_error
from flytekit.clis.sdk_in_container import constants
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, patch_image_config
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.configuration.default_images import DefaultImages
from flytekit.loggers import cli_logger
Expand Down Expand Up @@ -137,12 +137,20 @@ def register(
if pkgs:
raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command")

if non_fast and not version:
raise ValueError("Version is a required parameter in case --non-fast is specified.")

if len(package_or_module) == 0:
display_help_with_error(
ctx,
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
)

# Use extra images in the config file if that file exists
config_file = ctx.obj.get(constants.CTX_CONFIG_FILE)
if config_file:
image_config = patch_image_config(config_file, image_config)

cli_logger.debug(
f"Running pyflyte register from {os.getcwd()} "
f"with images {image_config} "
Expand All @@ -155,12 +163,12 @@ def register(

detected_root = find_common_root(package_or_module)
cli_logger.debug(f"Using {detected_root} as root folder for project")
fast_serialization_settings = None

# Create a zip file containing all the entries.
zip_file = fast_package(detected_root, output, deref_symlinks)
md5_bytes, _ = hash_file(pathlib.Path(zip_file))

fast_serialization_settings = None
if non_fast is False:
# Upload zip file to Admin using FlyteRemote.
md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file))
Expand Down Expand Up @@ -194,5 +202,9 @@ def register(
version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa
cli_logger.info(f"Computed version is {version}")

click.echo(
f"Registering entities under version {version} using the following serialization settings = {serialization_settings}"
)

# Register using repo code
repo_register(registerable_entities, project, domain, version, remote.client)
31 changes: 8 additions & 23 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import pathlib
import typing
from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import cast

import click
Expand All @@ -22,7 +22,11 @@
CTX_PROJECT,
CTX_PROJECT_ROOT,
)
from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context
from flytekit.clis.sdk_in_container.helpers import (
FLYTE_REMOTE_INSTANCE_KEY,
get_and_save_remote_with_click_context,
patch_image_config,
)
from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager
Expand Down Expand Up @@ -517,27 +521,8 @@ def _run(*args, **kwargs):
remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
config_file = ctx.obj.get(CTX_CONFIG_FILE)

# Images come from three places:
# * The default flytekit images, which are already supplied by the base run_level_params.
# * The images provided by the user on the command line.
# * The images provided by the user via the config file, if there is one. (Images on the command line should
# override all).
#
# However, the run_level_params already contains both the default flytekit images (lowest priority), as well
# as the images from the command line (highest priority). So when we read from the config file, we only
# want to add in the images that are missing, including the default, if that's also missing.
image_config_from_parent_cmd = run_level_params.get("image_config", None)
additional_image_names = set([v.name for v in image_config_from_parent_cmd.images])
new_additional_images = [v for v in image_config_from_parent_cmd.images]
new_default = image_config_from_parent_cmd.default_image
if config_file:
cfg_ic = ImageConfig.auto(config_file=config_file)
new_default = new_default or cfg_ic.default_image
for addl in cfg_ic.images:
if addl.name not in additional_image_names:
new_additional_images.append(addl)

image_config = replace(image_config_from_parent_cmd, default_image=new_default, images=new_additional_images)
image_config = run_level_params.get("image_config")
image_config = patch_image_config(config_file, image_config)

remote_entity = remote.register_script(
entity,
Expand Down
22 changes: 21 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def test_non_fast_register(mock_client, mock_remote):
with open(os.path.join("core", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core"])
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "--version", "a-version", "core"])
assert "Output given as None, using a temporary directory at" in result.output
shutil.rmtree("core")


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_non_fast_register_require_version(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core"])
assert result.exit_code == 1
assert str(result.exception) == "Version is a required parameter in case --non-fast is specified."
shutil.rmtree("core")