Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transaction mode to worker #202

Merged
merged 24 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
335285a
Add transaction mode to worker class
callumforrester May 15, 2023
300f68d
Add tests for worker transactions
callumforrester May 15, 2023
b963dab
Integrate transactions into previous worker API
callumforrester May 17, 2023
392fa50
Synchronise concurrency test
callumforrester May 18, 2023
9d17d7c
Fix imports
callumforrester May 18, 2023
67b586f
Use transactions in REST API
callumforrester May 18, 2023
5cd1a5a
Fix tests
callumforrester May 19, 2023
f01ee21
Rename test
callumforrester May 19, 2023
db307c3
Separate out worker tests with comments
callumforrester May 19, 2023
2185dcd
Make clear task return bool rather than raise an exception
callumforrester May 19, 2023
300b1d8
Rename ActiveTask to TrackableTask
callumforrester May 19, 2023
17f34c4
Rename field task_name to task_id in various event types
callumforrester May 19, 2023
90c6147
Fix docstrings for worker
callumforrester May 19, 2023
d9b6c9d
Remove unecessary test repetition
callumforrester May 19, 2023
2ea816c
Add architecture descision record for queueing
callumforrester May 22, 2023
e0a68d2
Rename active task references
callumforrester May 22, 2023
5e4a704
Promote trackable task to API level
callumforrester May 22, 2023
7b7e9a0
Fix imports
callumforrester May 22, 2023
57603cb
Use fake device to test simultaneous plans
callumforrester May 23, 2023
10c0cc6
Rename private variable
callumforrester May 23, 2023
422f770
Rename test
callumforrester May 23, 2023
96e01f6
Fix test
callumforrester May 23, 2023
de633dc
Fix imports
callumforrester May 23, 2023
ebb94e8
Add extra tests
callumforrester May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/developer/explanations/decisions/0002-no-queues.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
2. No Queues
============

Date: 2023-05-22

Status
------

Accepted

Context
-------

In asking whether this service should hold and execute a queue of tasks.

Decision
--------

We will not hold any queues. The worker can execute one task at a time and will return
an error if asked to execute one task while another is running. Queueing should be the
responsibility of a different service.

Consequences
------------

