Skip to content

Commit

Permalink
Refactor pause_flow_run for consistency with engine state handling (#…
Browse files Browse the repository at this point in the history
…7857)

Co-authored-by: Chris Pickett <chris.pickett@prefect.io>
Co-authored-by: Dustin Ngo <dustin.ngo@gmail.com>
  • Loading branch information
3 people committed Feb 3, 2023
1 parent 8d3c824 commit 453c47c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
4 changes: 0 additions & 4 deletions src/prefect/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,3 @@ class NotPausedError(PrefectException):

class FlowPauseTimeout(PrefectException):
"""Raised when a flow pause times out"""


class PausedRun(PrefectSignal):
"""Signal raised when exiting a flow early for nonblocking pauses"""
25 changes: 18 additions & 7 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,23 +553,29 @@ async def test_paused_flows_do_not_block_execution_with_reschedule_flag(
):
frc = partial(FlowRunCreate, deployment_id=deployment.id)
monkeypatch.setattr("prefect.client.orion.schemas.actions.FlowRunCreate", frc)
flow_run_id = None

@task
async def foo():
return 42

@flow(task_runner=SequentialTaskRunner())
async def pausing_flow_without_blocking():
nonlocal flow_run_id
flow_run_id = get_run_context().flow_run.id
x = await foo.submit()
y = await foo.submit()
await pause_flow_run(timeout=20, reschedule=True)
z = await foo(wait_for=[x])
alpha = await foo(wait_for=[y])
omega = await foo(wait_for=[x, y])
assert False, "This line should not be reached"

flow_run_state = await pausing_flow_without_blocking(return_state=True)
assert flow_run_state.is_paused()
flow_run_id = flow_run_state.state_details.flow_run_id
with pytest.raises(Pause):
await pausing_flow_without_blocking(return_state=True)

flow_run = await orion_client.read_flow_run(flow_run_id)
assert flow_run.state.is_paused()
task_runs = await orion_client.read_task_runs(
flow_run_filter=FlowRunFilter(id={"any_": [flow_run_id]})
)
Expand All @@ -594,31 +600,36 @@ async def pausing_flow_without_blocking():
alpha = await foo(wait_for=[y])
omega = await foo(wait_for=[x, y])

with pytest.raises(PausedRun):
with pytest.raises(Pause):
await pausing_flow_without_blocking()

async def test_paused_flows_can_be_resumed_then_rescheduled(
self, orion_client, deployment, monkeypatch
):
frc = partial(FlowRunCreate, deployment_id=deployment.id)
monkeypatch.setattr("prefect.client.orion.schemas.actions.FlowRunCreate", frc)
flow_run_id = None

@task
async def foo():
return 42

@flow(task_runner=SequentialTaskRunner())
async def pausing_flow_without_blocking():
nonlocal flow_run_id
flow_run_id = get_run_context().flow_run.id
x = await foo.submit()
y = await foo.submit()
await pause_flow_run(timeout=20, reschedule=True)
z = await foo(wait_for=[x])
alpha = await foo(wait_for=[y])
omega = await foo(wait_for=[x, y])

flow_run_state = await pausing_flow_without_blocking(return_state=True)
assert flow_run_state.is_paused()
flow_run_id = flow_run_state.state_details.flow_run_id
with pytest.raises(Pause):
await pausing_flow_without_blocking()

flow_run = await orion_client.read_flow_run(flow_run_id)
assert flow_run.state.is_paused()

await resume_flow_run(flow_run_id)
flow_run = await orion_client.read_flow_run(flow_run_id)
Expand Down

0 comments on commit 453c47c

Please sign in to comment.