From bd5b0cf52da396bd06e1f6cc81ed2e813249deaf Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Mon, 12 Dec 2022 19:10:52 -0500 Subject: [PATCH] Implement out-of-process flow run pauses (#7863) Co-authored-by: Zach Angell <42625717+zangell44@users.noreply.github.com> --- src/prefect/engine.py | 57 ++++++++++- src/prefect/exceptions.py | 2 +- src/prefect/orion/api/flow_runs.py | 12 +-- .../orion/orchestration/core_policy.py | 65 ++++++++++--- src/prefect/orion/schemas/states.py | 10 +- src/prefect/states.py | 10 +- tests/fixtures/database.py | 8 ++ tests/orion/api/test_task_runs.py | 12 +++ tests/orion/models/test_task_run_states.py | 15 +++ tests/orion/models/test_task_runs.py | 8 ++ tests/orion/orchestration/test_core_policy.py | 51 ++++++++++ tests/test_engine.py | 95 +++++++++++++++++++ 12 files changed, 315 insertions(+), 30 deletions(-) diff --git a/src/prefect/engine.py b/src/prefect/engine.py index c7da5dace037..b19e6cfd5fd1 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -718,7 +718,11 @@ async def orchestrate_flow_run( @sync_compatible async def pause_flow_run( - timeout: int = 300, poll_interval: int = 10, reschedule=False, key: str = None + flow_run_id: UUID = None, + timeout: int = 300, + poll_interval: int = 10, + reschedule=False, + key: str = None, ): """ Pauses the current flow run by stopping execution until resumed. @@ -729,6 +733,14 @@ async def pause_flow_run( been resumed within the specified time. Args: + flow_run_id: a flow run id. If supplied, this function will attempt to pause + the specified flow run outside of the flow run process. When paused, the + flow run will continue execution until the NEXT task is orchestrated, at + which point the flow will exit. Any tasks that have already started will + run until completion. When resumed, the flow run will be rescheduled to + finish execution. In order pause a flow run in this way, the flow needs to + have an associated deployment and results need to be configured with the + `persist_results` option. timeout: the number of seconds to wait for the flow to be resumed before failing. Defaults to 5 minutes (300 seconds). If the pause timeout exceeds any configured flow-level timeout, the flow might fail even after resuming. @@ -737,12 +749,28 @@ async def pause_flow_run( reschedule: Flag that will reschedule the flow run if resumed. Instead of blocking execution, the flow will gracefully exit (with no result returned) instead. To use this flag, a flow needs to have an associated deployment and - results need to be configured. + results need to be configured with the `persist_results` option. key: An optional key to prevent calling pauses more than once. This defaults to the number of pauses observed by the flow so far, and prevents pauses that use the "reschedule" option from running the same pause twice. A custom key can be supplied for custom pausing behavior. """ + if flow_run_id: + return await _out_of_process_pause( + flow_run_id=flow_run_id, + timeout=timeout, + reschedule=reschedule, + key=key, + ) + else: + return await _in_process_pause( + timeout=timeout, poll_interval=poll_interval, reschedule=reschedule, key=key + ) + + +async def _in_process_pause( + timeout: int = 300, poll_interval: int = 10, reschedule=False, key: str = None +): if TaskRunContext.get(): raise RuntimeError("Cannot pause task runs.") @@ -795,6 +823,26 @@ async def pause_flow_run( raise FlowPauseTimeout("Flow run was paused and never resumed.") +async def _out_of_process_pause( + flow_run_id: UUID, + timeout: int = 300, + reschedule: bool = True, + key: str = None, +): + if reschedule: + raise RuntimeError( + "Pausing a flow run out of process requires the `reschedule` option set to True." + ) + + client = get_client() + response = await client.set_flow_run_state( + flow_run_id, + Paused(timeout_seconds=timeout, reschedule=True, pause_key=key), + ) + if response.status != SetStateStatus.ACCEPT: + raise RuntimeError(response.details.reason) + + @sync_compatible async def resume_flow_run(flow_run_id): """ @@ -1543,6 +1591,9 @@ async def report_task_run_crashes(task_run: TaskRun, client: OrionClient): """ try: yield + except PausedRun: + # Do not capture PausedRuns as crashes + raise except Abort: # Do not capture aborts as crashes raise @@ -1712,6 +1763,8 @@ async def propose_state( ) elif response.status == SetStateStatus.REJECT: + if response.state.is_paused(): + raise PausedRun(response.details.reason) return response.state else: diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index 3d1ef764039d..125dd89edbe9 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -351,5 +351,5 @@ class FlowPauseTimeout(PrefectException): """Raised when a flow pause times out""" -class PausedRun(BaseException): +class PausedRun(PrefectSignal): """Signal raised when exiting a flow early for nonblocking pauses""" diff --git a/src/prefect/orion/api/flow_runs.py b/src/prefect/orion/api/flow_runs.py index 0dfd962036dd..ef9190613540 100644 --- a/src/prefect/orion/api/flow_runs.py +++ b/src/prefect/orion/api/flow_runs.py @@ -191,6 +191,8 @@ async def resume_flow_run( """ Resume a paused flow run. """ + now = pendulum.now() + async with db.session_context(begin_transaction=True) as session: flow_run = await models.flow_runs.read_flow_run(session, flow_run_id) state = flow_run.state @@ -226,13 +228,11 @@ async def resume_flow_run( orchestration_parameters=orchestration_parameters, ) - # only set the 201 when a new state was created - if orchestration_result.status == schemas.responses.SetStateStatus.WAIT: - response.status_code = status.HTTP_200_OK - elif orchestration_result.status == schemas.responses.SetStateStatus.ABORT: - response.status_code = status.HTTP_200_OK - else: + # set the 201 if a new state was created + if orchestration_result.state and orchestration_result.state.timestamp >= now: response.status_code = status.HTTP_201_CREATED + else: + response.status_code = status.HTTP_200_OK return orchestration_result diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index 6d14787b5eb1..170aff138046 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -58,6 +58,7 @@ def priority(): return [ CacheRetrieval, HandleTaskTerminalStateTransitions, + PreventRunningTasksFromStoppedFlows, PreventRedundantTransitions, SecureTaskConcurrencySlots, # retrieve cached states even if slots are full CopyScheduledTime, @@ -95,7 +96,7 @@ class SecureTaskConcurrencySlots(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.RUNNING] + TO_STATES = [StateType.RUNNING] async def before_transition( self, @@ -190,7 +191,7 @@ class CacheInsertion(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.COMPLETED] + TO_STATES = [StateType.COMPLETED] @inject_db async def after_transition( @@ -223,7 +224,7 @@ class CacheRetrieval(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.RUNNING] + TO_STATES = [StateType.RUNNING] @inject_db async def before_transition( @@ -269,8 +270,8 @@ class RetryFailedFlows(BaseOrchestrationRule): instructed to transition into a scheduled state to retry flow execution. """ - FROM_STATES = [states.StateType.RUNNING] - TO_STATES = [states.StateType.FAILED] + FROM_STATES = [StateType.RUNNING] + TO_STATES = [StateType.FAILED] async def before_transition( self, @@ -337,8 +338,8 @@ class RetryFailedTasks(BaseOrchestrationRule): instructed to transition into a scheduled state to retry task execution. """ - FROM_STATES = [states.StateType.RUNNING] - TO_STATES = [states.StateType.FAILED] + FROM_STATES = [StateType.RUNNING] + TO_STATES = [StateType.FAILED] async def before_transition( self, @@ -368,7 +369,7 @@ class RenameReruns(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.RUNNING] + TO_STATES = [StateType.RUNNING] async def before_transition( self, @@ -392,8 +393,8 @@ class CopyScheduledTime(BaseOrchestrationRule): on the scheduled state will be ignored. """ - FROM_STATES = [states.StateType.SCHEDULED] - TO_STATES = [states.StateType.PENDING] + FROM_STATES = [StateType.SCHEDULED] + TO_STATES = [StateType.PENDING] async def before_transition( self, @@ -418,8 +419,8 @@ class WaitForScheduledTime(BaseOrchestrationRule): before attempting the transition again. """ - FROM_STATES = [states.StateType.SCHEDULED, states.StateType.PENDING] - TO_STATES = [states.StateType.RUNNING] + FROM_STATES = [StateType.SCHEDULED, StateType.PENDING] + TO_STATES = [StateType.RUNNING] async def before_transition( self, @@ -446,7 +447,7 @@ class HandlePausingFlows(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.PAUSED] + TO_STATES = [StateType.PAUSED] async def before_transition( self, @@ -504,7 +505,7 @@ class HandleResumingPausedFlows(BaseOrchestrationRule): Governs runs attempting to leave a Paused state """ - FROM_STATES = [states.StateType.PAUSED] + FROM_STATES = [StateType.PAUSED] TO_STATES = ALL_ORCHESTRATION_STATES async def before_transition( @@ -559,7 +560,7 @@ class UpdateFlowRunTrackerOnTasks(BaseOrchestrationRule): """ FROM_STATES = ALL_ORCHESTRATION_STATES - TO_STATES = [states.StateType.RUNNING] + TO_STATES = [StateType.RUNNING] async def after_transition( self, @@ -717,3 +718,37 @@ async def before_transition( await self.abort_transition( reason=f"This run cannot transition to the {proposed_state_type} state from the {initial_state_type} state." ) + + +class PreventRunningTasksFromStoppedFlows(BaseOrchestrationRule): + """ + Prevents running tasks from stopped flows. + + A running state implies execution, but also the converse. This rule ensures that a + flow's tasks cannot be run unless the flow is also running. + """ + + FROM_STATES = ALL_ORCHESTRATION_STATES + TO_STATES = [StateType.RUNNING] + + async def before_transition( + self, + initial_state: Optional[states.State], + proposed_state: Optional[states.State], + context: TaskOrchestrationContext, + ) -> None: + flow_run = await context.flow_run() + if flow_run.state is None: + await self.abort_transition( + reason=f"The enclosing flow must be running to begin task execution." + ) + elif flow_run.state.type == StateType.PAUSED: + await self.reject_transition( + state=states.Paused(name="NotReady"), + reason=f"The flow is paused, new tasks can execute after resuming flow run: {flow_run.id}.", + ) + elif not flow_run.state.type == StateType.RUNNING: + # task runners should abort task run execution + await self.abort_transition( + reason=f"The enclosing flow must be running to begin task execution.", + ) diff --git a/src/prefect/orion/schemas/states.py b/src/prefect/orion/schemas/states.py index f225d9e0942b..d5378563db37 100644 --- a/src/prefect/orion/schemas/states.py +++ b/src/prefect/orion/schemas/states.py @@ -308,9 +308,13 @@ def Paused( "Cannot supply both a pause_expiration_time and timeout_seconds" ) - state_details.pause_timeout = pause_expiration_time or ( - pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds) - ) + if pause_expiration_time is None and timeout_seconds is None: + pass + else: + state_details.pause_timeout = pause_expiration_time or ( + pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds) + ) + state_details.pause_reschedule = reschedule state_details.pause_key = pause_key diff --git a/src/prefect/states.py b/src/prefect/states.py index c730aebb0e3e..e6738f6b27ab 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -529,9 +529,13 @@ def Paused( "Cannot supply both a pause_expiration_time and timeout_seconds" ) - state_details.pause_timeout = pause_expiration_time or ( - pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds) - ) + if pause_expiration_time is None and timeout_seconds is None: + pass + else: + state_details.pause_timeout = pause_expiration_time or ( + pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds) + ) + state_details.pause_reschedule = reschedule state_details.pause_key = pause_key diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 5f65e73148ba..5002efe0bb42 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -518,6 +518,7 @@ async def initializer( run_type, initial_state_type, proposed_state_type, + initial_flow_run_state_type=None, run_override=None, run_tags=None, initial_details=None, @@ -556,6 +557,13 @@ async def initializer( context = FlowOrchestrationContext state_constructor = commit_flow_run_state elif run_type == "task": + if initial_flow_run_state_type: + flow_state_constructor = commit_flow_run_state + await flow_state_constructor( + session, + flow_run, + initial_flow_run_state_type, + ) task_run = await models.task_runs.create_task_run( session=session, task_run=schemas.actions.TaskRunCreate( diff --git a/tests/orion/api/test_task_runs.py b/tests/orion/api/test_task_runs.py index ac7c515d0764..e216f6ec3554 100644 --- a/tests/orion/api/test_task_runs.py +++ b/tests/orion/api/test_task_runs.py @@ -291,6 +291,12 @@ async def test_delete_task_run_returns_404_if_does_not_exist(self, client): class TestSetTaskRunState: async def test_set_task_run_state(self, task_run, client, session): + # first ensure the parent flow run is in a running state + await client.post( + f"/flow_runs/{task_run.flow_run_id}/set_state", + json=dict(state=dict(type="RUNNING")), + ) + response = await client.post( f"/task_runs/{task_run.id}/set_state", json=dict(state=dict(type="RUNNING", name="Test State")), @@ -317,6 +323,12 @@ async def test_setting_task_run_state_twice_aborts( # this test ensures that a 2nd agent cannot re-propose a state that's already # been set + # first ensure the parent flow run is in a running state + await client.post( + f"/flow_runs/{task_run.flow_run_id}/set_state", + json=dict(state=dict(type="RUNNING")), + ) + response = await client.post( f"/task_runs/{task_run.id}/set_state", json=dict(state=dict(type=proposed_state, name="Test State")), diff --git a/tests/orion/models/test_task_run_states.py b/tests/orion/models/test_task_run_states.py index ef323bf339ad..bc26260713e0 100644 --- a/tests/orion/models/test_task_run_states.py +++ b/tests/orion/models/test_task_run_states.py @@ -69,6 +69,14 @@ async def test_run_details_are_updated_entering_running(self, task_run, session) assert task_run.total_run_time == (dt2 - dt) async def test_failed_becomes_awaiting_retry(self, task_run, client, session): + # first ensure the task run's flow run is in a running state + await models.flow_runs.set_flow_run_state( + session=session, + flow_run_id=task_run.flow_run_id, + state=Running(), + force=True, + ) + # set max retries to 1 # copy to trigger ORM updates task_run.empirical_policy = task_run.empirical_policy.copy() @@ -125,6 +133,13 @@ async def test_failed_doesnt_retry_if_flag_set(self, task_run, client, session): async def test_database_is_not_updated_when_no_transition_takes_place( self, task_run, session ): + # first ensure the task run's flow run is in a running state + await models.flow_runs.set_flow_run_state( + session=session, + flow_run_id=task_run.flow_run_id, + state=Running(), + force=True, + ) # place the run in a scheduled state in the future trs = await models.task_runs.set_task_run_state( diff --git a/tests/orion/models/test_task_runs.py b/tests/orion/models/test_task_runs.py index 85b06a75cd79..592011b9fa1c 100644 --- a/tests/orion/models/test_task_runs.py +++ b/tests/orion/models/test_task_runs.py @@ -705,6 +705,14 @@ async def task_run_2(self, session, flow_run): return model async def test_force_releases_concurrency(self, session, task_run_1, task_run_2): + # first set flow runs in a running state + await models.flow_runs.set_flow_run_state( + session=session, flow_run_id=task_run_1.flow_run_id, state=Running() + ) + await models.flow_runs.set_flow_run_state( + session=session, flow_run_id=task_run_2.flow_run_id, state=Running() + ) + await concurrency_limits.create_concurrency_limit( session=session, concurrency_limit=schemas.core.ConcurrencyLimit( diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index bb7d6cde5fb3..d83ce39029c9 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -19,6 +19,7 @@ HandleResumingPausedFlows, HandleTaskTerminalStateTransitions, PreventRedundantTransitions, + PreventRunningTasksFromStoppedFlows, ReleaseTaskConcurrencySlots, RenameReruns, RetryFailedFlows, @@ -2220,3 +2221,53 @@ async def test_marks_flow_run_as_resuming_upon_leaving_paused_state( assert ctx.response_status == SetStateStatus.ACCEPT assert ctx.run.empirical_policy.resuming + + +class TestPreventRunningTasksFromStoppedFlows: + async def test_allows_task_runs_to_run(self, session, initialize_orchestration): + initial_state_type = states.StateType.PENDING + proposed_state_type = states.StateType.RUNNING + intended_transition = (initial_state_type, proposed_state_type) + ctx = await initialize_orchestration( + session, + "task", + *intended_transition, + initial_flow_run_state_type=states.StateType.RUNNING, + ) + + run_preventer = PreventRunningTasksFromStoppedFlows(ctx, *intended_transition) + + async with run_preventer as ctx: + await ctx.validate_proposed_state() + + assert ctx.response_status == SetStateStatus.ACCEPT + assert ctx.validated_state.is_running() + + @pytest.mark.parametrize( + "initial_flow_run_state_type", + sorted(list(set(states.StateType) - {states.StateType.RUNNING})), + ) + async def test_prevents_tasks_From_running( + self, session, initial_flow_run_state_type, initialize_orchestration + ): + initial_state_type = states.StateType.PENDING + proposed_state_type = states.StateType.RUNNING + intended_transition = (initial_state_type, proposed_state_type) + ctx = await initialize_orchestration( + session, + "task", + *intended_transition, + initial_flow_run_state_type=initial_flow_run_state_type, + ) + + run_preventer = PreventRunningTasksFromStoppedFlows(ctx, *intended_transition) + + async with run_preventer as ctx: + await ctx.validate_proposed_state() + + if initial_flow_run_state_type == states.StateType.PAUSED: + assert ctx.response_status == SetStateStatus.REJECT + assert ctx.validated_state.is_paused() + else: + assert ctx.response_status == SetStateStatus.ABORT + assert ctx.validated_state.is_pending() diff --git a/tests/test_engine.py b/tests/test_engine.py index f5a351f37e86..33df7ede2f1a 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -417,6 +417,71 @@ async def pausing_flow_without_blocking(): flow_run_state = await pausing_flow_without_blocking() +class TestOutOfProcessPause: + async def test_flows_can_be_paused_out_of_process( + self, orion_client, deployment, monkeypatch + ): + frc = partial(FlowRunCreate, deployment_id=deployment.id) + monkeypatch.setattr("prefect.client.orion.schemas.actions.FlowRunCreate", frc) + + @task + async def foo(): + return 42 + + # when pausing the flow run with a specific flow run id, `pause_flow_run` + # attempts an out-of-process pause; this continues execution until the NEXT + # task run attempts to start, then gracefully exits + + @flow(task_runner=SequentialTaskRunner()) + async def pausing_flow_without_blocking(): + context = FlowRunContext.get() + x = await foo.submit() + y = await foo.submit() + await pause_flow_run(flow_run_id=context.flow_run.id, timeout=20) + z = await foo(wait_for=[x]) + alpha = await foo(wait_for=[y]) + omega = await foo(wait_for=[x, y]) + + flow_run_state = await pausing_flow_without_blocking(return_state=True) + assert flow_run_state.is_paused() + flow_run_id = flow_run_state.state_details.flow_run_id + task_runs = await orion_client.read_task_runs( + flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]}) + ) + completed_task_runs = list( + filter(lambda tr: tr.state.is_completed(), task_runs) + ) + paused_task_runs = list(filter(lambda tr: tr.state.is_paused(), task_runs)) + assert len(task_runs) == 3, "only three tasks should have tried to run" + assert len(completed_task_runs) == 2, "only two task runs should have completed" + assert ( + len(paused_task_runs) == 1 + ), "one task run should have exited with a paused state" + + async def test_out_of_process_pauses_exit_gracefully( + self, orion_client, deployment, monkeypatch + ): + frc = partial(FlowRunCreate, deployment_id=deployment.id) + monkeypatch.setattr("prefect.client.orion.schemas.actions.FlowRunCreate", frc) + + @task + async def foo(): + return 42 + + @flow(task_runner=SequentialTaskRunner()) + async def pausing_flow_without_blocking(): + context = FlowRunContext.get() + x = await foo.submit() + y = await foo.submit() + await pause_flow_run(flow_run_id=context.flow_run.id, timeout=20) + z = await foo(wait_for=[x]) + alpha = await foo(wait_for=[y]) + omega = await foo(wait_for=[x, y]) + + with pytest.raises(PausedRun): + flow_run_state = await pausing_flow_without_blocking() + + class TestOrchestrateTaskRun: async def test_waits_until_scheduled_start_time( self, @@ -427,6 +492,12 @@ async def test_waits_until_scheduled_start_time( result_factory, monkeypatch, ): + # the flow run must be running prior to running tasks + await orion_client.set_flow_run_state( + flow_run_id=flow_run.id, + state=Running(), + ) + @task def foo(): return 1 @@ -461,6 +532,12 @@ def foo(): async def test_does_not_wait_for_scheduled_time_in_past( self, orion_client, flow_run, mock_anyio_sleep, result_factory, local_filesystem ): + # the flow run must be running prior to running tasks + await orion_client.set_flow_run_state( + flow_run_id=flow_run.id, + state=Running(), + ) + @task def foo(): return 1 @@ -495,6 +572,12 @@ def foo(): async def test_waits_for_awaiting_retry_scheduled_time( self, mock_anyio_sleep, orion_client, flow_run, result_factory, local_filesystem ): + # the flow run must be running prior to running tasks + await orion_client.set_flow_run_state( + flow_run_id=flow_run.id, + state=Running(), + ) + # Define a task that fails once and then succeeds mock = MagicMock() @@ -616,6 +699,12 @@ def my_task(x): async def test_quoted_parameters_are_resolved( self, orion_client, flow_run, result_factory, local_filesystem ): + # the flow run must be running prior to running tasks + await orion_client.set_flow_run_state( + flow_run_id=flow_run.id, + state=Running(), + ) + # Define a mock to ensure the task was not run mock = MagicMock() @@ -661,6 +750,12 @@ async def test_states_in_parameters_can_be_incomplete_if_quoted( result_factory, local_filesystem, ): + # the flow run must be running prior to running tasks + await orion_client.set_flow_run_state( + flow_run_id=flow_run.id, + state=Running(), + ) + # Define a mock to ensure the task was not run mock = MagicMock()