Skip to content

Commit

Permalink
Implement out-of-process flow run pauses (#7863)
Browse files Browse the repository at this point in the history
Co-authored-by: Zach Angell <42625717+zangell44@users.noreply.github.com>
  • Loading branch information
2 people authored and zanieb committed Dec 15, 2022
1 parent 1e1316e commit 88ccfa7
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 30 deletions.
57 changes: 55 additions & 2 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,11 @@ async def orchestrate_flow_run(

@sync_compatible
async def pause_flow_run(
timeout: int = 300, poll_interval: int = 10, reschedule=False, key: str = None
flow_run_id: UUID = None,
timeout: int = 300,
poll_interval: int = 10,
reschedule=False,
key: str = None,
):
"""
Pauses the current flow run by stopping execution until resumed.
Expand All @@ -729,6 +733,14 @@ async def pause_flow_run(
been resumed within the specified time.
Args:
flow_run_id: a flow run id. If supplied, this function will attempt to pause
the specified flow run outside of the flow run process. When paused, the
flow run will continue execution until the NEXT task is orchestrated, at
which point the flow will exit. Any tasks that have already started will
run until completion. When resumed, the flow run will be rescheduled to
finish execution. In order pause a flow run in this way, the flow needs to
have an associated deployment and results need to be configured with the
`persist_results` option.
timeout: the number of seconds to wait for the flow to be resumed before
failing. Defaults to 5 minutes (300 seconds). If the pause timeout exceeds
any configured flow-level timeout, the flow might fail even after resuming.
Expand All @@ -737,12 +749,28 @@ async def pause_flow_run(
reschedule: Flag that will reschedule the flow run if resumed. Instead of
blocking execution, the flow will gracefully exit (with no result returned)
instead. To use this flag, a flow needs to have an associated deployment and
results need to be configured.
results need to be configured with the `persist_results` option.
key: An optional key to prevent calling pauses more than once. This defaults to
the number of pauses observed by the flow so far, and prevents pauses that
use the "reschedule" option from running the same pause twice. A custom key
can be supplied for custom pausing behavior.
"""
if flow_run_id:
return await _out_of_process_pause(
flow_run_id=flow_run_id,
timeout=timeout,
reschedule=reschedule,
key=key,
)
else:
return await _in_process_pause(
timeout=timeout, poll_interval=poll_interval, reschedule=reschedule, key=key
)


async def _in_process_pause(
timeout: int = 300, poll_interval: int = 10, reschedule=False, key: str = None
):
if TaskRunContext.get():
raise RuntimeError("Cannot pause task runs.")

Expand Down Expand Up @@ -795,6 +823,26 @@ async def pause_flow_run(
raise FlowPauseTimeout("Flow run was paused and never resumed.")


async def _out_of_process_pause(
flow_run_id: UUID,
timeout: int = 300,
reschedule: bool = True,
key: str = None,
):
if reschedule:
raise RuntimeError(
"Pausing a flow run out of process requires the `reschedule` option set to True."
)

client = get_client()
response = await client.set_flow_run_state(
flow_run_id,
Paused(timeout_seconds=timeout, reschedule=True, pause_key=key),
)
if response.status != SetStateStatus.ACCEPT:
raise RuntimeError(response.details.reason)


@sync_compatible
async def resume_flow_run(flow_run_id):
"""
Expand Down Expand Up @@ -1543,6 +1591,9 @@ async def report_task_run_crashes(task_run: TaskRun, client: OrionClient):
"""
try:
yield
except PausedRun:
# Do not capture PausedRuns as crashes
raise
except Abort:
# Do not capture aborts as crashes
raise
Expand Down Expand Up @@ -1712,6 +1763,8 @@ async def propose_state(
)

elif response.status == SetStateStatus.REJECT:
if response.state.is_paused():
raise PausedRun(response.details.reason)
return response.state

else:
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,5 @@ class FlowPauseTimeout(PrefectException):
"""Raised when a flow pause times out"""


class PausedRun(BaseException):
class PausedRun(PrefectSignal):
"""Signal raised when exiting a flow early for nonblocking pauses"""
12 changes: 6 additions & 6 deletions src/prefect/orion/api/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ async def resume_flow_run(
"""
Resume a paused flow run.
"""
now = pendulum.now()

async with db.session_context(begin_transaction=True) as session:
flow_run = await models.flow_runs.read_flow_run(session, flow_run_id)
state = flow_run.state
Expand Down Expand Up @@ -226,13 +228,11 @@ async def resume_flow_run(
orchestration_parameters=orchestration_parameters,
)

# only set the 201 when a new state was created
if orchestration_result.status == schemas.responses.SetStateStatus.WAIT:
response.status_code = status.HTTP_200_OK
elif orchestration_result.status == schemas.responses.SetStateStatus.ABORT:
response.status_code = status.HTTP_200_OK
else:
# set the 201 if a new state was created
if orchestration_result.state and orchestration_result.state.timestamp >= now:
response.status_code = status.HTTP_201_CREATED
else:
response.status_code = status.HTTP_200_OK

return orchestration_result

Expand Down
65 changes: 50 additions & 15 deletions src/prefect/orion/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def priority():
return [
CacheRetrieval,
HandleTaskTerminalStateTransitions,
PreventRunningTasksFromStoppedFlows,
PreventRedundantTransitions,
SecureTaskConcurrencySlots, # retrieve cached states even if slots are full
CopyScheduledTime,
Expand Down Expand Up @@ -95,7 +96,7 @@ class SecureTaskConcurrencySlots(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.RUNNING]
TO_STATES = [StateType.RUNNING]

async def before_transition(
self,
Expand Down Expand Up @@ -190,7 +191,7 @@ class CacheInsertion(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.COMPLETED]
TO_STATES = [StateType.COMPLETED]

@inject_db
async def after_transition(
Expand Down Expand Up @@ -223,7 +224,7 @@ class CacheRetrieval(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.RUNNING]
TO_STATES = [StateType.RUNNING]

@inject_db
async def before_transition(
Expand Down Expand Up @@ -269,8 +270,8 @@ class RetryFailedFlows(BaseOrchestrationRule):
instructed to transition into a scheduled state to retry flow execution.
"""

FROM_STATES = [states.StateType.RUNNING]
TO_STATES = [states.StateType.FAILED]
FROM_STATES = [StateType.RUNNING]
TO_STATES = [StateType.FAILED]

async def before_transition(
self,
Expand Down Expand Up @@ -337,8 +338,8 @@ class RetryFailedTasks(BaseOrchestrationRule):
instructed to transition into a scheduled state to retry task execution.
"""

FROM_STATES = [states.StateType.RUNNING]
TO_STATES = [states.StateType.FAILED]
FROM_STATES = [StateType.RUNNING]
TO_STATES = [StateType.FAILED]

async def before_transition(
self,
Expand Down Expand Up @@ -368,7 +369,7 @@ class RenameReruns(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.RUNNING]
TO_STATES = [StateType.RUNNING]

async def before_transition(
self,
Expand All @@ -392,8 +393,8 @@ class CopyScheduledTime(BaseOrchestrationRule):
on the scheduled state will be ignored.
"""

FROM_STATES = [states.StateType.SCHEDULED]
TO_STATES = [states.StateType.PENDING]
FROM_STATES = [StateType.SCHEDULED]
TO_STATES = [StateType.PENDING]

async def before_transition(
self,
Expand All @@ -418,8 +419,8 @@ class WaitForScheduledTime(BaseOrchestrationRule):
before attempting the transition again.
"""

FROM_STATES = [states.StateType.SCHEDULED, states.StateType.PENDING]
TO_STATES = [states.StateType.RUNNING]
FROM_STATES = [StateType.SCHEDULED, StateType.PENDING]
TO_STATES = [StateType.RUNNING]

async def before_transition(
self,
Expand All @@ -446,7 +447,7 @@ class HandlePausingFlows(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.PAUSED]
TO_STATES = [StateType.PAUSED]

async def before_transition(
self,
Expand Down Expand Up @@ -504,7 +505,7 @@ class HandleResumingPausedFlows(BaseOrchestrationRule):
Governs runs attempting to leave a Paused state
"""

FROM_STATES = [states.StateType.PAUSED]
FROM_STATES = [StateType.PAUSED]
TO_STATES = ALL_ORCHESTRATION_STATES

async def before_transition(
Expand Down Expand Up @@ -559,7 +560,7 @@ class UpdateFlowRunTrackerOnTasks(BaseOrchestrationRule):
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [states.StateType.RUNNING]
TO_STATES = [StateType.RUNNING]

async def after_transition(
self,
Expand Down Expand Up @@ -717,3 +718,37 @@ async def before_transition(
await self.abort_transition(
reason=f"This run cannot transition to the {proposed_state_type} state from the {initial_state_type} state."
)


class PreventRunningTasksFromStoppedFlows(BaseOrchestrationRule):
"""
Prevents running tasks from stopped flows.
A running state implies execution, but also the converse. This rule ensures that a
flow's tasks cannot be run unless the flow is also running.
"""

FROM_STATES = ALL_ORCHESTRATION_STATES
TO_STATES = [StateType.RUNNING]

async def before_transition(
self,
initial_state: Optional[states.State],
proposed_state: Optional[states.State],
context: TaskOrchestrationContext,
) -> None:
flow_run = await context.flow_run()
if flow_run.state is None:
await self.abort_transition(
reason=f"The enclosing flow must be running to begin task execution."
)
elif flow_run.state.type == StateType.PAUSED:
await self.reject_transition(
state=states.Paused(name="NotReady"),
reason=f"The flow is paused, new tasks can execute after resuming flow run: {flow_run.id}.",
)
elif not flow_run.state.type == StateType.RUNNING:
# task runners should abort task run execution
await self.abort_transition(
reason=f"The enclosing flow must be running to begin task execution.",
)
10 changes: 7 additions & 3 deletions src/prefect/orion/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,13 @@ def Paused(
"Cannot supply both a pause_expiration_time and timeout_seconds"
)

state_details.pause_timeout = pause_expiration_time or (
pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds)
)
if pause_expiration_time is None and timeout_seconds is None:
pass
else:
state_details.pause_timeout = pause_expiration_time or (
pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds)
)

state_details.pause_reschedule = reschedule
state_details.pause_key = pause_key

Expand Down
10 changes: 7 additions & 3 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,13 @@ def Paused(
"Cannot supply both a pause_expiration_time and timeout_seconds"
)

state_details.pause_timeout = pause_expiration_time or (
pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds)
)
if pause_expiration_time is None and timeout_seconds is None:
pass
else:
state_details.pause_timeout = pause_expiration_time or (
pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds)
)

state_details.pause_reschedule = reschedule
state_details.pause_key = pause_key

Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ async def initializer(
run_type,
initial_state_type,
proposed_state_type,
initial_flow_run_state_type=None,
run_override=None,
run_tags=None,
initial_details=None,
Expand Down Expand Up @@ -556,6 +557,13 @@ async def initializer(
context = FlowOrchestrationContext
state_constructor = commit_flow_run_state
elif run_type == "task":
if initial_flow_run_state_type:
flow_state_constructor = commit_flow_run_state
await flow_state_constructor(
session,
flow_run,
initial_flow_run_state_type,
)
task_run = await models.task_runs.create_task_run(
session=session,
task_run=schemas.actions.TaskRunCreate(
Expand Down
12 changes: 12 additions & 0 deletions tests/orion/api/test_task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ async def test_delete_task_run_returns_404_if_does_not_exist(self, client):

class TestSetTaskRunState:
async def test_set_task_run_state(self, task_run, client, session):
# first ensure the parent flow run is in a running state
await client.post(
f"/flow_runs/{task_run.flow_run_id}/set_state",
json=dict(state=dict(type="RUNNING")),
)

response = await client.post(
f"/task_runs/{task_run.id}/set_state",
json=dict(state=dict(type="RUNNING", name="Test State")),
Expand All @@ -317,6 +323,12 @@ async def test_setting_task_run_state_twice_aborts(
# this test ensures that a 2nd agent cannot re-propose a state that's already
# been set

# first ensure the parent flow run is in a running state
await client.post(
f"/flow_runs/{task_run.flow_run_id}/set_state",
json=dict(state=dict(type="RUNNING")),
)

response = await client.post(
f"/task_runs/{task_run.id}/set_state",
json=dict(state=dict(type=proposed_state, name="Test State")),
Expand Down
Loading

0 comments on commit 88ccfa7

Please sign in to comment.