diff --git a/python_modules/dagster/dagster/_core/pipes/subprocess.py b/python_modules/dagster/dagster/_core/pipes/subprocess.py index 4ec12fbb3834e..8d5bd1dc96d03 100644 --- a/python_modules/dagster/dagster/_core/pipes/subprocess.py +++ b/python_modules/dagster/dagster/_core/pipes/subprocess.py @@ -1,4 +1,5 @@ import os +import signal from subprocess import Popen from typing import Mapping, Optional, Sequence, Union @@ -7,7 +8,7 @@ from dagster import _check as check from dagster._annotations import experimental, public from dagster._core.definitions.resource_annotation import ResourceParam -from dagster._core.errors import DagsterPipesExecutionError +from dagster._core.errors import DagsterExecutionInterruptedError, DagsterPipesExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.pipes.client import ( PipesClient, @@ -21,6 +22,8 @@ open_pipes_session, ) +INTERRUPT_FWD_WAIT_SECONDS = 2 + @experimental class _PipesSubprocess(PipesClient): @@ -108,11 +111,19 @@ def run( **pipes_session.get_bootstrap_env_vars(), }, ) - process.wait() - if process.returncode != 0: - raise DagsterPipesExecutionError( - f"External execution process failed with code {process.returncode}" - ) + try: + process.wait() + if process.returncode != 0: + raise DagsterPipesExecutionError( + f"External execution process failed with code {process.returncode}" + ) + except DagsterExecutionInterruptedError: + context.log.info("[pipes] execution interrupted, sending SIGINT to subprocess.") + # send sigint to give external process chance to exit gracefully + process.send_signal(signal.SIGINT) + process.wait(timeout=INTERRUPT_FWD_WAIT_SECONDS) + raise + return PipesClientCompletedInvocation(pipes_session) diff --git a/python_modules/dagster/dagster_tests/execution_tests/pipes_tests/test_subprocess.py b/python_modules/dagster/dagster_tests/execution_tests/pipes_tests/test_subprocess.py index 906ab6cec4f00..d5bd1c5974ef1 100644 --- a/python_modules/dagster/dagster_tests/execution_tests/pipes_tests/test_subprocess.py +++ b/python_modules/dagster/dagster_tests/execution_tests/pipes_tests/test_subprocess.py @@ -3,7 +3,9 @@ import shutil import subprocess import textwrap +import time from contextlib import contextmanager +from multiprocessing import Process from tempfile import NamedTemporaryFile from typing import Any, Callable, Iterator @@ -37,6 +39,7 @@ from dagster._core.errors import DagsterInvariantViolationError, DagsterPipesExecutionError from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext from dagster._core.execution.context.invocation import build_asset_context +from dagster._core.instance import DagsterInstance from dagster._core.instance_for_test import instance_for_test from dagster._core.pipes.subprocess import ( PipesSubprocessClient, @@ -48,6 +51,7 @@ open_pipes_session, ) from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus +from dagster._utils import process_is_alive from dagster._utils.env import environ from dagster_pipes import DagsterPipesError @@ -694,3 +698,66 @@ def bad_msg(context: OpExecutionContext, pipes_client: PipesSubprocessClient): "Object of type Cursed is not JSON serializable" in pipes_events[1].dagster_event.engine_event_data.error.message ) + + +def _execute_job(spin_timeout, subproc_log_path): + def script_fn(): + import os + import time + + from dagster_pipes import open_dagster_pipes + + with open_dagster_pipes() as pipes: + timeout = pipes.get_extra("timeout") + log_path = pipes.get_extra("log_path") + with open(log_path, "w") as f: + f.write(f"{os.getpid()}") + f.flush() + start = time.time() + while time.time() - start < timeout: + ... + + with temp_script(script_fn) as script_path: + + @op + def stalling_pipes_op( + context: OpExecutionContext, + ): + cmd = [_PYTHON_EXECUTABLE, script_path] + PipesSubprocessClient().run( + command=cmd, + context=context, + extras={ + "timeout": spin_timeout, + "log_path": subproc_log_path, + }, + ) + + @job + def pipes_job(): + stalling_pipes_op() + + return pipes_job.execute_in_process( + instance=DagsterInstance.get(), + raise_on_error=False, + ) + + +def test_cancellation(): + spin_timeout = 600 + with instance_for_test(), NamedTemporaryFile() as subproc_log_path: + p = Process(target=_execute_job, args=(spin_timeout, subproc_log_path.name)) + p.start() + pid = None + while p.is_alive(): + data = subproc_log_path.read().decode("utf-8") + if data: + pid = int(data) + time.sleep(0.1) + p.terminate() + break + + p.join(timeout=1) + assert not p.is_alive() + assert pid + assert not process_is_alive(pid)