Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pod_template and pod_template_name arguments for ContainerTask #1515

Merged
merged 8 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.3.7
flyteidl==1.3.12
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
55 changes: 46 additions & 9 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


# TODO: do we need pod_template here? Seems that it is a raw container not running in pods
class ContainerTask(PythonTask):
"""
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
Expand Down Expand Up @@ -47,6 +49,8 @@ def __init__(
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
sec_ctx = None
Expand All @@ -55,6 +59,11 @@ def __init__(
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)

# pod_template_name overwrites the metadata.pod_template_name
metadata = metadata or TaskMetadata()
metadata.pod_template_name = pod_template_name

super().__init__(
task_type="raw-container",
name=name,
Expand All @@ -74,6 +83,7 @@ def __init__(
self._resources = ResourceSpec(
requests=requests if requests else Resources(), limits=limits if limits else Resources()
)
self.pod_template = pod_template

@property
def resources(self) -> ResourceSpec:
Expand All @@ -91,19 +101,29 @@ def execute(self, **kwargs) -> Any:
return None

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
return None

return self._get_container(settings)

def _get_data_loading_config(self) -> _task_model.DataLoadingConfig:
return _task_model.DataLoadingConfig(
input_path=self._input_data_dir,
output_path=self._output_data_dir,
format=self._md_format.value,
enabled=True,
io_strategy=self._io_strategy.value if self._io_strategy else None,
)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = settings.env or {}
env = {**env, **self.environment} if self.environment else env
return _get_container_definition(
image=self._image,
command=self._cmd,
args=self._args,
data_loading_config=_task_model.DataLoadingConfig(
input_path=self._input_data_dir,
output_path=self._output_data_dir,
format=self._md_format.value,
enabled=True,
io_strategy=self._io_strategy.value if self._io_strategy else None,
),
data_loading_config=self._get_data_loading_config(),
environment=env,
storage_request=self.resources.requests.storage,
ephemeral_storage_request=self.resources.requests.ephemeral_storage,
Expand All @@ -116,3 +136,20 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
gpu_limit=self.resources.limits.gpu,
memory_limit=self.resources.limits.mem,
)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
),
data_config=self._get_data_loading_config(),
)

def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
if self.pod_template is None:
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}
55 changes: 3 additions & 52 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import importlib
import re
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements
from typing import Callable, Dict, List, Optional, TypeVar

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
Expand All @@ -16,7 +12,7 @@
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext
Expand All @@ -25,10 +21,6 @@
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
"""
A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the
Expand Down Expand Up @@ -206,52 +198,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
memory_limit=self.resources.limits.mem,
)

def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]:
containers = cast(PodTemplate, self.pod_template).pod_spec.containers
primary_exists = False

for container in containers:
if container.name == cast(PodTemplate, self.pod_template).primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the default values used in the regular Python task.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == cast(PodTemplate, self.pod_template).primary_container_name:
sdk_default_container = self._get_container(settings)
container.image = sdk_default_container.image
# clear existing commands
container.command = sdk_default_container.command
# also clear existing args
container.args = sdk_default_container.args
limits, requests = {}, {}
for resource in sdk_default_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in sdk_default_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
container.env or []
)
final_containers.append(container)
cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=self._serialize_pod_spec(settings),
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
Expand Down
53 changes: 52 additions & 1 deletion flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import time as _time
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.core.pod_template import PodTemplate
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models import task as task_models


Expand Down Expand Up @@ -125,6 +131,51 @@ def _get_container_definition(
)


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]:
containers = cast(PodTemplate, pod_template).pod_spec.containers
primary_exists = False

for container in containers:
if container.name == cast(PodTemplate, pod_template).primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the values given to ContainerTask.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == cast(PodTemplate, pod_template).primary_container_name:
container.image = primary_container.image
container.command = primary_container.command
container.args = primary_container.args

limits, requests = {}, {}
for resource in primary_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in primary_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
if primary_container.env is not None:
container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + (
container.env or []
)
final_containers.append(container)
cast(PodTemplate, pod_template).pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec)


def load_proto_from_file(pb2_type, path):
with open(path, "rb") as reader:
out = pb2_type()
Expand Down
16 changes: 15 additions & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata):


class K8sPod(_common.FlyteIdlEntity):
def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None):
def __init__(
self,
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
):
"""
This defines a kubernetes pod target. It will build the pod target during task execution
"""
self._metadata = metadata
self._pod_spec = pod_spec
self._data_config = data_config

@property
def metadata(self) -> K8sObjectMetadata:
Expand All @@ -883,17 +889,25 @@ def metadata(self) -> K8sObjectMetadata:
def pod_spec(self) -> typing.Dict[str, typing.Any]:
return self._pod_spec

@property
def data_config(self) -> typing.Optional[DataLoadingConfig]:
return self._data_config

def to_flyte_idl(self) -> _core_task.K8sPod:
return _core_task.K8sPod(
metadata=self._metadata.to_flyte_idl(),
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _core_task.K8sPod):
return cls(
metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata),
pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None,
data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config)
if pb2_object.HasField("data_config")
else None,
)


Expand Down
3 changes: 2 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
from flytekit.core.container_task import ContainerTask
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.map_task import MapPythonTask
Expand Down Expand Up @@ -189,7 +190,7 @@ def get_serializable_task(
# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
# the pod spec is a K8s library object, and we shouldn't be messing around with it in this file.
elif pod:
elif pod and not isinstance(entity, ContainerTask):
if isinstance(entity, MapPythonTask):
entity.set_command_prefix(get_command_prefix_for_fast_execute(settings))
pod = entity.get_k8s_pod(settings)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.3.5,<1.4.0",
"flyteidl>=1.3.12,<1.4.0",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
Loading