diff --git a/src/prefect/agent.py b/src/prefect/agent.py index c40c3e6fb359..e409780edf20 100644 --- a/src/prefect/agent.py +++ b/src/prefect/agent.py @@ -370,6 +370,7 @@ async def _mark_flow_run_as_cancelled( ) -> None: state_updates = state_updates or {} state_updates.setdefault("name", "Cancelled") + state_updates.setdefault("type", StateType.CANCELLED) state = flow_run.state.copy(update=state_updates) await self.client.set_flow_run_state(flow_run.id, state, force=True) diff --git a/tests/agent/test_agent_run_cancellation.py b/tests/agent/test_agent_run_cancellation.py index 1e6d2bec1d92..02b31fdb7905 100644 --- a/tests/agent/test_agent_run_cancellation.py +++ b/tests/agent/test_agent_run_cancellation.py @@ -11,7 +11,15 @@ from prefect.infrastructure.base import Infrastructure from prefect.orion.database.orm_models import ORMDeployment from prefect.orion.schemas.core import Deployment -from prefect.states import Cancelled, Cancelling, Completed, Pending, Running, Scheduled +from prefect.states import ( + Cancelled, + Cancelling, + Completed, + Pending, + Running, + Scheduled, + StateType, +) from prefect.testing.utilities import AsyncMock from prefect.utilities.dispatch import get_registry_for_type @@ -286,7 +294,7 @@ async def test_agent_cancel_run_with_missing_infrastructure_pid( @pytest.mark.parametrize( "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] ) -async def test_agent_cancel_run_updates_state_name( +async def test_agent_cancel_run_updates_state_type( orion_client: OrionClient, deployment: ORMDeployment, cancelling_constructor, @@ -304,7 +312,7 @@ async def test_agent_cancel_run_updates_state_name( await agent.check_for_cancelled_flow_runs() post_flow_run = await orion_client.read_flow_run(flow_run.id) - assert post_flow_run.state.name == "Cancelled" + assert post_flow_run.state.type == StateType.CANCELLED @pytest.mark.usefixtures("mock_infrastructure_kill") @@ -316,7 +324,7 @@ async def test_agent_cancel_run_preserves_other_state_properties( deployment: ORMDeployment, cancelling_constructor, ): - expected_changed_fields = {"name", "timestamp", "id"} + expected_changed_fields = {"type", "name", "timestamp", "id"} flow_run = await orion_client.create_flow_run_from_deployment( deployment.id,