diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 8372f8006..d8be05ad9 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -134,7 +134,10 @@ def get_devices(obj: dict) -> None: def listen_to_events(obj: dict) -> None: """Listen to events output by blueapi""" config: ApplicationConfig = obj["config"] - amq_client = AmqClient(StompMessagingTemplate.autoconfigured(config.stomp)) + if config.stomp is not None: + amq_client = AmqClient(StompMessagingTemplate.autoconfigured(config.stomp)) + else: + raise RuntimeError("Message bus needs to be configured") def on_event( context: MessageContext, @@ -172,8 +175,13 @@ def run_plan( client: BlueapiRestClient = obj["rest_client"] logger = logging.getLogger(__name__) - - amq_client = AmqClient(StompMessagingTemplate.autoconfigured(config.stomp)) + if config.stomp is not None: + _message_template = StompMessagingTemplate.autoconfigured(config.stomp) + else: + raise RuntimeError( + "Cannot run plans without Stomp configuration to track progress" + ) + amq_client = AmqClient(_message_template) finished_event: deque[WorkerEvent] = deque() def store_finished_event(event: WorkerEvent) -> None: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 3fafdc8c6..d4376fb4e 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -98,7 +98,7 @@ class ApplicationConfig(BlueapiBaseModel): config tree. """ - stomp: StompConfig = Field(default_factory=StompConfig) + stomp: StompConfig | None = None env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index f38dcf434..0280c5d3d 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -1,5 +1,6 @@ import logging from collections.abc import Mapping +from typing import Any from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext @@ -27,7 +28,7 @@ class Handler(BlueskyHandler): _context: BlueskyContext _worker: Worker _config: ApplicationConfig - _messaging_template: MessagingTemplate + _messaging_template: MessagingTemplate | None _initialized: bool = False def __init__( @@ -46,25 +47,30 @@ def __init__( self._context, broadcast_statuses=self._config.env.events.broadcast_status_events, ) - self._messaging_template = ( - messaging_template - or StompMessagingTemplate.autoconfigured(self._config.stomp) - ) + if self._config.stomp is None: + self._messaging_template = messaging_template + else: + self._messaging_template = ( + messaging_template + or StompMessagingTemplate.autoconfigured(self._config.stomp) + ) def start(self) -> None: self._worker.start() - event_topic = self._messaging_template.destinations.topic("public.worker.event") - - self._publish_event_streams( - { - self._worker.worker_events: event_topic, - self._worker.progress_events: event_topic, - self._worker.data_events: event_topic, - } - ) + if self._messaging_template is not None: + event_topic = self._messaging_template.destinations.topic( + "public.worker.event" + ) + self._publish_event_streams( + { + self._worker.worker_events: event_topic, + self._worker.progress_events: event_topic, + self._worker.data_events: event_topic, + } + ) - self._messaging_template.connect() + self._messaging_template.connect() self._initialized = True def _publish_event_streams( @@ -74,16 +80,19 @@ def _publish_event_streams( self._publish_event_stream(stream, destination) def _publish_event_stream(self, stream: EventStream, destination: str) -> None: - stream.subscribe( - lambda event, correlation_id: self._messaging_template.send( - destination, event, None, correlation_id - ) - ) + def forward_message(event: Any, correlation_id: str | None) -> None: + if self._messaging_template is not None: + self._messaging_template.send(destination, event, None, correlation_id) + + stream.subscribe(forward_message) def stop(self) -> None: self._initialized = False self._worker.stop() - if self._messaging_template.is_connected(): + if ( + self._messaging_template is not None + and self._messaging_template.is_connected() + ): self._messaging_template.disconnect() @property diff --git a/tests/example_yaml/rest_config.yaml b/tests/example_yaml/rest_config.yaml index 51a4714b1..b6ea43a5e 100644 --- a/tests/example_yaml/rest_config.yaml +++ b/tests/example_yaml/rest_config.yaml @@ -1,3 +1,6 @@ api: host: a.fake.host port: 12345 +stomp: + host: localhost + port: 61613 diff --git a/tests/example_yaml/valid_stomp_config.yaml b/tests/example_yaml/valid_stomp_config.yaml new file mode 100644 index 000000000..b4004a7d3 --- /dev/null +++ b/tests/example_yaml/valid_stomp_config.yaml @@ -0,0 +1,3 @@ +stomp: + host: localhost + port: 61613 diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 0e7f8afdf..66e0f8b03 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -12,6 +12,7 @@ from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate +from blueapi.service.handler import get_handler, setup_handler, teardown_handler _TIMEOUT: float = 10.0 _COUNT = itertools.count() @@ -28,13 +29,16 @@ def test_stomp_configs(self) -> Iterable[StompConfig]: @pytest.fixture(params=StompTestingSettings().test_stomp_configs()) def disconnected_template(request: pytest.FixtureRequest) -> MessagingTemplate: stomp_config = request.param - return StompMessagingTemplate.autoconfigured(stomp_config) + template = StompMessagingTemplate.autoconfigured(stomp_config) + assert template is not None + return template @pytest.fixture(params=StompTestingSettings().test_stomp_configs()) def template(request: pytest.FixtureRequest) -> Iterable[MessagingTemplate]: stomp_config = request.param template = StompMessagingTemplate.autoconfigured(stomp_config) + assert template is not None template.connect() yield template template.disconnect() @@ -218,3 +222,10 @@ def server(ctx: MessageContext, message: str) -> None: template.send(reply_queue, "ack", correlation_id=ctx.correlation_id) template.subscribe(destination, server) + + +def test_messaging_template_can_be_set_with_none(): + setup_handler(None) + teardown_handler() + with pytest.raises(ValueError): + get_handler() diff --git a/tests/test_cli.py b/tests/test_cli.py index eb42d0673..67a25d544 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -169,3 +169,40 @@ def test_config_passed_down_to_command_children( "params": {"time": 5}, } } + + +def test_invalid_stomp_config_for_listener(runner: CliRunner): + result = runner.invoke(main, ["controller", "listen"]) + assert ( + isinstance(result.exception, RuntimeError) + and str(result.exception) == "Message bus needs to be configured" + ) + + +def test_cannot_run_plans_without_stomp_config(runner: CliRunner): + result = runner.invoke(main, ["controller", "run", "sleep", '{"time": 5}']) + assert ( + isinstance(result.exception, RuntimeError) + and str(result.exception) + == "Cannot run plans without Stomp configuration to track progress" + ) + + +@pytest.mark.stomp +def test_valid_stomp_config_for_listener(runner: CliRunner): + result = runner.invoke( + main, + [ + "-c", + "tests/example_yaml/valid_stomp_config.yaml", + "controller", + "listen", + ], + input="\n", + ) + assert result.exit_code == 0 + + +def test_invalid_condition_for_run(runner: CliRunner): + result = runner.invoke(main, ["controller", "run", "sleep", '{"time": 5}']) + assert type(result.exception) is RuntimeError