The API must be kept queue-free, although transactions are permitted where the server
caches requests.
2 changes: 1 addition & 1 deletion src/blueapi/cli/amq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def on_progress_event_wrapper(
task_response = self.app.send_and_receive(
"worker.run", {"name": name, "params": params}, reply_type=TaskResponse
).result(5.0)
task_id = task_response.task_name
task_id = task_response.task_id

if timeout is not None:
complete.wait(timeout)
Expand Down
12 changes: 6 additions & 6 deletions src/blueapi/cli/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def _update(self, name: str, view: StatusView) -> None:


class CliEventRenderer:
_task_name: Optional[str]
_task_id: Optional[str]
_pbar_renderer: ProgressBarRenderer

def __init__(
self,
task_name: Optional[str] = None,
task_id: Optional[str] = None,
pbar_renderer: Optional[ProgressBarRenderer] = None,
) -> None:
self._task_name = task_name
self._task_id = task_id
if pbar_renderer is None:
pbar_renderer = ProgressBarRenderer()
self._pbar_renderer = pbar_renderer
Expand All @@ -65,14 +65,14 @@ def on_worker_event(self, event: WorkerEvent) -> None:
print(str(event.state))

def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool:
if self._task_name is None:
if self._task_id is None:
return True
elif isinstance(event, WorkerEvent):
return (
event.task_status is not None
and event.task_status.task_name == self._task_name
and event.task_status.task_id == self._task_id
)
elif isinstance(event, ProgressEvent):
return event.task_name == self._task_name
return event.task_id == self._task_id
else:
return False
5 changes: 3 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def submit_task(
handler: Handler = Depends(get_handler),
):
"""Submit a task onto the worker queue."""
handler.worker.submit_task(name, RunPlan(name=name, params=task))
return TaskResponse(task_name=name)
task_id = handler.worker.submit_task(RunPlan(name=name, params=task))
handler.worker.begin_task(task_id)
return TaskResponse(task_id=task_id)


@app.get("/worker/state")
Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/service/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ class TaskResponse(BlueapiBaseModel):
Acknowledgement that a task has started, includes its ID
"""

task_name: str = Field(description="Unique identifier for the task")
task_id: str = Field(description="Unique identifier for the task")
5 changes: 4 additions & 1 deletion src/blueapi/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .multithread import run_worker_in_own_thread
from .reworker import RunEngineWorker
from .task import RunPlan, Task
from .worker import Worker
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

__all__ = [
"run_worker_in_own_thread",
Expand All @@ -15,4 +16,6 @@
"StatusView",
"ProgressEvent",
"TaskStatus",
"TrackableTask",
"WorkerBusyError",
]
4 changes: 2 additions & 2 deletions src/blueapi/worker/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ProgressEvent(BlueapiBaseModel):
such as moving motors and exposing detectors.
"""

task_name: str
task_id: str
statuses: Mapping[str, StatusView] = Field(default_factory=dict)


Expand All @@ -97,7 +97,7 @@ class TaskStatus(BlueapiBaseModel):
Status of a task the worker is running.
"""

task_name: str
task_id: str
task_complete: bool
task_failed: bool

Expand Down
60 changes: 43 additions & 17 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
WorkerState,
)
from .multithread import run_worker_in_own_thread
from .task import ActiveTask, Task
from .worker import Worker
from .task import Task
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

LOGGER = logging.getLogger(__name__)
Expand All @@ -47,11 +47,13 @@ class RunEngineWorker(Worker[Task]):
_ctx: BlueskyContext
_stop_timeout: float

_pending_tasks: Dict[str, TrackableTask]

_state: WorkerState
_errors: List[str]
_warnings: List[str]
_task_queue: Queue # type: ignore
_current: Optional[ActiveTask]
_task_channel: Queue # type: ignore
_current: Optional[TrackableTask]
_status_lock: RLock
_status_snapshot: Dict[str, StatusView]
_completed_statuses: Set[str]
Expand All @@ -70,10 +72,12 @@ def __init__(
self._ctx = ctx
self._stop_timeout = stop_timeout

self._pending_tasks = {}
callumforrester marked this conversation as resolved.
Show resolved Hide resolved

self._state = WorkerState.from_bluesky_state(ctx.run_engine.state)
self._errors = []
self._warnings = []
self._task_queue = Queue(maxsize=1)
self._task_channel = Queue(maxsize=1)
self._current = None
self._worker_events = EventPublisher()
self._progress_events = EventPublisher()
Expand All @@ -85,11 +89,33 @@ def __init__(
self._stopping = Event()
self._stopped = Event()

def submit_task(self, name: str, task: Task) -> None:
active_task = ActiveTask(name, task)
LOGGER.info(f"Submitting: {active_task}")
def clear_task(self, task_id: str) -> bool:
if task_id in self._pending_tasks:
del self._pending_tasks[task_id]
return True
else:
return False

def get_pending_tasks(self) -> List[TrackableTask[Task]]:
return list(self._pending_tasks.values())

def begin_task(self, task_id: str) -> None:
task = self._pending_tasks.get(task_id)
if task is not None:
self._submit_trackable_task(task)
else:
raise KeyError(f"No pending task with ID {task_id}")
callumforrester marked this conversation as resolved.
Show resolved Hide resolved

def submit_task(self, task: Task) -> str:
task_id: str = str(uuid.uuid4())
trackable_task = TrackableTask(task_id=task_id, task=task)
self._pending_tasks[task_id] = trackable_task
return task_id

def _submit_trackable_task(self, trackable_task: TrackableTask) -> None:
LOGGER.info(f"Submitting: {trackable_task}")
try:
self._task_queue.put_nowait(active_task)
self._task_channel.put_nowait(trackable_task)
except Full:
LOGGER.error("Cannot submit task while another is running")
raise WorkerBusyError("Cannot submit task while another is running")
Expand All @@ -104,7 +130,7 @@ def stop(self) -> None:

# If the worker has not yet started there is nothing to do.
if self._started.is_set():
self._task_queue.put(KillSignal())
self._task_channel.put(KillSignal())
self._stopped.wait(timeout=self._stop_timeout)
# Event timeouts do not actually raise errors
if not self._stopped.is_set():
Expand Down Expand Up @@ -138,8 +164,8 @@ def _cycle_with_error_handling(self) -> None:
def _cycle(self) -> None:
try:
LOGGER.info("Awaiting task")
next_task: Union[ActiveTask, KillSignal] = self._task_queue.get()
if isinstance(next_task, ActiveTask):
next_task: Union[TrackableTask, KillSignal] = self._task_channel.get()
if isinstance(next_task, TrackableTask):
LOGGER.info(f"Got new task: {next_task}")
self._current = next_task # Informing mypy that the task is not None
self._current.task.do_task(self._ctx)
Expand Down Expand Up @@ -200,11 +226,11 @@ def _report_status(
warnings = self._warnings
if self._current is not None:
task_status = TaskStatus(
task_name=self._current.name,
task_id=self._current.task_id,
task_complete=self._current.is_complete,
task_failed=self._current.is_error or bool(errors),
)
correlation_id = self._current.name
correlation_id = self._current.task_id
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
else:
task_status = None
correlation_id = None
Expand All @@ -219,7 +245,7 @@ def _report_status(

def _on_document(self, name: str, document: Mapping[str, Any]) -> None:
if self._current is not None:
correlation_id = self._current.name
correlation_id = self._current.task_id
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
self._data_events.publish(
DataEvent(name=name, doc=document), correlation_id
)
Expand Down Expand Up @@ -293,10 +319,10 @@ def _publish_status_snapshot(self) -> None:
else:
self._progress_events.publish(
ProgressEvent(
task_name=self._current.name,
task_id=self._current.task_id,
statuses=self._status_snapshot,
),
self._current.name,
self._current.task_id,
)


Expand Down
9 changes: 0 additions & 9 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Mapping

from pydantic import BaseModel, Field, parse_obj_as
Expand Down Expand Up @@ -65,11 +64,3 @@ def _lookup_params(

model = plan.model
return parse_obj_as(model, params)


@dataclass
class ActiveTask:
name: str
task: Task
is_complete: bool = False
is_error: bool = False
56 changes: 51 additions & 5 deletions src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,73 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, List, TypeVar

from blueapi.core import DataEvent, EventStream
from blueapi.utils import BlueapiBaseModel

from .event import ProgressEvent, WorkerEvent, WorkerState

T = TypeVar("T")


class TrackableTask(BlueapiBaseModel, Generic[T]):
"""
A representation of a task that the worker recognizes
"""

task_id: str
task: T
is_complete: bool = False
is_error: bool = False
callumforrester marked this conversation as resolved.
Show resolved Hide resolved


class Worker(ABC, Generic[T]):
"""
Entity that takes and runs tasks. Intended to be a central,
atomic worker rather than a load distributor
"""

@abstractmethod
def submit_task(self, __name: str, __task: T) -> None:
def get_pending_tasks(self) -> List[TrackableTask[T]]:
"""
Submit a task to be run
Return a list of all tasks pending on the worker,
any one of which can be triggered with begin_task.

Returns:
List[TrackableTask[T]]: List of task objects
"""

@abstractmethod
def clear_task(self, task_id: str) -> bool:
"""
Remove a pending task from the worker

Args:
__name (str): A unique name to identify this task
__task (T): The task to run
task_id: The ID of the task to be removed
Returns:
bool: True if the task existed in the first place
"""

@abstractmethod
def begin_task(self, task_id: str) -> None:
"""
Trigger a pending task. Will fail if the worker is busy.

Args:
task_id: The ID of the task to be triggered
Throws:
WorkerBusyError: If the worker is already running a task.
KeyError: If the task ID does not exist
"""

@abstractmethod
def submit_task(self, task: T) -> str:
"""
Submit a task to be run on begin_task

Args:
task: A description of the task
Returns:
str: A unique ID to refer to this task
"""

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def test_put_plan_submits_task(handler: Handler, client: TestClient) -> None:

client.put(f"/task/{task_name}", json=task_json)

task_queue = handler.worker._task_queue.queue # type: ignore
assert len(task_queue) == 1
assert task_queue[0].task == RunPlan(name=task_name, params=task_json)
assert handler.worker.get_pending_tasks()[0].task == RunPlan(
name=task_name, params=task_json
)


def test_get_state_updates(handler: Handler, client: TestClient) -> None:
Expand Down
Loading