Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure WatchedSubprocess and CommsDecoder for reuse in DagParsing #44874

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 69 additions & 65 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import signal
import sys
import time
import weakref
from collections.abc import Generator
from contextlib import suppress
from datetime import datetime, timezone
Expand Down Expand Up @@ -64,6 +63,8 @@
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger, WrappedLogger

from airflow.typing_compat import Self


__all__ = ["WatchedSubprocess", "supervise"]

Expand Down Expand Up @@ -263,7 +264,7 @@ def exit(n: int) -> NoReturn:

@attrs.define()
class WatchedSubprocess:
ti_id: UUID
id: UUID
pid: int

stdin: BinaryIO
Expand Down Expand Up @@ -292,20 +293,16 @@ class WatchedSubprocess:

selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector)

procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary()

def __attrs_post_init__(self):
self.procs[self.pid] = self

@classmethod
def start(
cls,
path: str | os.PathLike[str],
ti: TaskInstance,
what: TaskInstance,
client: Client,
target: Callable[[], None] = _subprocess_main,
logger: FilteringBoundLogger | None = None,
) -> WatchedSubprocess:
**constructor_kwargs,
) -> Self:
"""Fork and start a new subprocess to execute the given task."""
# Create socketpairs/"pipes" to connect to the stdin and out from the subprocess
child_stdin, feed_stdin = mkpipe(remote_read=True)
Expand All @@ -324,31 +321,27 @@ def start(
# around in the forked processes, especially things that might involve open files or sockets!
del path
del client
del ti
del what
del logger

# Run the child entrypoint
_fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target)

requests_fd = child_comms.fileno()

# Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the
# other end of the pair open
cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs)

proc = cls(
ti_id=ti.id,
id=constructor_kwargs.get("id") or getattr(what, "id"),
pid=pid,
stdin=feed_stdin,
process=psutil.Process(pid),
client=client,
**constructor_kwargs,
)

# We've forked, but the task won't start until we send it the StartupDetails message. But before we do
# that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any
# reason)
try:
client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc))
proc._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
proc.kill(signal.SIGKILL)
raise

logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind())
proc._register_pipe_readers(
logger=logger,
Expand All @@ -359,11 +352,8 @@ def start(
)

# Tell the task process what it needs to do!
proc._send_startup_message(ti, path, child_comms)
proc._on_child_started(what, path, requests_fd)

# Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the
# other end of the pair open
proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs)
return proc

def _register_pipe_readers(
Expand Down Expand Up @@ -401,12 +391,23 @@ def _close_unused_sockets(*sockets):
for sock in sockets:
sock.close()

def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket):
def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requests_fd: int):
"""Send startup message to the subprocess."""
try:
# We've forked, but the task won't start doing anything until we send it the StartupDetails
# message. But before we do that, we need to tell the server it's started (so it has the chance to
# tell us "no, stop!" for any reason)
self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
self._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
self.kill(signal.SIGKILL)
raise

msg = StartupDetails.model_construct(
ti=ti,
file=str(path),
requests_fd=child_comms.fileno(),
file=os.fspath(path),
requests_fd=requests_fd,
)

# Send the message to tell the process what it needs to execute
Expand Down Expand Up @@ -490,7 +491,7 @@ def wait(self) -> int:
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
return self._exit_code

Expand Down Expand Up @@ -525,9 +526,9 @@ def _monitor_subprocess(self):
# logs
self._send_heartbeat_if_needed()

self._handle_task_overtime_if_needed()
self._handle_process_overtime_if_needed()

def _handle_task_overtime_if_needed(self):
def _handle_process_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the configured overtime."""
# If the task has reached a terminal state, we can start monitoring the overtime
if not self._terminal_state:
Expand All @@ -537,7 +538,7 @@ def _handle_task_overtime_if_needed(self):
self._task_end_time_monotonic
and (time.monotonic() - self._task_end_time_monotonic) > self.TASK_OVERTIME_THRESHOLD
):
log.warning("Task success overtime reached; terminating process", ti_id=self.ti_id)
log.warning("Workload success overtime reached; terminating process", ti_id=self.id)
self.kill(signal.SIGTERM, force=True)

def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False):
Expand Down Expand Up @@ -579,7 +580,7 @@ def _check_subprocess_exit(self, raise_on_timeout: bool = False) -> int | None:
if self._exit_code is None:
try:
self._exit_code = self._process.wait(timeout=0)
log.debug("Task process exited", exit_code=self._exit_code)
log.debug("Workload process exited", exit_code=self._exit_code)
except psutil.TimeoutExpired:
if raise_on_timeout:
raise
Expand All @@ -593,7 +594,7 @@ def _send_heartbeat_if_needed(self):

self._last_heartbeat_attempt = time.monotonic()
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self.client.task_instances.heartbeat(self.id, pid=self._process.pid)
# Update the last heartbeat time on success
self._last_successful_heartbeat = time.monotonic()

Expand All @@ -619,7 +620,7 @@ def _handle_heartbeat_failures(self):
log.warning(
"Failed to send heartbeat. Will be retried",
failed_heartbeats=self.failed_heartbeats,
ti_id=self.ti_id,
ti_id=self.id,
max_retries=MAX_FAILED_HEARTBEATS,
exc_info=True,
)
Expand All @@ -646,15 +647,15 @@ def final_state(self):
return TerminalTIState.FAILED

def __rich_repr__(self):
yield "ti_id", self.ti_id
yield "id", self.id
yield "pid", self.pid
# only include this if it's not the default (third argument)
yield "exit_code", self._exit_code, None

__rich_repr__.angular = True # type: ignore[attr-defined]

def __repr__(self) -> str:
rep = f"<WatchedSubprocess ti_id={self.ti_id} pid={self.pid}"
rep = f"<WatchedSubprocess id={self.id} pid={self.pid}"
if self._exit_code is not None:
rep += f" exit_code={self._exit_code}"
return rep + " >"
Expand All @@ -672,35 +673,38 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

self._handle_request(msg, log)

def _handle_request(self, msg, log):
resp = None
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
resp = None
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
elif isinstance(msg, SetXCom):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
resp = None
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue
elif isinstance(msg, SetXCom):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
resp = None
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
resp = None
else:
log.error("Unhandled request", msg=msg)
return

if resp:
self.stdin.write(resp + b"\n")
if resp:
self.stdin.write(resp + b"\n")


# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read
Expand Down
21 changes: 14 additions & 7 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import sys
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, TextIO
from typing import TYPE_CHECKING, Generic, TextIO, TypeVar

import attrs
import structlog
from pydantic import ConfigDict, TypeAdapter
from pydantic import BaseModel, ConfigDict, TypeAdapter

from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
Expand Down Expand Up @@ -77,17 +77,24 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
return RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=task)


SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)


@attrs.define()
class CommsDecoder:
class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
"""Handle communication between the task in this process and the supervisor parent process."""

input: TextIO

request_socket: FileIO = attrs.field(init=False, default=None)

decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda: TypeAdapter(ToTask))
# We could be "clever" here and set the default to this based type parameters and a custom
# `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a
# "sort of wrong default"
decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)

def get_message(self) -> ToTask:
def get_message(self) -> ReceiveMsgType:
"""
Get a message from the parent.
Expand All @@ -106,7 +113,7 @@ def get_message(self) -> ToTask:
self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)
return msg

def send_request(self, log: Logger, msg: ToSupervisor):
def send_request(self, log: Logger, msg: SendMsgType):
encoded_msg = msg.model_dump_json().encode() + b"\n"

log.debug("Sending request", json=encoded_msg)
Expand All @@ -123,7 +130,7 @@ def send_request(self, log: Logger, msg: ToSupervisor):
# deeply nested execution stack.
# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
# accessible wherever needed during task execution without modifying every layer of the call stack.
SUPERVISOR_COMMS: CommsDecoder
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]

# State machine!
# 1. Start up (receive details from supervisor)
Expand Down
Loading