diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index c38ad33834..9e8e5ef937 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -13,6 +13,7 @@ from flytekit.models import task as _task_models _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" +PRIMARY_CONTAINER_DEFAULT_NAME = "primary" def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> str: @@ -26,17 +27,19 @@ class Pod(object): This plugin helps expose a fully modifiable Kubernetes pod spec to customize the task execution runtime. To use pod tasks: (1) Define a pod spec, and (2) Specify the primary container name. :param V1PodSpec pod_spec: Kubernetes pod spec. https://kubernetes.io/docs/concepts/workloads/pods - :param str primary_container_name: the primary container name + :param str primary_container_name: the primary container name. If provided the pod-spec can contain a container whose name matches the primary_container_name. This will force Flyte to give up control of the primary + container and will expect users to control setting up the container. If you expect your python function to run as is, simply create containers that do not match the default primary-container-name and Flyte will auto-inject a + container for the python function based on the default image provided during serialization. :param Optional[Dict[str, str]] labels: Labels are key/value pairs that are attached to pod spec :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. """ pod_spec: V1PodSpec - primary_container_name: str = _PRIMARY_CONTAINER_NAME_FIELD + primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME labels: Optional[Dict[str, str]] = None annotations: Optional[Dict[str, str]] = None - def __post_init_(self): + def __post_init__(self): if not self.pod_spec: raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined") if not self.primary_container_name: diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 716190b4df..0d6788ac92 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -12,6 +12,7 @@ from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import user from flytekit.extend import ExecutionState from flytekit.tools.translator import get_serializable @@ -473,3 +474,32 @@ def dynamic_task_with_pod_subtask(dummy_input: str) -> str: assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]["resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1 + + +def test_pod_config(): + with pytest.raises(user.FlyteValidationException): + Pod(pod_spec=None) + + with pytest.raises(user.FlyteValidationException): + Pod(pod_spec=V1PodSpec(containers=[]), primary_container_name=None) + + selector = {"node_group": "memory"} + + @task( + task_config=Pod( + pod_spec=V1PodSpec( + containers=[], + node_selector=selector, + ), + ), + requests=Resources( + mem="1G", + ), + ) + def my_pod_task(): + print("hello world") + time.sleep(30000) + + assert my_pod_task.task_config + assert isinstance(my_pod_task.task_config, Pod) + assert my_pod_task.task_config.pod_spec.node_selector == selector