Skip to content

Commit

Permalink
Add the unstarted tasks endpoint (#508)
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot authored Jun 20, 2024
1 parent d18fcf8 commit 31c94c5
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 14 deletions.
40 changes: 40 additions & 0 deletions docs/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ components:
- task_id
title: TaskResponse
type: object
TasksListResponse:
additionalProperties: false
description: Diagnostic information on the tasks
properties:
tasks:
description: List of tasks
items:
$ref: '#/components/schemas/TrackableTask'
title: Tasks
type: array
required:
- tasks
title: TasksListResponse
type: object
TrackableTask:
additionalProperties: false
description: A representation of a task that the worker recognizes
Expand Down Expand Up @@ -307,6 +321,32 @@ paths:
description: Validation Error
summary: Get Plan By Name
/tasks:
get:
description: 'Retrieve tasks based on their status.
The status of a newly created task is ''unstarted''.'
operationId: get_tasks_tasks_get
parameters:
- in: query
name: task_status
required: false
schema:
title: Task Status
type: string
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/TasksListResponse'
description: Successful Response
'422':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
description: Validation Error
summary: Get Tasks
post:
description: Submit a task to the worker.
operationId: submit_task_tasks_post
Expand Down
5 changes: 4 additions & 1 deletion src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from blueapi.messaging.base import MessagingTemplate
from blueapi.service.handler_base import BlueskyHandler
from blueapi.service.model import DeviceModel, PlanModel, WorkerTask
from blueapi.worker.event import WorkerState
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.reworker import TaskWorker
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask, Worker
Expand Down Expand Up @@ -115,6 +115,9 @@ def begin_task(self, task: WorkerTask) -> WorkerTask:
self._worker.begin_task(task.task_id)
return task

def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
return self._worker.get_tasks_by_status(status)

@property
def active_task(self) -> TrackableTask | None:
return self._worker.get_active_task()
Expand Down
12 changes: 11 additions & 1 deletion src/blueapi/service/handler_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod

from blueapi.service.model import DeviceModel, PlanModel, WorkerTask
from blueapi.worker.event import WorkerState
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask

Expand Down Expand Up @@ -49,6 +49,16 @@ def clear_task(self, task_id: str) -> str:
def begin_task(self, task: WorkerTask) -> WorkerTask:
"""Trigger a task. Will fail if the worker is busy"""

@abstractmethod
def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
"""
Retrieve a list of tasks based on their status.
Args:
str: The status to filter tasks by.
Returns:
list[TrackableTask]: A list of tasks that match the given status.
"""

@property
@abstractmethod
def active_task(self) -> TrackableTask | None:
Expand Down
34 changes: 34 additions & 0 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from blueapi.config import ApplicationConfig
from blueapi.worker import Task, TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum

from .handler_base import BlueskyHandler
from .model import (
Expand All @@ -26,6 +27,7 @@
PlanResponse,
StateChangeRequest,
TaskResponse,
TasksListResponse,
WorkerTask,
)
from .subprocess_handler import SubprocessHandler
Expand Down Expand Up @@ -173,6 +175,38 @@ def delete_submitted_task(
return TaskResponse(task_id=handler.clear_task(task_id))


def validate_task_status(v: str) -> TaskStatusEnum:
v_upper = v.upper()
if v_upper not in TaskStatusEnum.__members__:
raise ValueError("Invalid status query parameter")
return TaskStatusEnum(v_upper)


@app.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK)
def get_tasks(
task_status: str | None = None,
handler: BlueskyHandler = Depends(get_handler),
) -> TasksListResponse:
"""
Retrieve tasks based on their status.
The status of a newly created task is 'unstarted'.
"""
tasks = []
if task_status:
try:
desired_status = validate_task_status(task_status)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid status query parameter",
) from e

tasks = handler.get_tasks_by_status(desired_status)
else:
tasks = handler.tasks
return TasksListResponse(tasks=tasks)


