Skip to content

Commit

Permalink
Replace State by TaskInstanceState in Airflow executors (#32627)
Browse files Browse the repository at this point in the history
* Replace State by TaskInstanceState in Airflow executors

* chaneg state type in change_state method, KubernetesResultsType and KubernetesWatchType to TaskInstanceState

* Fix change_state annotation in CeleryExecutor

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
(cherry picked from commit 9556d6d)
  • Loading branch information
hussein-awala authored and ephraimbuddy committed Aug 8, 2023
1 parent 0655d88 commit 097d2be
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 35 deletions.
8 changes: 4 additions & 4 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")

Expand Down Expand Up @@ -295,7 +295,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
"""
Changes state of the task.
Expand All @@ -317,7 +317,7 @@ def fail(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.FAILED, info)
self.change_state(key, TaskInstanceState.FAILED, info)

def success(self, key: TaskInstanceKey, info=None) -> None:
"""
Expand All @@ -326,7 +326,7 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.SUCCESS, info)
self.change_state(key, TaskInstanceState.SUCCESS, info)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Expand Down
28 changes: 14 additions & 14 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any

from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -68,15 +68,15 @@ def sync(self) -> None:
while self.tasks_to_run:
ti = self.tasks_to_run.pop(0)
if self.fail_fast and not task_succeeded:
self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
ti.set_state(State.UPSTREAM_FAILED)
self.change_state(ti.key, State.UPSTREAM_FAILED)
self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
continue

if self._terminated.is_set():
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED)
ti.set_state(TaskInstanceState.FAILED)
self.change_state(ti.key, TaskInstanceState.FAILED)
continue

task_succeeded = self._run_task(ti)
Expand All @@ -87,11 +87,11 @@ def _run_task(self, ti: TaskInstance) -> bool:
try:
params = self.tasks_params.pop(ti.key, {})
ti.run(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
self.change_state(key, TaskInstanceState.SUCCESS)
return True
except Exception as e:
ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
ti.set_state(TaskInstanceState.FAILED)
self.change_state(key, TaskInstanceState.FAILED)
self.log.exception("Failed to execute task: %s.", str(e))
return False

Expand Down Expand Up @@ -148,14 +148,14 @@ def trigger_tasks(self, open_slots: int) -> None:
def end(self) -> None:
"""Set states of queued tasks to UPSTREAM_FAILED marking them as not executed."""
for ti in self.tasks_to_run:
self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
ti.set_state(State.UPSTREAM_FAILED)
self.change_state(ti.key, State.UPSTREAM_FAILED)
self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)

def terminate(self) -> None:
self._terminated.set()

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
6 changes: 3 additions & 3 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import TYPE_CHECKING, Any

from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
Expand Down Expand Up @@ -75,9 +75,9 @@ def sync(self) -> None:

try:
subprocess.check_call(command, close_fds=True)
self.change_state(key, State.SUCCESS)
self.change_state(key, TaskInstanceState.SUCCESS)
except subprocess.CalledProcessError as e:
self.change_state(key, State.FAILED)
self.change_state(key, TaskInstanceState.FAILED)
self.log.error("Failed to execute task %s.", str(e))

self.commands_to_run = []
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/celery/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from airflow.exceptions import AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -299,7 +299,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback)
self.event_buffer[key] = (State.FAILED, None)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
Expand All @@ -308,7 +308,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we don't need the ID anymore anyway
self.event_buffer[key] = (State.QUEUED, result.task_id)
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)

# If the task runs _really quickly_ we may already have a result!
self.update_task_state(key, result.state, getattr(result, "info", None))
Expand Down Expand Up @@ -355,7 +355,7 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from kubernetes import client
Expand Down Expand Up @@ -425,20 +425,20 @@ def sync(self) -> None:
def _change_state(
self,
key: TaskInstanceKey,
state: str | None,
state: TaskInstanceState | None,
pod_name: str,
namespace: str,
session: Session = NEW_SESSION,
) -> None:
if TYPE_CHECKING:
assert self.kube_scheduler

if state == State.RUNNING:
if state == TaskInstanceState.RUNNING:
self.event_buffer[key] = state, None
return

if self.kube_config.delete_worker_pods:
if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
else:
Expand All @@ -455,6 +455,7 @@ def _change_state(
from airflow.models.taskinstance import TaskInstance

state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
state = TaskInstanceState(state)

self.event_buffer[key] = state, None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.state import TaskInstanceState

# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]

# key, pod state, pod_name, namespace, resource_version
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[TaskInstanceState], str, str, str]

# pod_name, namespace, pod state, annotations, resource_version
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
KubernetesWatchType = Tuple[str, str, Optional[TaskInstanceState], Dict[str, str], str]

ALL_NAMESPACES = "ALL_NAMESPACES"
POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

try:
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
Expand Down Expand Up @@ -223,12 +223,16 @@ def process_status(
# since kube server have received request to delete pod set TI state failed
if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
else:
self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string)
elif status == "Failed":
self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
elif status == "Succeeded":
# We get multiple events once the pod hits a terminal state, and we only want to
# send it along to the scheduler once.
Expand Down Expand Up @@ -256,7 +260,9 @@ def process_status(
pod_name,
annotations_string,
)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
else:
self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string)
else:
Expand Down

0 comments on commit 097d2be

Please sign in to comment.