Skip to content

Commit

Permalink
store Mocks for Devices and MockSignalBackends for Signals in diction…
Browse files Browse the repository at this point in the history
…aries
  • Loading branch information
jsouter committed Oct 22, 2024
1 parent 9c1d9c2 commit 49c78f6
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 23 deletions.
18 changes: 14 additions & 4 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ._protocol import Connectable
from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection

_device_mocks: dict[Device, Mock] = {}


class DeviceConnector:
async def connect(
Expand Down Expand Up @@ -56,7 +58,6 @@ def __init__(
self, name: str = "", connector: DeviceConnector | None = None
) -> None:
self._connector = connector or DeviceConnector()
self.mock = None
self.set_name(name)

@property
Expand Down Expand Up @@ -107,16 +108,18 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""
uses_mock = bool(mock)
can_use_previous_connect = (
bool(mock) is self._connect_mock_arg
uses_mock is self._connect_mock_arg
and self._connect_task
and not (self._connect_task.done() and self._connect_task.exception())
)
if mock is True:
mock = Mock() # create a new Mock if one not provided
self.mock = mock
if force_reconnect or not can_use_previous_connect:
self._connect_mock_arg = bool(mock)
self._connect_mock_arg = uses_mock
if self._connect_mock_arg:
_device_mocks[self] = mock
coro = self._connector.connect(
device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect
)
Expand All @@ -126,6 +129,10 @@ async def connect(
# Wait for it to complete
await self._connect_task

def __hash__(self):
# to allow Devices to be used as dict keys
return hash(id(self))


DeviceT = TypeVar("DeviceT", bound=Device)

Expand Down Expand Up @@ -180,6 +187,9 @@ def children(self) -> Iterator[tuple[str, Device]]:
for key, child in self._children.items():
yield str(key), child

def __hash__(self): # to allow DeviceVector to be used as dict keys
return hash(id(self))


class DeviceCollector:
"""Collector of top level Device instances to be used as a context manager
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/_mock_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]):
def __init__(
self,
initial_backend: SignalBackend[SignalDatatypeT],
mock: bool | Mock = True,
mock: Mock,
) -> None:
if isinstance(initial_backend, MockSignalBackend):
raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend")
Expand All @@ -33,7 +33,7 @@ def __init__(
)

# use existing Mock if provided
self.mock = Mock() if isinstance(mock, bool) else mock
self.mock = mock
self.mock.attach_mock(AsyncMock(name="put", spec=Callable), "put")

super().__init__(datatype=self.initial_backend.datatype)
Expand Down
16 changes: 7 additions & 9 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
from typing import Any
from unittest.mock import AsyncMock, Mock

from ._device import Device
from ._device import Device, _device_mocks
from ._mock_signal_backend import MockSignalBackend
from ._signal import Signal
from ._signal import Signal, _mock_signal_backends
from ._soft_signal_backend import SignalDatatypeT


def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
backend = signal._connector.backend # noqa:SLF001
assert isinstance(backend, MockSignalBackend), (
"Expected to receive a `MockSignalBackend`, instead "
f" received {type(backend)}. "
)
return backend
assert (
signal in _mock_signal_backends
), "No `MockSignalBackend` registered for signal."
return _mock_signal_backends[signal]


def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT):
Expand Down Expand Up @@ -51,7 +49,7 @@ def get_mock_put(signal: Signal) -> AsyncMock:
def get_mock(device: Device | Signal) -> Mock:
if isinstance(device, Signal):
return _get_mock_signal_backend(device).mock
return device.mock
return _device_mocks[device]


def reset_mock_put_calls(signal: Signal):
Expand Down
3 changes: 3 additions & 0 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from ._status import AsyncStatus
from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T

_mock_signal_backends: dict[Device, MockSignalBackend] = {}


async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T:
try:
Expand Down Expand Up @@ -61,6 +63,7 @@ async def connect(
):
if mock:
self.backend = MockSignalBackend(self._init_backend, mock)
_mock_signal_backends[device] = self.backend
else:
self.backend = self._init_backend
device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}")
Expand Down
5 changes: 1 addition & 4 deletions tests/core/test_mock_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
exc_msgs.append(str(exc.value))

for msg in exc_msgs:
assert msg == (
"Expected to receive a `MockSignalBackend`, instead "
f" received {SoftSignalBackend}. "
)
assert msg == "No `MockSignalBackend` registered for signal."


async def test_get_mock_put():
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import time
from asyncio import Event
from unittest.mock import ANY
from unittest.mock import ANY, Mock

import pytest
from bluesky.protocols import Reading
Expand Down Expand Up @@ -42,7 +42,7 @@ def num_occurrences(substring: str, string: str) -> int:

async def test_signal_connects_to_previous_backend(caplog):
caplog.set_level(logging.DEBUG)
int_mock_backend = MockSignalBackend(SoftSignalBackend(int))
int_mock_backend = MockSignalBackend(SoftSignalBackend(int), Mock())
original_connect = int_mock_backend.connect
times_backend_connect_called = 0

Expand All @@ -61,7 +61,7 @@ async def new_connect(timeout=1):

async def test_signal_connects_with_force_reconnect(caplog):
caplog.set_level(logging.DEBUG)
signal = Signal(MockSignalBackend(SoftSignalBackend(int)))
signal = Signal(MockSignalBackend(SoftSignalBackend(int), Mock()))
await signal.connect()
assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1
await signal.connect(force_reconnect=True)
Expand All @@ -80,7 +80,7 @@ async def connect(self, timeout=DEFAULT_TIMEOUT):
self.succeed_on_connect = True
raise RuntimeError("connect fail")

signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int)))
signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int), Mock()))

with pytest.raises(RuntimeError, match="connect fail"):
await signal.connect(mock=False)
Expand Down

0 comments on commit 49c78f6

Please sign in to comment.