Skip to content

Commit

Permalink
Allow getting the state from the RunEngineWorker and expose the state…
Browse files Browse the repository at this point in the history
… to the REST API (#218)

* Allow querying the state of the RunEngineWorker, which maps to the state of the RunEngine

* Add Status get to openapi schema and CLI

* Simplify mocking of Handler, add tests for wrapping RunEngineState
  • Loading branch information
DiamondJoseph authored May 23, 2023
1 parent 03c4d1c commit febaf8e
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 43 deletions.
27 changes: 27 additions & 0 deletions docs/user/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ components:
- type
title: ValidationError
type: object
WorkerState:
description: The state of the Worker.
enum:
- IDLE
- RUNNING
- PAUSING
- PAUSED
- HALTING
- STOPPING
- ABORTING
- SUSPENDING
- PANICKED
- UNKNOWN
title: WorkerState
type: string
info:
title: BlueAPI Control
version: 0.1.0
Expand Down Expand Up @@ -215,3 +230,15 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
description: Validation Error
summary: Submit Task
/worker/state:
get:
description: Get the State of the Worker
operationId: get_state_worker_state_get
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/WorkerState'
description: Successful Response
summary: Get State
11 changes: 11 additions & 0 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,14 @@ def run_plan(obj: dict, name: str, parameters: Optional[str]) -> None:
json=json.loads(parameters),
)
print(f"Response returned with {resp.status_code}")


@controller.command(name="state")
@check_connection
@click.pass_obj
def get_state(obj: dict) -> None:
config: ApplicationConfig = obj["config"]

resp = requests.get(f"http://{config.api.host}:{config.api.port}/worker/state")
print(f"Response returned with {resp.status_code}: ")
pprint(resp.json())
35 changes: 22 additions & 13 deletions src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,46 @@ class Handler:
context: BlueskyContext
worker: Worker
config: ApplicationConfig
message_bus: MessagingTemplate

def __init__(self, config: Optional[ApplicationConfig] = None) -> None:
self.context = BlueskyContext()
self.config = config if config is not None else ApplicationConfig()
messaging_template: MessagingTemplate

def __init__(
self,
config: Optional[ApplicationConfig] = None,
context: Optional[BlueskyContext] = None,
messaging_template: Optional[MessagingTemplate] = None,
worker: Optional[Worker] = None,
) -> None:
self.config = config or ApplicationConfig()
self.context = context or BlueskyContext()

logging.basicConfig(level=self.config.logging.level)

self.context.with_config(self.config.env)

self.worker = RunEngineWorker(self.context)
self.message_bus = StompMessagingTemplate.autoconfigured(self.config.stomp)
self.worker = worker or RunEngineWorker(self.context)
self.messaging_template = (
messaging_template
or StompMessagingTemplate.autoconfigured(self.config.stomp)
)

def start(self) -> None:
self.worker.start()

