diff --git a/docs/user/reference/openapi.yaml b/docs/user/reference/openapi.yaml index b0a79193b..1bca4e3c6 100644 --- a/docs/user/reference/openapi.yaml +++ b/docs/user/reference/openapi.yaml @@ -95,6 +95,10 @@ components: type: boolean new_state: $ref: '#/components/schemas/WorkerState' + reason: + description: The reason for the current run to be aborted + title: Reason + type: string required: - new_state title: StateChangeRequest @@ -292,6 +296,29 @@ paths: description: Validation Error summary: Submit Task /tasks/{task_id}: + delete: + operationId: delete_submitted_task_tasks__task_id__delete + parameters: + - in: path + name: task_id + required: true + schema: + title: Task Id + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/TaskResponse' + description: Successful Response + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Delete Submitted Task get: description: Retrieve a task operationId: get_task_tasks__task_id__get @@ -329,17 +356,20 @@ paths: description: Successful Response summary: Get State put: - description: 'Request that the worker is put into a particular state. - - Returns the state of the worker at the end of the call. - - If the worker is PAUSED, new_state may be RUNNING to resume. - - If the worker is RUNNING, new_state may be PAUSED to pause and - - defer may be True to defer the pause until the new checkpoint. - - All other values of new_state will result in 400 "Bad Request"' + description: "Request that the worker is put into a particular state.\nReturns\ + \ the state of the worker at the end of the call.\n\n- **The following transitions\ + \ are allowed and return 202: Accepted**\n- If the worker is **PAUSED**, new_state\ + \ may be **RUNNING** to resume.\n- If the worker is **RUNNING**, new_state\ + \ may be **PAUSED** to pause:\n - If defer is False (default): pauses and\ + \ rewinds to the previous checkpoint\n - If defer is True: waits until\ + \ the next checkpoint to pause\n - **If the task has no checkpoints, the\ + \ task will instead be Aborted**\n- If the worker is **RUNNING/PAUSED**, new_state\ + \ may be **STOPPING** to stop.\n Stop marks any currently open Runs in\ + \ the Task as a success and ends the task.\n- If the worker is **RUNNING/PAUSED**,\ + \ new_state may be **ABORTING** to abort.\n Abort marks any currently open\ + \ Runs in the Task as a Failure and ends the task.\n - If reason is set,\ + \ the reason will be passed as the reason for the Run failure.\n- **All other\ + \ transitions return 400: Bad Request**" operationId: set_state_worker_state_put requestBody: content: @@ -349,14 +379,14 @@ paths: required: true responses: '202': - description: Accepted - detail: Transition requested - '400': content: application/json: schema: $ref: '#/components/schemas/WorkerState' description: Successful Response + detail: Transition requested + '400': + description: Bad Request detail: Transition not allowed '422': content: diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index f92b5d76d..f018be586 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -202,6 +202,32 @@ def resume(obj: dict) -> None: pprint(client.set_state(WorkerState.RUNNING)) +@controller.command(name="abort") +@check_connection +@click.argument("reason", type=str, required=False) +@click.pass_obj +def abort(obj: dict, reason: Optional[str] = None) -> None: + """ + Abort the execution of the current task, marking any ongoing runs as failed, + with optional reason + """ + + client: BlueapiRestClient = obj["rest_client"] + pprint(client.cancel_current_task(state=WorkerState.ABORTING, reason=reason)) + + +@controller.command(name="stop") +@check_connection +@click.pass_obj +def stop(obj: dict) -> None: + """ + Stop the execution of the current task, marking as ongoing runs as success + """ + + client: BlueapiRestClient = obj["rest_client"] + pprint(client.cancel_current_task(state=WorkerState.STOPPING)) + + # helper function def process_event_after_finished(event: WorkerEvent, logger: logging.Logger): if event.is_error(): diff --git a/src/blueapi/cli/rest.py b/src/blueapi/cli/rest.py index de9c12427..48dcc1e7a 100644 --- a/src/blueapi/cli/rest.py +++ b/src/blueapi/cli/rest.py @@ -72,6 +72,11 @@ def create_task(self, task: RunPlan) -> TaskResponse: data=task.dict(), ) + def delete_task(self, task_id: str) -> TaskResponse: + return self._request_and_deserialize( + f"/tasks/{task_id}", TaskResponse, method="DELETE" + ) + def update_worker_task(self, task: WorkerTask) -> WorkerTask: return self._request_and_deserialize( "/worker/task", @@ -80,6 +85,18 @@ def update_worker_task(self, task: WorkerTask) -> WorkerTask: data=task.dict(), ) + def cancel_current_task( + self, + state: Literal[WorkerState.ABORTING, WorkerState.STOPPING], + reason: Optional[str] = None, + ): + return self._request_and_deserialize( + "/worker/state", + target_type=WorkerState, + method="PUT", + data={"new_state": state, "reason": reason}, + ) + def _request_and_deserialize( self, suffix: str, diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c67d6db6d..697cdd2cf 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -4,6 +4,7 @@ from fastapi import Body, Depends, FastAPI, HTTPException, Request, Response, status from pydantic import ValidationError from starlette.responses import JSONResponse +from super_state_machine.errors import TransitionError from blueapi.config import ApplicationConfig from blueapi.worker import RunPlan, TrackableTask, WorkerState @@ -108,6 +109,14 @@ def submit_task( ) +@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) +def delete_submitted_task( + task_id: str, + handler: Handler = Depends(get_handler), +) -> TaskResponse: + return TaskResponse(task_id=handler.worker.clear_task(task_id)) + + @app.put( "/worker/task", response_model=WorkerTask, @@ -156,14 +165,22 @@ def get_state(handler: Handler = Depends(get_handler)) -> WorkerState: # Map of current_state: allowed new_states _ALLOWED_TRANSITIONS: Dict[WorkerState, Set[WorkerState]] = { - WorkerState.RUNNING: {WorkerState.PAUSED}, - WorkerState.PAUSED: {WorkerState.RUNNING}, + WorkerState.RUNNING: { + WorkerState.PAUSED, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, + WorkerState.PAUSED: { + WorkerState.RUNNING, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, } @app.put( "/worker/state", - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_202_ACCEPTED, responses={ status.HTTP_400_BAD_REQUEST: {"detail": "Transition not allowed"}, status.HTTP_202_ACCEPTED: {"detail": "Transition requested"}, @@ -177,10 +194,19 @@ def set_state( """ Request that the worker is put into a particular state. Returns the state of the worker at the end of the call. - If the worker is PAUSED, new_state may be RUNNING to resume. - If the worker is RUNNING, new_state may be PAUSED to pause and - defer may be True to defer the pause until the new checkpoint. - All other values of new_state will result in 400 "Bad Request" + + - **The following transitions are allowed and return 202: Accepted** + - If the worker is **PAUSED**, new_state may be **RUNNING** to resume. + - If the worker is **RUNNING**, new_state may be **PAUSED** to pause: + - If defer is False (default): pauses and rewinds to the previous checkpoint + - If defer is True: waits until the next checkpoint to pause + - **If the task has no checkpoints, the task will instead be Aborted** + - If the worker is **RUNNING/PAUSED**, new_state may be **STOPPING** to stop. + Stop marks any currently open Runs in the Task as a success and ends the task. + - If the worker is **RUNNING/PAUSED**, new_state may be **ABORTING** to abort. + Abort marks any currently open Runs in the Task as a Failure and ends the task. + - If reason is set, the reason will be passed as the reason for the Run failure. + - **All other transitions return 400: Bad Request** """ current_state = handler.worker.state new_state = state_change_request.new_state @@ -188,11 +214,21 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): - response.status_code = status.HTTP_202_ACCEPTED if new_state == WorkerState.PAUSED: handler.worker.pause(defer=state_change_request.defer) elif new_state == WorkerState.RUNNING: handler.worker.resume() + elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + try: + handler.worker.cancel_active_task( + state_change_request.new_state is WorkerState.ABORTING, + state_change_request.reason, + ) + except TransitionError: + response.status_code = status.HTTP_400_BAD_REQUEST + else: + response.status_code = status.HTTP_400_BAD_REQUEST + return handler.worker.state diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 52c567d2f..e7e25b894 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -94,7 +94,7 @@ class WorkerTask(BlueapiBaseModel): ) @classmethod - def of_worker(self, worker: Worker) -> "WorkerTask": + def of_worker(cls, worker: Worker) -> "WorkerTask": active = worker.get_active_task() if active is not None: return WorkerTask(task_id=active.task_id) @@ -112,3 +112,7 @@ class StateChangeRequest(BlueapiBaseModel): description="Should worker defer Pausing until the next checkpoint", default=False, ) + reason: Optional[str] = Field( + description="The reason for the current run to be aborted", + default=None, + ) diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 676d234d4..7444abd79 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Union from bluesky.protocols import Status +from super_state_machine.errors import TransitionError from blueapi.core import ( BlueskyContext, @@ -90,12 +91,24 @@ def __init__( self._stopped = Event() self._stopped.set() - def clear_task(self, task_id: str) -> bool: - if task_id in self._pending_tasks: - del self._pending_tasks[task_id] - return True + def clear_task(self, task_id: str) -> str: + task = self._pending_tasks.pop(task_id) + return task.task_id + + def cancel_active_task( + self, + failure: bool = False, + reason: Optional[str] = None, + ) -> str: + if self._current is None: + # Persuades mypy that self._current is not None + # We only allow this method to be called if a Plan is active + raise TransitionError("Attempted to cancel while no active Task") + if failure: + self._ctx.run_engine.abort(reason) else: - return False + self._ctx.run_engine.stop() + return self._current.task_id def get_pending_tasks(self) -> List[TrackableTask[Task]]: return list(self._pending_tasks.values()) diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index 7e298de74..5350a8cbc 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -63,14 +63,25 @@ def get_active_task(self) -> Optional[TrackableTask[T]]: """ @abstractmethod - def clear_task(self, task_id: str) -> bool: + def clear_task(self, task_id: str) -> str: """ Remove a pending task from the worker Args: task_id: The ID of the task to be removed Returns: - bool: True if the task existed in the first place + task_id of the removed task + """ + + @abstractmethod + def cancel_active_task( + self, + failure: bool = False, + reason: Optional[str] = None, + ) -> str: + """ + Remove the currently active task from the worker if there is one + Returns the task_id of the active task """ @abstractmethod diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index b5df21e47..378f60a3d 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest from bluesky.run_engine import RunEngineStateMachine @@ -256,3 +256,130 @@ def test_pause_and_resume(mockable_state_machine: Handler, client: TestClient) - assert re.request_pause.call_count == 1 # type: ignore assert re.resume.call_count == 1 # type: ignore assert client.get("/worker/state").text == f'"{WorkerState.RUNNING.name}"' + + +def test_clear_pending_task_no_longer_pending(handler: Handler, client: TestClient): + response = client.post("/tasks", json=_TASK.dict()) + task_id = response.json()["task_id"] + + pending = handler.worker.get_pending_task(task_id) + assert pending is not None + assert pending.task == _TASK + + delete_response = client.delete(f"/tasks/{task_id}") + assert delete_response.status_code is status.HTTP_200_OK + assert not handler.worker.get_pending_tasks() + assert handler.worker.get_pending_task(task_id) is None + + +def test_clear_not_pending_task_not_found(handler: Handler, client: TestClient): + response = client.post("/tasks", json=_TASK.dict()) + task_id = response.json()["task_id"] + + pending = handler.worker.get_pending_task(task_id) + assert pending is not None + assert pending.task == _TASK + + delete_response = client.delete("/tasks/wrong-task-id") + assert delete_response.status_code is status.HTTP_404_NOT_FOUND + pending = handler.worker.get_pending_task(task_id) + assert pending is not None + assert pending.task == _TASK + + +def test_clear_when_empty(handler: Handler, client: TestClient): + pending = handler.worker.get_pending_tasks() + assert not pending + + delete_response = client.delete("/tasks/wrong-task-id") + assert delete_response.status_code is status.HTTP_404_NOT_FOUND + assert not handler.worker.get_pending_tasks() + + +@pytest.mark.parametrize( + "worker_state,stops,aborts", + [(WorkerState.STOPPING, 1, 0), (WorkerState.ABORTING, 0, 1)], +) +def test_delete_running_task( + mockable_state_machine: Handler, + client: TestClient, + worker_state: WorkerState, + stops: int, + aborts: int, +): + stop = mockable_state_machine.context.run_engine.stop = MagicMock() # type: ignore + abort = ( + mockable_state_machine.context.run_engine.abort # type: ignore + ) = MagicMock() + + def start_task(_: str): + mockable_state_machine.worker._current = ( # type: ignore + mockable_state_machine.worker.get_pending_task(task_id) + ) + mockable_state_machine.worker._on_state_change( # type: ignore + RunEngineStateMachine.States.RUNNING + ) + + mockable_state_machine.worker.begin_task = start_task # type: ignore + response = client.post("/tasks", json=_TASK.dict()) + task_id = response.json()["task_id"] + + task_json = {"task_id": task_id} + client.put("/worker/task", json=task_json) + + active_task = mockable_state_machine.worker.get_active_task() + assert active_task is not None + assert active_task.task_id == task_id + + response = client.put("/worker/state", json={"new_state": worker_state.name}) + assert response.status_code is status.HTTP_202_ACCEPTED + assert stop.call_count is stops + assert abort.call_count is aborts + + +def test_reason_passed_to_abort(mockable_state_machine: Handler, client: TestClient): + abort = ( + mockable_state_machine.context.run_engine.abort # type: ignore + ) = MagicMock() + + def start_task(_: str): + mockable_state_machine.worker._current = ( # type: ignore + mockable_state_machine.worker.get_pending_task(task_id) + ) + mockable_state_machine.worker._on_state_change( # type: ignore + RunEngineStateMachine.States.RUNNING + ) + + mockable_state_machine.worker.begin_task = start_task # type: ignore + response = client.post("/tasks", json=_TASK.dict()) + task_id = response.json()["task_id"] + + task_json = {"task_id": task_id} + client.put("/worker/task", json=task_json) + + active_task = mockable_state_machine.worker.get_active_task() + assert active_task is not None + assert active_task.task_id == task_id + + response = client.put( + "/worker/state", json={"new_state": WorkerState.ABORTING.name, "reason": "foo"} + ) + assert response.status_code is status.HTTP_202_ACCEPTED + assert abort.call_args == call("foo") + + +@pytest.mark.parametrize( + "worker_state", + [WorkerState.ABORTING, WorkerState.STOPPING], +) +def test_current_complete_returns_400( + mockable_state_machine: Handler, client: TestClient, worker_state: WorkerState +): + mockable_state_machine.worker._current = MagicMock() # type: ignore + mockable_state_machine.worker._current.is_complete = True # type: ignore + + # As _current.is_complete, necessarily state of run_engine is IDLE + response = client.put( + "/worker/state", json={"new_state": WorkerState.ABORTING.name, "reason": "foo"} + ) + assert response.status_code is status.HTTP_400_BAD_REQUEST diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index bb1932b9d..6dad9b132 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -169,7 +169,8 @@ def test_clear_task(worker: Worker) -> None: def test_clear_nonexistant_task(worker: Worker) -> None: - assert not worker.clear_task("foo") + with pytest.raises(KeyError): + worker.clear_task("foo") def test_does_not_allow_simultaneous_running_tasks(