Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ending the task directly from the triggerer without going into the worker. #40084

Merged
merged 6 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import json
from typing import TYPE_CHECKING

from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance

Expand Down Expand Up @@ -68,22 +70,33 @@ class TaskCallbackRequest(CallbackRequest):

:param full_filepath: File Path to use to run the callback
:param simple_task_instance: Simplified Task Instance representation
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
:param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param task_callback_type: e.g. whether on success, on failure, on retry.
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
full_filepath: str,
simple_task_instance: SimpleTaskInstance,
is_failure_callback: bool | None = True,
processor_subdir: str | None = None,
msg: str | None = None,
task_callback_type: TaskInstanceState | None = None,
):
super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
self.task_callback_type = task_callback_type

@property
def is_failure_callback(self) -> bool:
"""Returns True if the callback is a failure callback."""
if self.task_callback_type is None:
return True
return self.task_callback_type in {
TaskInstanceState.FAILED,
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.UPSTREAM_FAILED,
}

def to_json(self) -> str:
from airflow.serialization.serialized_objects import BaseSerialization
Expand Down
37 changes: 32 additions & 5 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI
from airflow.models.taskinstance import TaskInstance, TaskInstance as TI, _run_finished_callback
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
Expand Down Expand Up @@ -808,8 +808,26 @@ def _execute_dag_callbacks(cls, dagbag: DagBag, request: DagCallbackRequest, ses
@provide_session
def _execute_task_callbacks(
cls, dagbag: DagBag | None, request: TaskCallbackRequest, unit_test_mode: bool, session: Session
):
if not request.is_failure_callback:
) -> None:
"""
Execute the task callbacks.

:param dagbag: the DagBag to use to get the task instance
:param request: the task callback request
:param session: the session to use
"""
try:
callback_type = TaskInstanceState(request.task_callback_type)
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
callback_type = None
is_remote = callback_type in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED)
sunank200 marked this conversation as resolved.
Show resolved Hide resolved

# previously we ignored any request besides failures. now if given callback type directly,
# then we respect it and execute it. additionally because in this scenario the callback
# is submitted remotely, we assume there is no need to mess with state; we simply run
# the callback

if not is_remote and not request.is_failure_callback:
return

simple_ti = request.simple_task_instance
Expand All @@ -820,6 +838,7 @@ def _execute_task_callbacks(
map_index=simple_ti.map_index,
session=session,
)

if not ti:
return

Expand All @@ -841,8 +860,16 @@ def _execute_task_callbacks(
if task:
ti.refresh_from_task(task)

ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session)
cls.logger().info("Executed failure callback for %s in state %s", ti, ti.state)
if callback_type is TaskInstanceState.SUCCESS:
context = ti.get_template_context(session=session)
if TYPE_CHECKING:
assert ti.task
callbacks = ti.task.on_success_callback
_run_finished_callback(callbacks=callbacks, context=context)
cls.logger().info("Executed callback for %s in state %s", ti, ti.state)
elif not is_remote or callback_type is TaskInstanceState.FAILED:
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
ti.handle_failure(error=request.msg, test_mode=unit_test_mode, session=session)
cls.logger().info("Executed callback for %s in state %s", ti, ti.state)
session.flush()

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,10 @@ class TaskDeferred(BaseException):
Signal an operator moving to deferred state.

Special exception raised to signal that the operator it was raised from
wishes to defer until a trigger fires.
wishes to defer until a trigger fires. Triggers can send execution back to task or end the task instance
directly. If the trigger should end the task instance itself, ``method_name`` does not matter,
and can be None; otherwise, provide the name of the method that should be used when
resuming execution in the task.
"""

def __init__(
Expand Down
5 changes: 4 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,10 @@ def defer(
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.

This is achieved by raising a special exception (TaskDeferred)
which is caught in the main _execute_task wrapper.
which is caught in the main _execute_task wrapper. Triggers can send execution back to task or end
the task instance directly. If the trigger will end the task instance itself, ``method_name`` should
be None; otherwise, provide the name of the method that should be used when resuming execution in
the task.
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

Expand Down
12 changes: 12 additions & 0 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
AirflowException,
RemovedInAirflow3Warning,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base
from airflow.stats import Stats
from airflow.utils import timezone
Expand Down Expand Up @@ -512,6 +514,16 @@ def _bag_dag(self, *, dag, root_dag, recursive):
settings.dag_policy(dag)

for task in dag.tasks:
# The listeners are not supported when ending a task via a trigger on asynchronous operators.
if getattr(task, "end_from_trigger", False) and get_listener_manager().has_listeners:
raise AirflowException(
"Listeners are not supported with end_from_trigger=True for deferrable operators. "
"Task %s in DAG %s has end_from_trigger=True with listeners from plugins. "
"Set end_from_trigger=False to use listeners.",
task.task_id,
dag.dag_id,
)

settings.task_policy(task)
except (AirflowClusterPolicyViolation, AirflowClusterPolicySkipDag):
raise
Expand Down
9 changes: 8 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetAliasModel, DatasetModel
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskfail import TaskFail
Expand Down Expand Up @@ -699,6 +698,8 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C

:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

task_to_execute = task_instance.task

if TYPE_CHECKING:
Expand Down Expand Up @@ -1288,6 +1289,8 @@ def _record_task_map_for_downstreams(

:meta private:
"""
from airflow.models.mappedoperator import MappedOperator

