From 4d6b54bd47fbe029570800b81e7019f80c87663c Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Wed, 22 Feb 2023 14:50:38 +0100 Subject: [PATCH 1/7] add pod_template and pod_template_name arguments for ContainerTask Signed-off-by: Felix Ruess --- flytekit/core/container_task.py | 105 +++++++++++++++++++++++++++++--- 1 file changed, 96 insertions(+), 9 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 677142736c..a6694a5c70 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,16 +1,26 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, cast + +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.utils import _get_container_definition from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + + +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + -# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -47,6 +57,8 @@ def __init__( metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None @@ -55,6 +67,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metadata.pod_template_name + kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() + kwargs["metadata"].pod_template_name = pod_template_name + super().__init__( task_type="raw-container", name=name, @@ -74,6 +91,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) + self.pod_template = pod_template @property def resources(self) -> ResourceSpec: @@ -91,19 +109,29 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + + return self._get_container(settings) + + def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: + return _task_model.DataLoadingConfig( + input_path=self._input_data_dir, + output_path=self._output_data_dir, + format=self._md_format.value, + enabled=True, + io_strategy=self._io_strategy.value if self._io_strategy else None, + ) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env return _get_container_definition( image=self._image, command=self._cmd, args=self._args, - data_loading_config=_task_model.DataLoadingConfig( - input_path=self._input_data_dir, - output_path=self._output_data_dir, - format=self._md_format.value, - enabled=True, - io_strategy=self._io_strategy.value if self._io_strategy else None, - ), + data_loading_config=self._get_data_loading_config(), environment=env, storage_request=self.resources.requests.storage, ephemeral_storage_request=self.resources.requests.ephemeral_storage, @@ -116,3 +144,62 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) + + def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: + containers = cast(PodTemplate, self.pod_template).pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the values given to ContainerTask. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + prim_container = self._get_container(settings) + + container.image = self._image + container.command = self._cmd + container.args = self._args + + limits, requests = {}, {} + for resource in prim_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in prim_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + env = settings.env or {} + env = {**env, **self.environment} if self.environment else env + container.env = [V1EnvVar(name=key, value=val) for key, val in env.items()] + (container.env or []) + final_containers.append(container) + cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers + + cast(PodTemplate, self.pod_template).data_config = self._get_data_loading_config() + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=self._serialize_pod_spec(settings), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + ) + + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} From 7cb1515dd3d848682866a33462f61effe5dac1ac Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Mon, 27 Feb 2023 16:00:15 +0100 Subject: [PATCH 2/7] factor out _serialize_pod_spec into separate util function Signed-off-by: Felix Ruess --- flytekit/core/container_task.py | 57 ++------------------------ flytekit/core/python_auto_container.py | 55 ++----------------------- flytekit/core/utils.py | 55 ++++++++++++++++++++++++- 3 files changed, 60 insertions(+), 107 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index a6694a5c70..e152d3ac5a 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,26 +1,18 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, cast - -from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements +from typing import Any, Dict, List, Optional, Tuple, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: - return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") - - class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -145,54 +137,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain memory_limit=self.resources.limits.mem, ) - def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = cast(PodTemplate, self.pod_template).pod_spec.containers - primary_exists = False - - for container in containers: - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - primary_exists = True - break - - if not primary_exists: - # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes - # with the values given to ContainerTask. - # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - prim_container = self._get_container(settings) - - container.image = self._image - container.command = self._cmd - container.args = self._args - - limits, requests = {}, {} - for resource in prim_container.resources.limits: - limits[_sanitize_resource_name(resource)] = resource.value - for resource in prim_container.resources.requests: - requests[_sanitize_resource_name(resource)] = resource.value - resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: - # Important! Only copy over resource requirements if they are non-empty. - container.resources = resource_requirements - env = settings.env or {} - env = {**env, **self.environment} if self.environment else env - container.env = [V1EnvVar(name=key, value=val) for key, val in env.items()] + (container.env or []) - final_containers.append(container) - cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers - - cast(PodTemplate, self.pod_template).data_config = self._get_data_loading_config() - - return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( - pod_spec=self._serialize_pod_spec(settings), + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 113f94a998..774825f347 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,11 +3,7 @@ import importlib import re from abc import ABC -from typing import Any, Callable, Dict, List, Optional, TypeVar, cast - -from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements +from typing import Callable, Dict, List, Optional, TypeVar from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin @@ -16,7 +12,7 @@ from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -25,10 +21,6 @@ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: - return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") - - class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the @@ -206,52 +198,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain memory_limit=self.resources.limits.mem, ) - def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = cast(PodTemplate, self.pod_template).pod_spec.containers - primary_exists = False - - for container in containers: - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - primary_exists = True - break - - if not primary_exists: - # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes - # with the default values used in the regular Python task. - # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - sdk_default_container = self._get_container(settings) - container.image = sdk_default_container.image - # clear existing commands - container.command = sdk_default_container.command - # also clear existing args - container.args = sdk_default_container.args - limits, requests = {}, {} - for resource in sdk_default_container.resources.limits: - limits[_sanitize_resource_name(resource)] = resource.value - for resource in sdk_default_container.resources.requests: - requests[_sanitize_resource_name(resource)] = resource.value - resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: - # Important! Only copy over resource requirements if they are non-empty. - container.resources = resource_requirements - container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( - container.env or [] - ) - final_containers.append(container) - cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers - - return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( - pod_spec=self._serialize_pod_spec(settings), + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index ee2c841465..ee2427d358 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -4,9 +4,15 @@ import time as _time from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, cast +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements + +from flytekit.core.pod_template import PodTemplate from flytekit.loggers import logger +from flytekit.models import task as _task_model from flytekit.models import task as task_models @@ -125,6 +131,53 @@ def _get_container_definition( ) +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + + +def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]: + containers = cast(PodTemplate, pod_template).pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the values given to ContainerTask. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, pod_template).primary_container_name: + container.image = primary_container.image + container.command = primary_container.command + container.args = primary_container.args + + limits, requests = {}, {} + for resource in primary_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in primary_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + if primary_container.env is not None: + container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + cast(PodTemplate, pod_template).pod_spec.containers = final_containers + + cast(PodTemplate, pod_template).data_config = primary_container.data_loading_config + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) + + def load_proto_from_file(pb2_type, path): with open(path, "rb") as reader: out = pb2_type() From 171ebe8b9b26d946319977a5faa186945ec90c0e Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 28 Feb 2023 18:18:08 -0800 Subject: [PATCH 3/7] model file changes, couple other changes Signed-off-by: Yee Hing Tong --- flytekit/core/container_task.py | 3 ++- flytekit/models/task.py | 16 +++++++++++++++- flytekit/tools/translator.py | 3 ++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index e152d3ac5a..15c05b410c 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -68,7 +68,7 @@ def __init__( task_type="raw-container", name=name, interface=Interface(inputs, outputs), - metadata=metadata, + # metadata=metadata, task_config=None, security_ctx=sec_ctx, **kwargs, @@ -146,6 +146,7 @@ def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: labels=self.pod_template.labels, annotations=self.pod_template.annotations, ), + data_config=self._get_data_loading_config(), ) def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: diff --git a/flytekit/models/task.py b/flytekit/models/task.py index fc79c87a2d..f7f1d710c9 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata): class K8sPod(_common.FlyteIdlEntity): - def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None): + def __init__( + self, + metadata: K8sObjectMetadata = None, + pod_spec: typing.Dict[str, typing.Any] = None, + data_config: typing.Optional[DataLoadingConfig] = None, + ): """ This defines a kubernetes pod target. It will build the pod target during task execution """ self._metadata = metadata self._pod_spec = pod_spec + self._data_config = data_config @property def metadata(self) -> K8sObjectMetadata: @@ -883,10 +889,15 @@ def metadata(self) -> K8sObjectMetadata: def pod_spec(self) -> typing.Dict[str, typing.Any]: return self._pod_spec + @property + def data_config(self) -> typing.Optional[DataLoadingConfig]: + return self._data_config + def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl(), pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, + data_config=self.data_config.to_flyte_idl() if self.data_config else None, ) @classmethod @@ -894,6 +905,9 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): return cls( metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata), pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None, + data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) + if pb2_object.HasField("data_config") + else None, ) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5ec249fa4b..8b30fc4d36 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -9,6 +9,7 @@ from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode +from flytekit.core.container_task import ContainerTask from flytekit.core.gate import Gate from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.map_task import MapPythonTask @@ -189,7 +190,7 @@ def get_serializable_task( # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. - elif pod: + elif pod and not isinstance(entity, ContainerTask): if isinstance(entity, MapPythonTask): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) From 569b563347bfb8c5892d1f61cd83f42edd78a87b Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Wed, 1 Mar 2023 17:43:10 +0100 Subject: [PATCH 4/7] minor cleanup Signed-off-by: Felix Ruess --- flytekit/core/container_task.py | 6 +++--- flytekit/core/utils.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 15c05b410c..d51f71d837 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -61,14 +61,14 @@ def __init__( sec_ctx = SecurityContext(secrets=secret_requests) # pod_template_name overwrites the metadata.pod_template_name - kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() - kwargs["metadata"].pod_template_name = pod_template_name + metadata = metadata or TaskMetadata() + metadata.pod_template_name = pod_template_name super().__init__( task_type="raw-container", name=name, interface=Interface(inputs, outputs), - # metadata=metadata, + metadata=metadata, task_config=None, security_ctx=sec_ctx, **kwargs, diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index ee2427d358..24ce4d07d8 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -173,8 +173,6 @@ def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_mode final_containers.append(container) cast(PodTemplate, pod_template).pod_spec.containers = final_containers - cast(PodTemplate, pod_template).data_config = primary_container.data_loading_config - return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) From b959e6573cb75c126bff3710cbddf5556618047a Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Thu, 9 Mar 2023 09:38:09 +0100 Subject: [PATCH 5/7] add unit test for container_task pod_template Signed-off-by: Felix Ruess --- .../flytekit/unit/core/test_container_task.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/flytekit/unit/core/test_container_task.py diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py new file mode 100644 index 0000000000..599061d403 --- /dev/null +++ b/tests/flytekit/unit/core/test_container_task.py @@ -0,0 +1,80 @@ +from kubernetes.client.models import ( + V1Affinity, + V1NodeAffinity, + V1NodeSelectorRequirement, + V1NodeSelectorTerm, + V1PodSpec, + V1PreferredSchedulingTerm, + V1Toleration, +) + +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.container_task import ContainerTask +from flytekit.core.pod_template import PodTemplate +from flytekit.tools.translator import get_serializable_task + + +def test_pod_template(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + ps.runtime_class_name = "nvidia" + nsr = V1NodeSelectorRequirement(key="nvidia.com/gpu.memory", operator="Gt", values=["10000"]) + pref_sched = V1PreferredSchedulingTerm(preference=V1NodeSelectorTerm(match_expressions=[nsr]), weight=1) + ps.affinity = V1Affinity( + node_affinity=V1NodeAffinity(preferred_during_scheduling_ignored_during_execution=[pref_sched]) + ) + pt = PodTemplate(pod_spec=ps, labels={"somelabel": "foobar"}) + + image = "ghcr.io/flyteorg/rawcontainers-shell:v2" + cmd = [ + "./calculate-ellipse-area.sh", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ] + ct = ContainerTask( + name="ellipse-area-metadata-shell", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float, metadata=str), + image=image, + command=cmd, + pod_template=pt, + pod_template_name="my-base-template", + ) + + assert ct.metadata.pod_template_name == "my-base-template" + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + container = ct.get_container(default_serialization_settings) + assert container is None + + k8s_pod = ct.get_k8s_pod(default_serialization_settings) + assert k8s_pod.metadata.labels == {"somelabel": "foobar"} + + primary_container = k8s_pod.pod_spec["containers"][0] + + assert primary_container["image"] == image + assert primary_container["command"] == cmd + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, ct) + assert ts.template.metadata.pod_template_name == "my-base-template" + assert ts.template.container is None + assert ts.template.k8s_pod is not None + serialized_pod_spec = ts.template.k8s_pod.pod_spec + assert serialized_pod_spec["affinity"]["nodeAffinity"] is not None + assert serialized_pod_spec["tolerations"] == [ + {"effect": "NoSchedule", "key": "nvidia.com/gpu", "operator": "Exists"} + ] + assert serialized_pod_spec["runtimeClassName"] == "nvidia" From 6e29a1d236578e9204511fec444d6e643cbe8f1d Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Mon, 13 Mar 2023 18:29:29 +0100 Subject: [PATCH 6/7] bump min version of flyteidl to 1.3.12 for pod template data config support Signed-off-by: Felix Ruess --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 74a466394b..4bc751de1c 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.3.5,<1.4.0", + "flyteidl>=1.3.12,<1.4.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", From 5ae7de7d222b588a7f311c91620f6b4868b936dd Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Thu, 23 Mar 2023 12:34:00 +0100 Subject: [PATCH 7/7] require flyteidl==1.3.12 in doc-requirements.txt Signed-off-by: Felix Ruess --- doc-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 2eb0532253..ceb15916d5 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -199,7 +199,7 @@ flask==2.2.2 # via mlflow flatbuffers==23.1.21 # via tensorflow -flyteidl==1.3.5 +flyteidl==1.3.12 # via flytekit fonttools==4.38.0 # via matplotlib