self._publish_event_streams(
{
self.worker.worker_events: self.message_bus.destinations.topic(
self.worker.worker_events: self.messaging_template.destinations.topic(
"public.worker.event"
),
self.worker.progress_events: self.message_bus.destinations.topic(
self.worker.progress_events: self.messaging_template.destinations.topic(
"public.worker.event.progress"
),
self.worker.data_events: self.message_bus.destinations.topic(
self.worker.data_events: self.messaging_template.destinations.topic(
"public.worker.event.data"
),
}
)

self.message_bus.connect()
self.messaging_template.connect()

def _publish_event_streams(
self, streams_to_destinations: Mapping[EventStream, str]
Expand All @@ -54,14 +63,14 @@ def _publish_event_streams(

def _publish_event_stream(self, stream: EventStream, destination: str) -> None:
stream.subscribe(
lambda event, correlation_id: self.message_bus.send(
lambda event, correlation_id: self.messaging_template.send(
destination, event, None, correlation_id
)
)

def stop(self) -> None:
self.worker.stop()
self.message_bus.disconnect()
self.messaging_template.disconnect()


HANDLER: Optional[Handler] = None
Expand Down
8 changes: 7 additions & 1 deletion src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import Body, Depends, FastAPI, HTTPException

from blueapi.config import ApplicationConfig
from blueapi.worker import RunPlan
from blueapi.worker import RunPlan, WorkerState

from .handler import Handler, get_handler, setup_handler, teardown_handler
from .model import DeviceModel, DeviceResponse, PlanModel, PlanResponse, TaskResponse
Expand Down Expand Up @@ -74,6 +74,12 @@ def submit_task(
return TaskResponse(task_name=name)


@app.get("/worker/state")
async def get_state(handler: Handler = Depends(get_handler)) -> WorkerState:
"""Get the State of the Worker"""
return handler.worker.state


def start(config: ApplicationConfig):
import uvicorn

Expand Down
4 changes: 4 additions & 0 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def stop(self) -> None:
LOGGER.info("Stopping worker: nothing to do")
LOGGER.info("Stopped")

@property
def state(self) -> WorkerState:
return self._state

def run(self) -> None:
LOGGER.info("Worker starting")
self._ctx.run_engine.state_hook = self._on_state_change
Expand Down
9 changes: 8 additions & 1 deletion src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from blueapi.core import DataEvent, EventStream

from .event import ProgressEvent, WorkerEvent
from .event import ProgressEvent, WorkerEvent, WorkerState

T = TypeVar("T")

Expand Down Expand Up @@ -43,6 +43,13 @@ def stop(self) -> None:
Command the worker to gracefully stop. Blocks until it has shut down.
"""

@property
@abstractmethod
def state(self) -> WorkerState:
"""
:return: state of the worker
"""

@property
@abstractmethod
def worker_events(self) -> EventStream[WorkerEvent, int]:
Expand Down
35 changes: 14 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501

import pytest
from bluesky.run_engine import RunEngineStateMachine
from fastapi.testclient import TestClient
from mock import Mock

from blueapi.core.context import BlueskyContext
from blueapi.service.handler import Handler, get_handler
from blueapi.service.main import app
from blueapi.worker.reworker import RunEngineWorker
from src.blueapi.core import BlueskyContext


def pytest_addoption(parser):
Expand All @@ -31,23 +31,8 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_stomp)


class MockHandler(Handler):
context: BlueskyContext
worker: RunEngineWorker

def __init__(self) -> None:
self.context = Mock()
self.worker = Mock()

def start(self):
return None

def stop(self):
return None


class Client:
def __init__(self, handler: MockHandler) -> None:
def __init__(self, handler: Handler) -> None:
"""Create tester object"""
self.handler = handler

Expand All @@ -58,10 +43,18 @@ def client(self) -> TestClient:


@pytest.fixture(scope="session")
def handler() -> MockHandler:
return MockHandler()
def handler() -> Handler:
context: BlueskyContext = Mock()
context.run_engine.state = RunEngineStateMachine.States.IDLE
handler = Handler(context=context)

def no_op():
return

handler.start = handler.stop = no_op # type: ignore
return handler


@pytest.fixture(scope="session")
def client(handler: MockHandler) -> TestClient:
def client(handler: Handler) -> TestClient:
return Client(handler).client
21 changes: 14 additions & 7 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass

from bluesky.run_engine import RunEngineStateMachine
from fastapi.testclient import TestClient
from pydantic import BaseModel

from blueapi.core.bluesky_types import Plan
from blueapi.service.handler import Handler
from blueapi.worker.task import RunPlan, Task
from blueapi.worker.task import RunPlan
from src.blueapi.worker import WorkerState


def test_get_plans(handler: Handler, client: TestClient) -> None:
Expand Down Expand Up @@ -75,12 +77,17 @@ class MyDevice:
def test_put_plan_submits_task(handler: Handler, client: TestClient) -> None:
task_json = {"detectors": ["x"]}
task_name = "count"
submitted_tasks = {}

def on_submit(name: str, task: Task):
submitted_tasks[name] = task
client.put(f"/task/{task_name}", json=task_json)

handler.worker.submit_task.side_effect = on_submit # type: ignore
task_queue = handler.worker._task_queue.queue # type: ignore
assert len(task_queue) == 1
assert task_queue[0].task == RunPlan(name=task_name, params=task_json)

client.put(f"/task/{task_name}", json=task_json)
assert submitted_tasks == {task_name: RunPlan(name=task_name, params=task_json)}

def test_get_state_updates(handler: Handler, client: TestClient) -> None:
assert client.get("/worker/state").text == f'"{WorkerState.IDLE.name}"'
handler.worker._on_state_change( # type: ignore
RunEngineStateMachine.States.RUNNING
)
assert client.get("/worker/state").text == f'"{WorkerState.RUNNING.name}"'

0 comments on commit febaf8e

Please sign in to comment.