Skip to content

Commit

Permalink
fresh take on the change of names
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot committed Nov 11, 2024
1 parent 00572da commit ad46d8a
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def __init__(
def from_config(cls, config: ApplicationConfig) -> "BlueapiClient":
rest = BlueapiRestClient(config.api)
if config.stomp is not None:
template = StompClient.for_broker(
stomp_client = StompClient.for_broker(
broker=Broker(
host=config.stomp.host,
port=config.stomp.port,
auth=config.stomp.auth,
)
)
events = EventBusClient(template)
events = EventBusClient(stomp_client)
else:
events = None
return cls(rest, events)
Expand Down
8 changes: 4 additions & 4 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def plan_2(...) -> MsgGenerator:

for obj in load_module_all(module):
if is_bluesky_plan_generator(obj):
self.plan(obj)
self.register_plan(obj)

def with_device_module(self, module: ModuleType) -> None:
self.with_dodal_module(module)
Expand All @@ -116,7 +116,7 @@ def with_dodal_module(self, module: ModuleType, **kwargs) -> None:
devices, exceptions = make_all_devices(module, **kwargs)

for device in devices.values():
self.device(device)
self.register_device(device)

# If exceptions have occurred, we log them but we do not make blueapi
# fall over
Expand All @@ -126,7 +126,7 @@ def with_dodal_module(self, module: ModuleType, **kwargs) -> None:
)
LOGGER.exception(NotConnected(exceptions))

def plan(self, plan: PlanGenerator) -> PlanGenerator:
def register_plan(self, plan: PlanGenerator) -> PlanGenerator:
"""
Register the argument as a plan in the context. Can be used as a decorator e.g.
@ctx.plan
Expand Down Expand Up @@ -154,7 +154,7 @@ def my_plan(a: int, b: str):
self.plan_functions[plan.__name__] = plan
return plan

def device(self, device: Device, name: str | None = None) -> None:
def register_device(self, device: Device, name: str | None = None) -> None:
"""
Register an device in the context. The device needs to be registered with a
name. If the device is Readable, Movable or Flyable it has a `name`
Expand Down
16 changes: 9 additions & 7 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def worker() -> TaskWorker:
def stomp_client() -> StompClient | None:
stomp_config = config().stomp
if stomp_config is not None:
template = StompClient.for_broker(
stomp_client = StompClient.for_broker(
broker=Broker(
host=stomp_config.host, port=stomp_config.port, auth=stomp_config.auth
)
Expand All @@ -68,8 +68,8 @@ def stomp_client() -> StompClient | None:
task_worker.data_events: event_topic,
}
)
template.connect()
return template
stomp_client.connect()
return stomp_client
else:
return None

Expand All @@ -88,8 +88,8 @@ def setup(config: ApplicationConfig) -> None:

def teardown() -> None:
worker().stop()
if (template := stomp_client()) is not None:
template.disconnect()
if (stomp_client_ref := stomp_client()) is not None:
stomp_client_ref.disconnect()
context.cache_clear()
worker.cache_clear()
stomp_client.cache_clear()
Expand All @@ -104,8 +104,10 @@ def _publish_event_streams(

def _publish_event_stream(stream: EventStream, destination: DestinationBase) -> None:
def forward_message(event: Any, correlation_id: str | None) -> None:
if (template := stomp_client()) is not None:
template.send(destination, event, None, correlation_id=correlation_id)
if (stomp_client_ref := stomp_client()) is not None:
stomp_client_ref.send(
destination, event, None, correlation_id=correlation_id
)

stream.subscribe(forward_message)

Expand Down
12 changes: 6 additions & 6 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ async def delete_environment(
@start_as_current_span(TRACER)
def get_plans(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available plans."""
return PlanResponse(plans=runner.run(interface.get_plans))
plans = runner.run(interface.get_plans)
return PlanResponse(plans=plans)


@router.get(
Expand All @@ -150,7 +151,8 @@ def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
@start_as_current_span(TRACER)
def get_devices(runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about all available devices."""
return DeviceResponse(devices=runner.run(interface.get_devices))
devices = runner.run(interface.get_devices)
return DeviceResponse(devices=devices)


@router.get(
Expand Down Expand Up @@ -285,10 +287,8 @@ def get_task(
@start_as_current_span(TRACER)
def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask:
active = runner.run(interface.get_active_task)
if active is not None:
return WorkerTask(task_id=active.task_id)
else:
return WorkerTask(task_id=None)
task_id = active.task_id if active is not None else None
return WorkerTask(task_id=task_id)


@router.get("/worker/state")
Expand Down
26 changes: 13 additions & 13 deletions tests/unit_tests/client/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,44 @@


@pytest.fixture
def mock_template() -> StompClient:
def mock_stomp_client() -> StompClient:
return Mock(spec=StompClient)


@pytest.fixture
def events(mock_template: StompClient) -> EventBusClient:
return EventBusClient(app=mock_template)
def events(mock_stomp_client: StompClient) -> EventBusClient:
return EventBusClient(app=mock_stomp_client)


def test_context_manager_connects_and_disconnects(
events: EventBusClient,
mock_template: Mock,
mock_stomp_client: Mock,
):
mock_template.connect.assert_not_called()
mock_template.disconnect.assert_not_called()
mock_stomp_client.connect.assert_not_called()
mock_stomp_client.disconnect.assert_not_called()

with events:
mock_template.connect.assert_called_once()
mock_template.disconnect.assert_not_called()
mock_stomp_client.connect.assert_called_once()
mock_stomp_client.disconnect.assert_not_called()

mock_template.disconnect.assert_called_once()
mock_stomp_client.disconnect.assert_called_once()


def test_client_subscribes_to_all_events(
events: EventBusClient,
mock_template: Mock,
mock_stomp_client: Mock,
):
on_event = Mock
with events:
events.subscribe_to_all_events(on_event=on_event) # type: ignore
mock_template.subscribe.assert_called_once_with(ANY, on_event)
mock_stomp_client.subscribe.assert_called_once_with(ANY, on_event)


def test_client_raises_streaming_error_on_subscribe_failure(
events: EventBusClient,
mock_template: Mock,
mock_stomp_client: Mock,
):
mock_template.subscribe.side_effect = RuntimeError("Foo")
mock_stomp_client.subscribe.side_effect = RuntimeError("Foo")
on_event = Mock
with events:
with pytest.raises(
Expand Down
36 changes: 18 additions & 18 deletions tests/unit_tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def empty_context() -> BlueskyContext:
@pytest.fixture
def devicey_context(sim_motor: SynAxis, sim_detector: SynGauss) -> BlueskyContext:
ctx = BlueskyContext()
ctx.device(sim_motor)
ctx.device(sim_detector)
ctx.register_device(sim_motor)
ctx.register_device(sim_detector)
return ctx


Expand All @@ -117,7 +117,7 @@ def some_configurable() -> SomeConfigurable:

@pytest.mark.parametrize("plan", [has_no_params, has_one_param, has_some_params])
def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator) -> None:
empty_context.plan(plan)
empty_context.register_plan(plan)
assert plan.__name__ in empty_context.plans


Expand All @@ -127,7 +127,7 @@ def test_generated_schema(
def demo_plan(foo: int, mov: Movable) -> MsgGenerator: # type: ignore
...

empty_context.plan(demo_plan)
empty_context.register_plan(demo_plan)
schema = empty_context.plans["demo_plan"].model.schema()
assert schema["properties"] == {
"foo": {"title": "Foo", "type": "integer"},
Expand All @@ -140,7 +140,7 @@ def demo_plan(foo: int, mov: Movable) -> MsgGenerator: # type: ignore
)
def test_add_invalid_plan(empty_context: BlueskyContext, plan: PlanGenerator) -> None:
with pytest.raises(ValueError):
empty_context.plan(plan)
empty_context.register_plan(plan)


def test_add_plan_from_module(empty_context: BlueskyContext) -> None:
Expand All @@ -151,14 +151,14 @@ def test_add_plan_from_module(empty_context: BlueskyContext) -> None:


def test_add_named_device(empty_context: BlueskyContext, sim_motor: SynAxis) -> None:
empty_context.device(sim_motor)
empty_context.register_device(sim_motor)
assert empty_context.devices[SIM_MOTOR_NAME] is sim_motor


def test_add_nameless_device(
empty_context: BlueskyContext, some_configurable: SomeConfigurable
) -> None:
empty_context.device(some_configurable, "conf")
empty_context.register_device(some_configurable, "conf")
assert empty_context.devices["conf"] is some_configurable


Expand All @@ -167,13 +167,13 @@ def test_add_nameless_device_without_override(
some_configurable: SomeConfigurable,
) -> None:
with pytest.raises(KeyError):
empty_context.device(some_configurable)
empty_context.register_device(some_configurable)


def test_override_device_name(
empty_context: BlueskyContext, sim_motor: SynAxis
) -> None:
empty_context.device(sim_motor, "foo")
empty_context.register_device(sim_motor, "foo")
assert empty_context.devices["foo"] is sim_motor


Expand Down Expand Up @@ -246,12 +246,12 @@ def test_lookup_non_device(devicey_context: BlueskyContext) -> None:

def test_add_non_plan(empty_context: BlueskyContext) -> None:
with pytest.raises(TypeError):
empty_context.plan("not a plan") # type: ignore
empty_context.register_plan("not a plan") # type: ignore


def test_add_non_device(empty_context: BlueskyContext) -> None:
with pytest.raises(TypeError):
empty_context.device("not a device") # type: ignore
empty_context.register_device("not a device") # type: ignore


def test_add_devices_and_plans_from_modules_with_config(
Expand Down Expand Up @@ -362,8 +362,8 @@ def test_str_default(
empty_context: BlueskyContext, sim_motor: SynAxis, alt_motor: SynAxis
):
movable_ref = empty_context._reference(Movable)
empty_context.device(sim_motor)
empty_context.plan(has_default_reference)
empty_context.register_device(sim_motor)
empty_context.register_plan(has_default_reference)

spec = empty_context._type_spec_for_function(has_default_reference)
assert spec["m"][0] is movable_ref
Expand All @@ -373,16 +373,16 @@ def test_str_default(
model = empty_context.plans[has_default_reference.__name__].model
adapter = TypeAdapter(model)
assert adapter.validate_python({}).m is sim_motor # type: ignore
empty_context.device(alt_motor)
empty_context.register_device(alt_motor)
assert adapter.validate_python({"m": ALT_MOTOR_NAME}).m is alt_motor # type: ignore


def test_nested_str_default(
empty_context: BlueskyContext, sim_motor: SynAxis, alt_motor: SynAxis
):
movable_ref = empty_context._reference(Movable)
empty_context.device(sim_motor)
empty_context.plan(has_default_nested_reference)
empty_context.register_device(sim_motor)
empty_context.register_plan(has_default_nested_reference)

spec = empty_context._type_spec_for_function(has_default_nested_reference)
assert spec["m"][0] == list[movable_ref] # type: ignore
Expand All @@ -393,7 +393,7 @@ def test_nested_str_default(
adapter = TypeAdapter(model)

assert adapter.validate_python({}).m == [sim_motor] # type: ignore
empty_context.device(alt_motor)
empty_context.register_device(alt_motor)
assert adapter.validate_python({"m": [ALT_MOTOR_NAME]}).m == [alt_motor] # type: ignore


Expand All @@ -402,6 +402,6 @@ def a_plan(foo_bar: int, baz: str) -> MsgGenerator:
if False:
yield

empty_context.plan(a_plan)
empty_context.register_plan(a_plan)
with pytest.raises(ValidationError):
empty_context.plans[a_plan.__name__].model(fooBar=1, baz="test")
Loading

0 comments on commit ad46d8a

Please sign in to comment.