if task.dag.__class__ is AttributeRemoved:
task.dag = dag # required after deserialization

Expand Down Expand Up @@ -3454,6 +3457,8 @@ def render_templates(
the unmapped, fully rendered BaseOperator. The original ``self.task``
before replacement is returned.
"""
from airflow.models.mappedoperator import MappedOperator

if not context:
context = self.get_template_context()
original_task = self.task
Expand Down Expand Up @@ -3989,6 +3994,8 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Mapp

def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
from airflow.models.mappedoperator import MappedOperator

if isinstance(operator, MappedOperator):
return True
task_group = operator.task_group
Expand Down
9 changes: 1 addition & 8 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None
TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED
)
):
# Add the event's payload into the kwargs for the task
next_kwargs = task_instance.next_kwargs or {}
next_kwargs["event"] = event.payload
task_instance.next_kwargs = next_kwargs
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
task_instance.state = TaskInstanceState.SCHEDULED
event.handle_submit(task_instance=task_instance)

@classmethod
@internal_api_call
Expand Down
15 changes: 9 additions & 6 deletions airflow/sensors/date_time.py
sunank200 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn, Sequence
from typing import TYPE_CHECKING, Any, NoReturn, Sequence

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -85,18 +85,21 @@ class DateTimeSensorAsync(DateTimeSensor):
It is a drop-in replacement for DateTimeSensor.

:param target_time: datetime after which the job succeeds. (templated)
:param end_from_trigger: End the task directly from the triggerer without going into the worker.
"""

def __init__(self, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=timezone.parse(self.target_time))
self.defer(
trigger=trigger,
method_name="execute_complete",
trigger=DateTimeTrigger(
moment=timezone.parse(self.target_time), end_from_trigger=self.end_from_trigger
),
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 9 additions & 4 deletions airflow/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -59,28 +59,33 @@ class TimeDeltaSensorAsync(TimeDeltaSensor):
Will defers itself to avoid taking up a worker slot while it is waiting.

:param delta: time length to wait after the data interval before succeeding.
:param end_from_trigger: End the task directly from the triggerer without going into the worker.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeDeltaSensorAsync`

"""

def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None:
super().__init__(delta=delta, **kwargs)
self.end_from_trigger = end_from_trigger

def execute(self, context: Context) -> bool | NoReturn:
target_dttm = context["data_interval_end"]
target_dttm += self.delta
if timezone.utcnow() > target_dttm:
# If the target datetime is in the past, return immediately
return True
try:
trigger = DateTimeTrigger(moment=target_dttm)
trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger)
except (TypeError, ValueError) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise

self.defer(trigger=trigger, method_name="execute_complete")

def execute_complete(self, context, event=None) -> None:
"""Execute for when the trigger fires - return immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
13 changes: 7 additions & 6 deletions airflow/sensors/time_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import DateTimeTrigger
Expand Down Expand Up @@ -56,14 +56,16 @@ class TimeSensorAsync(BaseSensorOperator):
This frees up a worker slot while it is waiting.

:param target_time: time after which the job succeeds
:param end_from_trigger: End the task directly from the triggerer without going into the worker.

.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/operator:TimeSensorAsync`
"""

def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
def __init__(self, *, end_from_trigger: bool = False, target_time: datetime.time, **kwargs) -> None:
super().__init__(**kwargs)
self.end_from_trigger = end_from_trigger
self.target_time = target_time

aware_time = timezone.coerce_datetime(
Expand All @@ -73,12 +75,11 @@ def __init__(self, *, target_time: datetime.time, **kwargs) -> None:
self.target_datetime = timezone.convert_to_utc(aware_time)

def execute(self, context: Context) -> NoReturn:
trigger = DateTimeTrigger(moment=self.target_datetime)
self.defer(
trigger=trigger,
trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""Execute when the trigger fires - returns immediately."""
def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
return None
Loading