diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 3bc586034..37a41f30d 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -59,9 +59,12 @@ def _runner() -> WorkerDispatcher: return RUNNER -def setup_runner(config: ApplicationConfig | None = None, use_subprocess: bool = True): +def setup_runner( + config: ApplicationConfig | None = None, + runner: WorkerDispatcher | None = None, +): global RUNNER - runner = WorkerDispatcher(config, use_subprocess) + runner = runner or WorkerDispatcher(config) runner.start() RUNNER = runner diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index be4c49ba6..d8d957055 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -8,7 +8,6 @@ from typing import Any, ParamSpec, TypeVar from observability_utils.tracing import ( - add_span_attributes, get_context_propagator, get_tracer, start_as_current_span, @@ -44,17 +43,19 @@ class WorkerDispatcher: _config: ApplicationConfig _subprocess: PoolClass | None - _use_subprocess: bool _state: EnvironmentResponse def __init__( self, config: ApplicationConfig | None = None, - use_subprocess: bool = True, + subprocess_factory: Callable[[], PoolClass] | None = None, ) -> None: + def default_subprocess_factory(): + return Pool(initializer=_init_worker, processes=1) + self._config = config or ApplicationConfig() self._subprocess = None - self._use_subprocess = use_subprocess + self._subprocess_factory = subprocess_factory or default_subprocess_factory self._state = EnvironmentResponse( initialized=False, ) @@ -68,12 +69,8 @@ def reload(self): @start_as_current_span(TRACER) def start(self): - add_span_attributes( - {"_use_subprocess": self._use_subprocess, "_config": str(self._config)} - ) try: - if self._use_subprocess: - self._subprocess = Pool(initializer=_init_worker, processes=1) + self._subprocess = self._subprocess_factory() self.run(setup, self._config) self._state = EnvironmentResponse(initialized=True) except Exception as e: @@ -107,40 +104,25 @@ def run( function: Callable[P, T], *args: P.args, **kwargs: P.kwargs, - ) -> T: - """Calls the supplied function, which is modified to accept a dict as it's new - first param, before being passed to the subprocess runner, or just run in place. - """ - add_span_attributes({"use_subprocess": self._use_subprocess}) - if self._use_subprocess: - return self._run_in_subprocess(function, *args, **kwargs) - else: - return function(*args, **kwargs) - - @start_as_current_span(TRACER, "function", "args", "kwargs") - def _run_in_subprocess( - self, - function: Callable[P, T], - *args: P.args, - **kwargs: P.kwargs, ) -> T: """Call the supplied function, passing the current Span ID, if one - exists,from the observability context inro the _rpc caller function. + exists,from the observability context into the import_and_run_function + caller function. + When this is deserialized in and run by the subprocess, this will allow its functions to use the corresponding span as their parent span.""" + if self._subprocess is None: raise InvalidRunnerStateError("Subprocess runner has not been started") if not (hasattr(function, "__name__") and hasattr(function, "__module__")): raise RpcError(f"{function} is anonymous, cannot be run in subprocess") - if not callable(function): - raise RpcError(f"{function} is not Callable, cannot be run in subprocess") try: return_type = inspect.signature(function).return_annotation except TypeError: return_type = None return self._subprocess.apply( - _rpc, + import_and_run_function, ( function.__module__, function.__name__, @@ -164,7 +146,7 @@ def __init__(self, message): class RpcError(Exception): ... -def _rpc( +def import_and_run_function( module_name: str, function_name: str, expected_type: type[T] | None, diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 6e5fa6ee0..c18e98a68 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -1,7 +1,7 @@ import uuid from collections.abc import Iterator from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import jwt import pytest @@ -14,40 +14,62 @@ from blueapi.config import ApplicationConfig, OIDCConfig from blueapi.core.bluesky_types import Plan from blueapi.service import main +from blueapi.service.interface import ( + cancel_active_task, + get_device, + get_plan, + pause_worker, + resume_worker, + submit_task, +) from blueapi.service.model import ( DeviceModel, + EnvironmentResponse, PlanModel, StateChangeRequest, WorkerTask, ) +from blueapi.service.runner import WorkerDispatcher from blueapi.worker.event import WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask +class MockCountModel(BaseModel): ... + + +COUNT = Plan(name="count", model=MockCountModel) + + +@pytest.fixture +def mock_runner() -> Mock: + return Mock(spec=WorkerDispatcher) + + @pytest.fixture -def client() -> Iterator[TestClient]: +def client(mock_runner: Mock) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): - main.setup_runner(use_subprocess=False) + main.setup_runner(runner=mock_runner) yield TestClient(main.get_app(ApplicationConfig())) main.teardown_runner() @pytest.fixture -def client_with_auth(oidc_config: OIDCConfig) -> Iterator[TestClient]: +def client_with_auth( + mock_runner: Mock, oidc_config: OIDCConfig +) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): - main.setup_runner(use_subprocess=False) + main.setup_runner(runner=mock_runner) yield TestClient(main.get_app(ApplicationConfig(oidc=oidc_config))) main.teardown_runner() -@patch("blueapi.service.interface.get_plans") -def test_get_plans(get_plans_mock: MagicMock, client: TestClient) -> None: +def test_get_plans(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plans_mock.return_value = [PlanModel.from_plan(plan)] + mock_runner.run.return_value = [PlanModel.from_plan(plan)] response = client.get("/plans") @@ -68,17 +90,16 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_plan_by_name(get_plan_mock: MagicMock, client: TestClient) -> None: +def test_get_plan_by_name(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) + mock_runner.run.return_value = PlanModel.from_plan(plan) response = client.get("/plans/my-plan") - get_plan_mock.assert_called_once_with("my-plan") + mock_runner.run.assert_called_once_with(get_plan, "my-plan") assert response.status_code == status.HTTP_200_OK assert response.json() == { "description": None, @@ -92,25 +113,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_non_existent_plan_by_name( - get_plan_mock: MagicMock, client: TestClient -) -> None: - get_plan_mock.side_effect = KeyError("my-plan") +def test_get_non_existent_plan_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-plan") response = client.get("/plans/my-plan") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_devices") -def test_get_devices(get_devices_mock: MagicMock, client: TestClient) -> None: +def test_get_devices(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_devices_mock.return_value = [DeviceModel.from_device(device)] + mock_runner.run.return_value = [DeviceModel.from_device(device)] response = client.get("/devices") @@ -125,18 +142,17 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_device_by_name(get_device_mock: MagicMock, client: TestClient) -> None: +def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_device_mock.return_value = DeviceModel.from_device(device) + mock_runner.run.return_value = DeviceModel.from_device(device) response = client.get("/devices/my-device") - get_device_mock.assert_called_once_with("my-device") + mock_runner.run.assert_called_once_with(get_device, "my-device") assert response.status_code == status.HTTP_200_OK assert response.json() == { "name": "my-device", @@ -144,51 +160,44 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_non_existent_device_by_name( - get_device_mock: MagicMock, client: TestClient -) -> None: - get_device_mock.side_effect = KeyError("my-device") +def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-device") response = client.get("/devices/my-device") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task(mock_runner: Mock, client: TestClient) -> None: task = Task(name="count", params={"detectors": ["x"]}) task_id = str(uuid.uuid4()) - submit_task_mock.return_value = task_id + mock_runner.run.side_effect = [COUNT, task_id] response = client.post("/tasks", json=task.model_dump()) - submit_task_mock.assert_called_once_with(task) + mock_runner.run.assert_called_with(submit_task, task) assert response.json() == {"task_id": task_id} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task_validation_error( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) - submit_task_mock.side_effect = ValidationError.from_exception_data( - title="ValueError", - line_errors=[ - InitErrorDetails( - type="missing", loc=("id",), msg="value is required for Identifier" - ) # type: ignore - ], - ) + + mock_runner.run.side_effect = [ + PlanModel.from_plan(plan), + ValidationError.from_exception_data( + title="ValueError", + line_errors=[ + InitErrorDetails( + type="missing", loc=("id",), msg="value is required for Identifier" + ) # type: ignore + ], + ), + ] + response = client.post("/tasks", json={"name": "my-plan"}) assert response.status_code == 422 assert response.json() == { @@ -202,32 +211,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_begins_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_begins_task(client: TestClient) -> None: task_id = "04cd9aa6-b902-414b-ae4b-49ea4200e957" - # Set to idle - get_active_task_mock.return_value = None - begin_task_mock.return_value = WorkerTask(task_id=task_id) - resp = client.put("/worker/task", json={"task_id": task_id}) assert resp.status_code == status.HTTP_200_OK assert resp.json() == {"task_id": task_id} -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_fails_if_not_idle( - get_active_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> None: task_id_current = "260f7de3-b608-4cdc-a66c-257e95809792" task_id_new = "07e98d68-21b5-4ad7-ac34-08b2cb992d42" # Set to non idle - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task=None, task_id=task_id_current, is_complete=False ) @@ -237,8 +235,7 @@ def test_put_plan_fails_if_not_idle( assert resp.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_tasks") -def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: +def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask(task_id="0", task=Task(name="sleep", params={"time": 0.0})), TrackableTask( @@ -249,7 +246,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: ), ] - get_tasks_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK @@ -276,10 +273,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: } -@patch("blueapi.service.interface.get_tasks_by_status") -def test_get_tasks_by_status( - get_tasks_by_status_mock: MagicMock, client: TestClient -) -> None: +def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask( task_id="3", @@ -289,7 +283,7 @@ def test_get_tasks_by_status( ), ] - get_tasks_by_status_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks", params={"task_status": "PENDING"}) assert response.json() == { @@ -311,19 +305,14 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST -@patch("blueapi.service.interface.clear_task") -def test_delete_submitted_task(clear_task_mock: MagicMock, client: TestClient) -> None: +def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) - clear_task_mock.return_value = task_id + mock_runner.run.return_value = task_id response = client.delete(f"/tasks/{task_id}") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_set_active_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_set_active_task(client: TestClient) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) @@ -333,15 +322,13 @@ def test_set_active_task( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_active_task_complete( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_completed_task"), is_complete=True, @@ -354,15 +341,13 @@ def test_set_active_task_active_task_complete( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_worker_already_running( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_running_task"), is_complete=False, @@ -375,15 +360,14 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task(get_task_by_id: MagicMock, client: TestClient): +def test_get_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_task_by_id.return_value = task + mock_runner.run.return_value = task response = client.get(f"/tasks/{task_id}") assert response.json() == { @@ -396,8 +380,7 @@ def test_get_task(get_task_by_id: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_tasks") -def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): +def test_get_all_tasks(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) tasks = [ TrackableTask( @@ -406,7 +389,7 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): ) ] - get_all_tasks.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK assert response.json() == { @@ -423,138 +406,108 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task_error(get_task_by_id_mock: MagicMock, client: TestClient): +def test_get_task_error(mock_runner: Mock, client: TestClient): task_id = 567 - get_task_by_id_mock.return_value = None + mock_runner.run.return_value = None response = client.get(f"/tasks/{task_id}") assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task(get_active_task_mock: MagicMock, client: TestClient): +def test_get_active_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_active_task_mock.return_value = task + mock_runner.run.return_value = task response = client.get("/worker/task") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task_none(get_active_task_mock: MagicMock, client: TestClient): - get_active_task_mock.return_value = None +def test_get_active_task_none(mock_runner: Mock, client: TestClient): + mock_runner.run.return_value = None response = client.get("/worker/task") assert response.json() == {"task_id": None} -@patch("blueapi.service.interface.get_worker_state") -def test_get_state(get_worker_state_mock: MagicMock, client: TestClient): +def test_get_state(mock_runner: Mock, client: TestClient): state = WorkerState.SUSPENDING - get_worker_state_mock.return_value = state + mock_runner.run.return_value = state response = client.get("/worker/state") assert response.json() == state -@patch("blueapi.service.interface.pause_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_paused( - get_worker_state_mock: MagicMock, pause_worker_mock: MagicMock, client: TestClient -): +def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - pause_worker_mock.assert_called_once_with(False) + mock_runner.run.assert_any_call(pause_worker, False) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.resume_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_paused_to_running( - get_worker_state_mock: MagicMock, resume_worker_mock: MagicMock, client: TestClient -): +def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - resume_worker_mock.assert_called_once() + mock_runner.run.assert_any_call(resume_worker) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_aborting( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - cancel_active_task_mock.assert_called_once_with(True, None) + mock_runner.run.assert_any_call(cancel_active_task, True, None) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") def test_set_state_running_to_stopping_including_reason( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, + mock_runner: Mock, client: TestClient ): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state, reason=reason).model_dump(), ) - cancel_active_task_mock.assert_called_once_with(False, reason) + mock_runner.run.assert_any_call(cancel_active_task, False, reason) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_transition_error( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] - - cancel_active_task_mock.side_effect = TransitionError() + mock_runner.run.side_effect = [current_state, TransitionError(), final_state] response = client.put( "/worker/state", @@ -565,15 +518,12 @@ def test_set_state_transition_error( assert response.json() == final_state -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_invalid_transition( - get_worker_state_mock: MagicMock, client: TestClient -): +def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): current_state = WorkerState.STOPPING requested_state = WorkerState.PAUSED final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, final_state] response = client.put( "/worker/state", @@ -584,14 +534,19 @@ def test_set_state_invalid_transition( assert response.json() == final_state -def test_get_environment_idle(client: TestClient) -> None: +def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: + mock_runner.state = EnvironmentResponse( + initialized=True, + error_message=None, + ) + assert client.get("/environment").json() == { "initialized": True, "error_message": None, } -def test_delete_environment(client: TestClient) -> None: +def test_delete_environment(mock_runner: Mock, client: TestClient) -> None: response = client.delete("/environment") assert response.status_code is status.HTTP_200_OK @@ -604,11 +559,8 @@ def test_subprocess_enabled_by_default(mp_pool_mock: MagicMock): main.teardown_runner() -@patch("blueapi.service.interface.get_device") -def test_get_without_authentication( - get_device_mock: MagicMock, client: TestClient -) -> None: - get_device_mock.side_effect = jwt.PyJWTError +def test_get_without_authentication(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = jwt.PyJWTError response = client.get("/devices/my-device") assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -621,14 +573,13 @@ def test_oidc_config_not_found_when_auth_is_disabled(client: TestClient): assert response.json() == {"detail": "Not Found"} -@patch("blueapi.service.interface.get_oidc_config") def test_get_oidc_config( - get_oidc_config: MagicMock, + mock_runner: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): - get_oidc_config.return_value = oidc_config + mock_runner.run.return_value = oidc_config response = client_with_auth.get("/config/oidc") assert response.status_code == status.HTTP_200_OK assert response.json() == oidc_config.model_dump() diff --git a/tests/unit_tests/service/test_runner.py b/tests/unit_tests/service/test_runner.py index 1162d8108..2c7c4b7b0 100644 --- a/tests/unit_tests/service/test_runner.py +++ b/tests/unit_tests/service/test_runner.py @@ -1,6 +1,6 @@ +from multiprocessing.pool import Pool as PoolClass from typing import Any, Generic, TypeVar -from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from observability_utils.tracing import ( @@ -16,17 +16,19 @@ InvalidRunnerStateError, RpcError, WorkerDispatcher, + import_and_run_function, ) @pytest.fixture -def local_runner(): - return WorkerDispatcher(use_subprocess=False) +def mock_subprocess() -> Mock: + subprocess = Mock(spec=PoolClass) + return subprocess @pytest.fixture -def runner(): - return WorkerDispatcher() +def runner(mock_subprocess: Mock): + return WorkerDispatcher(subprocess_factory=lambda: mock_subprocess) @pytest.fixture @@ -36,13 +38,22 @@ def started_runner(runner: WorkerDispatcher): runner.stop() -def test_initialize(runner: WorkerDispatcher): +def test_initialize(runner: WorkerDispatcher, mock_subprocess: Mock): + mock_subprocess.apply.return_value = None + + assert runner.state.error_message is None assert not runner.state.initialized runner.start() + + assert runner.state.error_message is None assert runner.state.initialized + # Run a single call to the runner for coverage of dispatch to subprocess - assert runner.run(interface.get_worker_state) + mock_subprocess.apply.return_value = 123 + assert runner.run(interface.get_worker_state) == 123 runner.stop() + + assert runner.state.error_message is None assert not runner.state.initialized @@ -59,22 +70,20 @@ def test_raises_if_used_before_started(runner: WorkerDispatcher): runner.run(interface.get_plans) -def test_error_on_runner_setup(local_runner: WorkerDispatcher): +def test_error_on_runner_setup(runner: WorkerDispatcher, mock_subprocess: Mock): + error_message = "Intentional start_worker exception" expected_state = EnvironmentResponse( initialized=False, - error_message="Intentional start_worker exception", + error_message=error_message, ) + mock_subprocess.apply.side_effect = Exception(error_message) - with mock.patch( - "blueapi.service.runner.setup", - side_effect=Exception("Intentional start_worker exception"), - ): - # Calling reload here instead of start also indirectly - # tests that stop() doesn't raise if there is no error message - # and the runner is not yet initialised - local_runner.reload() - state = local_runner.state - assert state == expected_state + # Calling reload here instead of start also indirectly + # tests that stop() doesn't raise if there is no error message + # and the runner is not yet initialised + runner.reload() + state = runner.state + assert state == expected_state def start_worker_mock(): @@ -99,7 +108,7 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): another_mock.apply.side_effect = subprocess_calls_return_values - runner = WorkerDispatcher(use_subprocess=True) + runner = WorkerDispatcher() runner.start() assert runner.state == EnvironmentResponse( @@ -111,55 +120,59 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): assert runner.state == EnvironmentResponse(initialized=True, error_message=None) -def test_function_not_findable_on_subprocess(started_runner: WorkerDispatcher): - from tests.unit_tests.core.fake_device_module import fake_motor_y - - # Valid target on main but not sub process - # Change in this process not reflected in subprocess - fake_motor_y.__name__ = "not_exported" - - with pytest.raises( - RpcError, match="not_exported: No such function in subprocess API" - ): - started_runner.run(fake_motor_y) +@patch("blueapi.service.runner.Pool") +def test_subprocess_enabled_by_default(pool_mock: MagicMock): + runner = WorkerDispatcher() + runner.start() + pool_mock.assert_called_once() + runner.stop() -def test_non_callable_excepts_in_main_process(started_runner: WorkerDispatcher): - # Not a valid target on main or sub process - from tests.unit_tests.core.fake_device_module import fetchable_non_callable +def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher): + non_fetchable_callable = MagicMock() with pytest.raises( RpcError, - match=" is not Callable, " - + "cannot be run in subprocess", + match=" is anonymous, cannot be run in subprocess", ): - started_runner.run(fetchable_non_callable) + started_runner.run(non_fetchable_callable) -def test_non_callable_excepts_in_sub_process(started_runner: WorkerDispatcher): - # Valid target on main but finds non-callable in sub process - from tests.unit_tests.core.fake_device_module import ( - fetchable_callable, - fetchable_non_callable, - ) +def test_function_not_findable_on_subprocess(): + with pytest.raises(RpcError, match="unknown: No such function in subprocess API"): + import_and_run_function("blueapi", "unknown", None, {}) - fetchable_callable.__name__ = fetchable_non_callable.__name__ - with pytest.raises( - RpcError, - match="fetchable_non_callable: Object in subprocess is not a function", - ): - started_runner.run(fetchable_callable) +def test_module_not_findable_on_subprocess(): + with pytest.raises(ModuleNotFoundError): + import_and_run_function("unknown", "unknown", None, {}) -def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher): - non_fetchable_callable = MagicMock() +def run_rpc_function( + func: Callable[..., Any], + expected_type: type[Any], + *args: Any, + **kwargs: Any, +) -> Any: + import_and_run_function( + func.__module__, + func.__name__, + expected_type, + {}, + *args, + **kwargs, + ) + + +def test_non_callable_excepts(started_runner: WorkerDispatcher): + # Not a valid target on main or sub process + from tests.unit_tests.core.fake_device_module import fetchable_non_callable with pytest.raises( RpcError, - match=" is anonymous, cannot be run in subprocess", + match="fetchable_non_callable: Object in subprocess is not a function", ): - started_runner.run(non_fetchable_callable) + run_rpc_function(fetchable_non_callable, Mock) def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher): @@ -169,7 +182,7 @@ def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher): ValidationError, match="1 validation error for int", ): - started_runner.run(wrong_return_type) + run_rpc_function(wrong_return_type, int) T = TypeVar("T") diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 72e98c9d2..10d5e03ff 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -241,7 +241,7 @@ def test_reset_env_client_behavior( reload_result = runner.invoke(main, ["controller", "env", "-r"]) # Verify if sleep was called between polling iterations - assert mock_sleep.call_count == 2 # Since the last check doesn't require a sleep + mock_sleep.assert_called() for index, call in enumerate(responses.calls): if index == 0: