Skip to content

Commit

Permalink
[pipes] subprocess termination forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
alangenfeld committed Dec 13, 2023
1 parent 69dc264 commit f8fd3ea
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 6 deletions.
23 changes: 17 additions & 6 deletions python_modules/dagster/dagster/_core/pipes/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import signal
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union

Expand All @@ -7,7 +8,7 @@
from dagster import _check as check
from dagster._annotations import experimental, public
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.errors import DagsterPipesExecutionError
from dagster._core.errors import DagsterExecutionInterruptedError, DagsterPipesExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClient,
Expand All @@ -21,6 +22,8 @@
open_pipes_session,
)

INTERRUPT_FWD_WAIT_SECONDS = 2


@experimental
class _PipesSubprocess(PipesClient):
Expand Down Expand Up @@ -108,11 +111,19 @@ def run(
**pipes_session.get_bootstrap_env_vars(),
},
)
process.wait()
if process.returncode != 0:
raise DagsterPipesExecutionError(
f"External execution process failed with code {process.returncode}"
)
try:
process.wait()
if process.returncode != 0:
raise DagsterPipesExecutionError(
f"External execution process failed with code {process.returncode}"
)
except DagsterExecutionInterruptedError:
context.log.info("[pipes] execution interrupted, sending SIGINT to subprocess.")
# send sigint to give external process chance to exit gracefully
process.send_signal(signal.SIGINT)
process.wait(timeout=INTERRUPT_FWD_WAIT_SECONDS)
raise

return PipesClientCompletedInvocation(pipes_session)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import shutil
import subprocess
import textwrap
import time
from contextlib import contextmanager
from multiprocessing import Process
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Iterator

Expand Down Expand Up @@ -37,6 +39,7 @@
from dagster._core.errors import DagsterInvariantViolationError, DagsterPipesExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.instance import DagsterInstance
from dagster._core.instance_for_test import instance_for_test
from dagster._core.pipes.subprocess import (
PipesSubprocessClient,
Expand All @@ -48,6 +51,7 @@
open_pipes_session,
)
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster._utils import process_is_alive
from dagster._utils.env import environ
from dagster_pipes import DagsterPipesError

Expand Down Expand Up @@ -694,3 +698,66 @@ def bad_msg(context: OpExecutionContext, pipes_client: PipesSubprocessClient):
"Object of type Cursed is not JSON serializable"
in pipes_events[1].dagster_event.engine_event_data.error.message
)


def _execute_job(spin_timeout, subproc_log_path):
def script_fn():
import os
import time

from dagster_pipes import open_dagster_pipes

with open_dagster_pipes() as pipes:
timeout = pipes.get_extra("timeout")
log_path = pipes.get_extra("log_path")
with open(log_path, "w") as f:
f.write(f"{os.getpid()}")
f.flush()
start = time.time()
while time.time() - start < timeout:
...

with temp_script(script_fn) as script_path:

@op
def stalling_pipes_op(
context: OpExecutionContext,
):
cmd = [_PYTHON_EXECUTABLE, script_path]
PipesSubprocessClient().run(
command=cmd,
context=context,
extras={
"timeout": spin_timeout,
"log_path": subproc_log_path,
},
)

@job
def pipes_job():
stalling_pipes_op()

return pipes_job.execute_in_process(
instance=DagsterInstance.get(),
raise_on_error=False,
)


def test_cancellation():
spin_timeout = 600
with instance_for_test(), NamedTemporaryFile() as subproc_log_path:
p = Process(target=_execute_job, args=(spin_timeout, subproc_log_path.name))
p.start()
pid = None
while p.is_alive():
data = subproc_log_path.read().decode("utf-8")
if data:
pid = int(data)
time.sleep(0.1)
p.terminate()
break

p.join(timeout=1)
assert not p.is_alive()
assert pid
assert not process_is_alive(pid)

0 comments on commit f8fd3ea

Please sign in to comment.