Skip to content

Commit

Permalink
AIP-72: Pass context keys from API Server to Workers (#44899)
Browse files Browse the repository at this point in the history
Part of #44481

This commit augments the TI context available in the Task Execution Interface with the one from the Execution API Server.

In future PRs the following will be added:

- More methods on TI like ti.xcom_pull, ti.xcom_push etc
- Lazy fetching of connections, variables
- Verifying the "get_current_context" is working
  • Loading branch information
kaxil authored Dec 16, 2024
1 parent 4b38bed commit dbff6e3
Show file tree
Hide file tree
Showing 12 changed files with 506 additions and 91 deletions.
39 changes: 35 additions & 4 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
from airflow.utils.types import DagRunType


class TIEnterRunningPayload(BaseModel):
Expand Down Expand Up @@ -94,9 +97,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
state = v.get("state")
else:
state = getattr(v, "state", None)
if state == TIState.RUNNING:
return str(state)
elif state in set(TerminalTIState):
if state in set(TerminalTIState):
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
Expand All @@ -107,7 +108,6 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
# and "_other_" is a catch-all for all other states that are not covered by the other schemas.
TIStateUpdate = Annotated[
Union[
Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Expand Down Expand Up @@ -135,3 +135,34 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None


class DagRun(BaseModel):
"""Schema for DagRun model with minimal required fields needed for Runtime."""

# TODO: `dag_id` and `run_id` are duplicated from TaskInstance
# See if we can avoid sending these fields from API server and instead
# use the TaskInstance data to get the DAG run information in the client (Task Execution Interface).
dag_id: str
run_id: str

logical_date: UtcDateTime
data_interval_start: UtcDateTime | None
data_interval_end: UtcDateTime | None
start_date: UtcDateTime
end_date: UtcDateTime | None
run_type: DagRunType
conf: Annotated[dict[str, Any], Field(default_factory=dict)]


class TIRunContext(BaseModel):
"""Response schema for TaskInstance run context."""

dag_run: DagRun
"""DAG run information for the task instance."""

variables: Annotated[list[VariableResponse], Field(default_factory=list)]
"""Variables that can be accessed by the task instance."""

connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""
139 changes: 108 additions & 31 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRunContext,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.trigger import Trigger
from airflow.utils import timezone
Expand All @@ -48,6 +51,110 @@
log = logging.getLogger(__name__)


@router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
},
)
def ti_run(
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
) -> TIRunContext:
"""
Run a TaskInstance.
This endpoint is used to start a TaskInstance that is in the QUEUED state.
"""
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state, dag_id, run_id) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_run_payload.model_dump(exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

# TODO: We will need to change this for other states like:
# reschedule, retry, defer etc.
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
previous_state,
)

# TODO: Pass a RFC 9457 compliant error message in "detail" field
# https://datatracker.ietf.org/doc/html/rfc9457
# to provide more information about the error
# FastAPI will automatically convert this to a JSON response
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
},
)
log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname)
# Ensure there is no end date set.
query = query.values(
end_date=None,
hostname=ti_run_payload.hostname,
unixname=ti_run_payload.unixname,
pid=ti_run_payload.pid,
state=State.RUNNING,
)

try:
result = session.execute(query)
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)

dr = session.execute(
select(
DR.run_id,
DR.dag_id,
DR.data_interval_start,
DR.data_interval_end,
DR.start_date,
DR.end_date,
DR.run_type,
DR.conf,
DR.logical_date,
).filter_by(dag_id=dag_id, run_id=run_id)
).one_or_none()

if not dr:
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.")

return TIRunContext(
dag_run=DagRun.model_validate(dr, from_attributes=True),
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
)
except SQLAlchemyError as e:
log.error("Error marking Task Instance state as running: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)


@router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -92,37 +199,7 @@ def ti_update_state(

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TIEnterRunningPayload):
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
previous_state,
)

# TODO: Pass a RFC 9457 compliant error message in "detail" field
# https://datatracker.ietf.org/doc/html/rfc9457
# to provide more information about the error
# FastAPI will automatically convert this to a JSON response
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
},
)
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname)
# Ensure there is no end date set.
query = query.values(
end_date=None,
hostname=ti_patch_payload.hostname,
unixname=ti_patch_payload.unixname,
pid=ti_patch_payload.pid,
state=State.RUNNING,
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
Expand Down
25 changes: 22 additions & 3 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
DagRunType,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRunContext,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariablePostBody,
Expand Down Expand Up @@ -110,11 +112,12 @@ class TaskInstanceOperations:
def __init__(self, client: Client):
self.client = client

def start(self, id: uuid.UUID, pid: int, when: datetime):
def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
Expand Down Expand Up @@ -218,7 +221,23 @@ def auth_flow(self, request: httpx.Request):
# This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
# sense for returning connections etc.
def noop_handler(request: httpx.Request) -> httpx.Response:
log.debug("Dry-run request", method=request.method, path=request.url.path)
path = request.url.path
log.debug("Dry-run request", method=request.method, path=path)

if path.startswith("/task-instances/") and path.endswith("/run"):
# Return a fake context
return httpx.Response(
200,
json={
"dag_run": {
"dag_id": "test_dag",
"run_id": "test_run",
"logical_date": "2021-01-01T00:00:00Z",
"start_date": "2021-01-01T00:00:00Z",
"run_type": DagRunType.MANUAL,
},
},
)
return httpx.Response(200, json={"text": "Hello, world!"})


Expand Down
37 changes: 37 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel):
extra: Annotated[str | None, Field(title="Extra")] = None


class DagRunType(str, Enum):
"""
Class with DagRun types.
"""

BACKFILL = "backfill"
SCHEDULED = "scheduled"
MANUAL = "manual"
ASSET_TRIGGERED = "asset_triggered"


class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
Expand Down Expand Up @@ -159,10 +170,36 @@ class TaskInstance(BaseModel):
map_index: Annotated[int | None, Field(title="Map Index")] = None


class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
"""

dag_id: Annotated[str, Field(title="Dag Id")]
run_id: Annotated[str, Field(title="Run Id")]
logical_date: Annotated[datetime, Field(title="Logical Date")]
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
start_date: Annotated[datetime, Field(title="Start Date")]
end_date: Annotated[datetime | None, Field(title="End Date")] = None
run_type: DagRunType
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None


class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None


class TIRunContext(BaseModel):
"""
Response schema for TaskInstance run context.
"""

dag_run: DagRun
variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None
connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None


class TITerminalStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).
Expand Down
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
TIRunContext,
VariableResponse,
XComResponse,
)
Expand All @@ -70,6 +71,7 @@ class StartupDetails(BaseModel):
Responses will come back on stdin
"""
ti_context: TIRunContext
type: Literal["StartupDetails"] = "StartupDetails"


Expand Down
3 changes: 2 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
# We've forked, but the task won't start doing anything until we send it the StartupDetails
# message. But before we do that, we need to tell the server it's started (so it has the chance to
# tell us "no, stop!" for any reason)
self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
ti_context = self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
self._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
Expand All @@ -408,6 +408,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
ti=ti,
file=os.fspath(path),
requests_fd=requests_fd,
ti_context=ti_context,
)

# Send the message to tell the process what it needs to execute
Expand Down
Loading

0 comments on commit dbff6e3

Please sign in to comment.