@app.put(
"/worker/task",
response_model=WorkerTask,
Expand Down
9 changes: 9 additions & 0 deletions src/blueapi/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from blueapi.core import BLUESKY_PROTOCOLS, Device, Plan
from blueapi.utils import BlueapiBaseModel
from blueapi.worker import Worker, WorkerState
from blueapi.worker.worker import TrackableTask

_UNKNOWN_NAME = "UNKNOWN"

Expand All @@ -33,6 +34,14 @@ def _protocol_names(device: Device) -> Iterable[str]:
yield protocol.__name__


class TasksListResponse(BlueapiBaseModel):
"""
Diagnostic information on the tasks
"""

tasks: list[TrackableTask] = Field(description="List of tasks")


class DeviceRequest(BlueapiBaseModel):
"""
A query for devices
Expand Down
9 changes: 8 additions & 1 deletion src/blueapi/service/subprocess_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from blueapi.service.handler import get_handler, setup_handler, teardown_handler
from blueapi.service.handler_base import BlueskyHandler, HandlerNotStartedError
from blueapi.service.model import DeviceModel, PlanModel, WorkerTask
from blueapi.worker.event import WorkerState
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask

Expand Down Expand Up @@ -82,6 +82,9 @@ def submit_task(self, task: Task) -> str:
def clear_task(self, task_id: str) -> str:
return self._run_in_subprocess(clear_task_by_id, [task_id])

def get_tasks_by_status(self, task_status: TaskStatusEnum) -> list[TrackableTask]:
return self._run_in_subprocess(get_tasks_by_status, [task_status])

def begin_task(self, task: WorkerTask) -> WorkerTask:
return self._run_in_subprocess(begin_task, [task])

Expand Down Expand Up @@ -137,6 +140,10 @@ def submit_task(task: Task) -> str:
return get_handler().submit_task(task)


def get_tasks_by_status(task_status: TaskStatusEnum) -> list[TrackableTask]:
return get_handler().get_tasks_by_status(task_status)


def clear_task_by_id(task_id: str) -> str:
return get_handler().clear_task(task_id)

Expand Down
8 changes: 8 additions & 0 deletions src/blueapi/worker/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
RawRunEngineState = type[PropertyMachine | ProxyString | str]


# NOTE this is interim until refactor
class TaskStatusEnum(str, Enum):
PENDING = "PENDING"
COMPLETE = "COMPLETE"
ERROR = "ERROR"
RUNNING = "RUNNING"


class WorkerState(str, Enum):
"""
The state of the Worker.
Expand Down
18 changes: 16 additions & 2 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RawRunEngineState,
StatusView,
TaskStatus,
TaskStatusEnum,
WorkerEvent,
WorkerState,
)
Expand Down Expand Up @@ -113,12 +114,25 @@ def cancel_active_task(
self._ctx.run_engine.stop()
return self._current.task_id

def get_tasks(self) -> list[TrackableTask[Task]]:
def get_tasks(self) -> list[TrackableTask]:
return list(self._tasks.values())

def get_task_by_id(self, task_id: str) -> TrackableTask[Task] | None:
def get_task_by_id(self, task_id: str) -> TrackableTask | None:
return self._tasks.get(task_id)

def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
if status == TaskStatusEnum.RUNNING:
return [
task
for task in self._tasks.values()
if not task.is_pending and not task.is_complete
]
elif status == TaskStatusEnum.PENDING:
return [task for task in self._tasks.values() if task.is_pending]
elif status == TaskStatusEnum.COMPLETE:
return [task for task in self._tasks.values() if task.is_complete]
return []

def get_active_task(self) -> TrackableTask[Task] | None:
return self._current

Expand Down
12 changes: 11 additions & 1 deletion src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from blueapi.core import DataEvent, EventStream
from blueapi.utils import BlueapiBaseModel

from .event import ProgressEvent, WorkerEvent, WorkerState
from .event import ProgressEvent, TaskStatusEnum, WorkerEvent, WorkerState

T = TypeVar("T")

Expand Down Expand Up @@ -107,6 +107,16 @@ def submit_task(self, task: T) -> str:
str: A unique ID to refer to this task
"""

@abstractmethod
def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
"""
Retrieve a list of tasks based on their status.
Args:
str: The status to filter tasks by.
Returns:
list[TrackableTask]: A list of tasks that match the given status.
"""

@abstractmethod
def start(self) -> None:
"""
Expand Down
46 changes: 44 additions & 2 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch

import pytest
from bluesky.run_engine import RunEngineStateMachine
Expand All @@ -12,8 +12,9 @@
from blueapi.service.handler import Handler
from blueapi.service.main import get_handler, setup_handler, teardown_handler
from blueapi.service.model import WorkerTask
from blueapi.worker import WorkerState
from blueapi.worker.event import WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask

_TASK = Task(name="count", params={"detectors": ["x"]})

Expand Down Expand Up @@ -532,3 +533,44 @@ def test_teardown_handler():

def test_teardown_handler_does_not_raise():
assert teardown_handler() is None


tasks_data = [
TrackableTask(
task_id="1", task=Task(name="first_task"), is_complete=False, is_pending=False
),
TrackableTask(
task_id="2", task=Task(name="first_task"), is_complete=False, is_pending=True
),
]


def test_get_unstarted_tasks(handler: Handler, client: TestClient):
# handler.tasks = tasks_data # overriding the property
with patch.object(handler._worker, "get_tasks_by_status", return_value=tasks_data):
response = client.get("/tasks/?task_status=pending")
assert response.status_code == 200
r = response.json()
assert len(r) == 1 # As per our mock data, only 1 task should be 'unstarted'
assert (
r["tasks"][0]["task_id"] == "1"
) # Check that the correct task ID is returned


def test_get_tasks_bad_status(handler: Handler, client: TestClient):
with patch.object(handler._worker, "get_tasks_by_status", return_value=tasks_data):
response = client.get("/tasks/?task_status=invalid")
assert response.status_code == 400
assert "Invalid status query parameter" in response.json()["detail"]


def test_get_just_all_tasks(handler: Handler, client: TestClient):
with patch.object(handler._worker, "get_tasks", return_value=tasks_data):
response = client.get("/tasks")
assert response.status_code == 200
r = response.json()
response_tasks = r["tasks"]
assert len(response_tasks) == 2
assert (
response_tasks[0]["task_id"] == "1"
) # Check that the correct task ID is returned
Loading

0 comments on commit 31c94c5

Please sign in to comment.