Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow shared parent mock to be passed to Device.connect #599

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._mock_signal_backend import MockSignalBackend
from ._mock_signal_utils import (
callback_on_mock_put,
get_mock,
get_mock_put,
mock_puts_blocked,
reset_mock_put_calls,
Expand Down Expand Up @@ -116,6 +117,7 @@
"config_ophyd_async_logging",
"MockSignalBackend",
"callback_on_mock_put",
"get_mock",
"get_mock_put",
"mock_puts_blocked",
"reset_mock_put_calls",
Expand Down
29 changes: 20 additions & 9 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from collections.abc import Coroutine, Iterator, Mapping, MutableMapping
from logging import LoggerAdapter, getLogger
from typing import Any, TypeVar
from unittest.mock import Mock

from bluesky.protocols import HasName
from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop

from ._protocol import Connectable
from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection

_device_mocks: dict[Device, Mock] = {}


class DeviceConnector:
"""Defines how a `Device` should be connected and type hints processed."""
Expand All @@ -37,7 +40,7 @@ def create_children_from_annotations(self, device: Device):
async def connect(
self,
device: Device,
mock: bool,
mock: bool | Mock,
timeout: float,
force_reconnect: bool,
):
Expand All @@ -47,12 +50,12 @@ async def connect(
done in a different mock more. It should connect the Device and all its
children.
"""
coros = {
name: child_device.connect(
mock=mock, timeout=timeout, force_reconnect=force_reconnect
coros = {}
for name, child_device in device.children():
child_mock = getattr(mock, name) if mock else mock # Mock() or False
coros[name] = child_device.connect(
mock=child_mock, timeout=timeout, force_reconnect=force_reconnect
)
for name, child_device in device.children()
}
await wait_for_connection(**coros)


Expand Down Expand Up @@ -114,7 +117,7 @@ def __setattr__(self, name: str, value: Any) -> None:

async def connect(
self,
mock: bool = False,
mock: bool | Mock = False,
timeout: float = DEFAULT_TIMEOUT,
force_reconnect: bool = False,
) -> None:
Expand All @@ -129,13 +132,18 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""
uses_mock = bool(mock)
can_use_previous_connect = (
mock is self._connect_mock_arg
uses_mock is self._connect_mock_arg
and self._connect_task
and not (self._connect_task.done() and self._connect_task.exception())
)
if mock is True:
mock = Mock() # create a new Mock if one not provided
if force_reconnect or not can_use_previous_connect:
self._connect_mock_arg = mock
self._connect_mock_arg = uses_mock
if self._connect_mock_arg:
_device_mocks[self] = mock
coro = self._connector.connect(
device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect
)
Expand Down Expand Up @@ -198,6 +206,9 @@ def children(self) -> Iterator[tuple[str, Device]]:
for key, child in self._children.items():
yield str(key), child

def __hash__(self): # to allow DeviceVector to be used as dict keys and in sets
return hash(id(self))


class DeviceCollector:
"""Collector of top level Device instances to be used as a context manager
Expand Down
18 changes: 12 additions & 6 deletions src/ophyd_async/core/_mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from collections.abc import Callable
from functools import cached_property
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock

from bluesky.protocols import Descriptor, Reading

Expand All @@ -13,7 +13,11 @@
class MockSignalBackend(SignalBackend[SignalDatatypeT]):
"""Signal backend for testing, created by ``Device.connect(mock=True)``."""

def __init__(self, initial_backend: SignalBackend[SignalDatatypeT]) -> None:
def __init__(
self,
initial_backend: SignalBackend[SignalDatatypeT],
mock: Mock,
) -> None:
if isinstance(initial_backend, MockSignalBackend):
raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend")

Expand All @@ -27,6 +31,12 @@ def __init__(self, initial_backend: SignalBackend[SignalDatatypeT]) -> None:
self.soft_backend = SoftSignalBackend(
datatype=self.initial_backend.datatype
)

# use existing Mock if provided
self.mock = mock
self.put_mock = AsyncMock(name="put", spec=Callable)
self.mock.attach_mock(self.put_mock, "put")

super().__init__(datatype=self.initial_backend.datatype)

def set_value(self, value: SignalDatatypeT):
Expand All @@ -38,10 +48,6 @@ def source(self, name: str, read: bool) -> str:
async def connect(self, timeout: float) -> None:
pass

@cached_property
def put_mock(self) -> AsyncMock:
return AsyncMock(name="put", spec=Callable)

@cached_property
def put_proceeds(self) -> asyncio.Event:
put_proceeds = asyncio.Event()
Expand Down
21 changes: 13 additions & 8 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from collections.abc import Awaitable, Callable, Iterable
from contextlib import asynccontextmanager, contextmanager
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock

from ._device import Device, _device_mocks
from ._mock_signal_backend import MockSignalBackend
from ._signal import Signal, SignalR
from ._signal import Signal, SignalR, _mock_signal_backends
from ._soft_signal_backend import SignalDatatypeT


def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend:
backend = signal._connector.backend # noqa:SLF001
assert isinstance(backend, MockSignalBackend), (
"Expected to receive a `MockSignalBackend`, instead "
f" received {type(backend)}. "
)
return backend
assert (
signal in _mock_signal_backends
), f"Signal {signal} not connected in mock mode"
return _mock_signal_backends[signal]


def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT):
Expand Down Expand Up @@ -46,6 +45,12 @@ def get_mock_put(signal: Signal) -> AsyncMock:
return _get_mock_signal_backend(signal).put_mock


def get_mock(device: Device | Signal) -> Mock:
if isinstance(device, Signal):
return _get_mock_signal_backend(device).mock
return _device_mocks[device]


def reset_mock_put_calls(signal: Signal):
backend = _get_mock_signal_backend(signal)
coretl marked this conversation as resolved.
Show resolved Hide resolved
backend.put_mock.reset_mock()
Expand Down
4 changes: 3 additions & 1 deletion src/ophyd_async/core/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from ._utils import DEFAULT_TIMEOUT

if TYPE_CHECKING:
from unittest.mock import Mock

from ._status import AsyncStatus


Expand All @@ -24,7 +26,7 @@ class Connectable(Protocol):
@abstractmethod
async def connect(
self,
mock: bool = False,
mock: bool | Mock = False,
timeout: float = DEFAULT_TIMEOUT,
force_reconnect: bool = False,
):
Expand Down
8 changes: 6 additions & 2 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from typing import Any, Generic, cast
from unittest.mock import Mock

from bluesky.protocols import (
Locatable,
Expand Down Expand Up @@ -31,6 +32,8 @@
from ._status import AsyncStatus
from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T

_mock_signal_backends: dict[Device, MockSignalBackend] = {}


async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T:
try:
Expand All @@ -54,12 +57,13 @@ def __init__(self, backend: SignalBackend):
async def connect(
self,
device: Device,
mock: bool,
mock: bool | Mock,
timeout: float,
force_reconnect: bool,
):
if mock:
self.backend = MockSignalBackend(self._init_backend)
self.backend = MockSignalBackend(self._init_backend, mock)
_mock_signal_backends[device] = self.backend
else:
self.backend = self._init_backend
device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}")
Expand Down
4 changes: 3 additions & 1 deletion src/ophyd_async/epics/pvi/_pvi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from unittest.mock import Mock

from ophyd_async.core import (
Device,
DeviceConnector,
Expand Down Expand Up @@ -41,7 +43,7 @@ def create_children_from_annotations(self, device: Device):
)

async def connect(
self, device: Device, mock: bool, timeout: float, force_reconnect: bool
self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool
) -> None:
if mock:
# Make 2 entries for each DeviceVector
Expand Down
4 changes: 3 additions & 1 deletion src/ophyd_async/plan_stubs/_ensure_connected.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from unittest.mock import Mock

import bluesky.plan_stubs as bps

from ophyd_async.core import DEFAULT_TIMEOUT, Device, wait_for_connection


def ensure_connected(
*devices: Device,
mock: bool = False,
mock: bool | Mock = False,
timeout: float = DEFAULT_TIMEOUT,
force_reconnect=False,
):
Expand Down
3 changes: 2 additions & 1 deletion src/ophyd_async/tango/base_devices/_base_device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TypeVar
from unittest.mock import Mock

from ophyd_async.core import Device, DeviceConnector, DeviceFiller
from ophyd_async.tango.signal import (
Expand Down Expand Up @@ -114,7 +115,7 @@ def create_children_from_annotations(self, device: Device):
)

async def connect(
self, device: Device, mock: bool, timeout: float, force_reconnect: bool
self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool
) -> None:
if mock:
# Make 2 entries for each DeviceVector
Expand Down
11 changes: 7 additions & 4 deletions tests/core/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,24 @@ async def test_device_with_children_lazily_connects(RE):
)


async def test_no_reconnect_signals_if_not_forced():
@pytest.mark.parametrize("use_Mock", [False, True])
async def test_no_reconnect_signals_if_not_forced(use_Mock):
parent = DummyDeviceGroup("parent")

connect_mock_arg = Mock() if use_Mock else True

async def inner_connect(mock, timeout, force_reconnect):
parent.child1.connected = True

parent.child1.connect = Mock(side_effect=inner_connect)
await parent.connect(mock=True, timeout=0.01)
await parent.connect(mock=connect_mock_arg, timeout=0.01)
assert parent.child1.connected
assert parent.child1.connect.call_count == 1
await parent.connect(mock=True, timeout=0.01)
await parent.connect(mock=connect_mock_arg, timeout=0.01)
assert parent.child1.connected
assert parent.child1.connect.call_count == 1

for count in range(2, 10):
await parent.connect(mock=True, timeout=0.01, force_reconnect=True)
await parent.connect(mock=connect_mock_arg, timeout=0.01, force_reconnect=True)
assert parent.child1.connected
assert parent.child1.connect.call_count == count
7 changes: 4 additions & 3 deletions tests/core/test_mock_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
exc_msgs.append(str(exc.value))

for msg in exc_msgs:
assert msg == (
"Expected to receive a `MockSignalBackend`, instead "
f" received {SoftSignalBackend}. "
assert re.match(
r"Signal <ophyd_async.core._signal.SignalRW object at [0-9a-z]+> "
r"not connected in mock mode",
msg,
)


Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import time
from asyncio import Event
from unittest.mock import ANY
from unittest.mock import ANY, Mock

import pytest
from bluesky.protocols import Reading
Expand Down Expand Up @@ -42,7 +42,7 @@ def num_occurrences(substring: str, string: str) -> int:

async def test_signal_connects_to_previous_backend(caplog):
caplog.set_level(logging.DEBUG)
int_mock_backend = MockSignalBackend(SoftSignalBackend(int))
int_mock_backend = MockSignalBackend(SoftSignalBackend(int), Mock())
original_connect = int_mock_backend.connect
times_backend_connect_called = 0

Expand All @@ -61,7 +61,7 @@ async def new_connect(timeout=1):

async def test_signal_connects_with_force_reconnect(caplog):
caplog.set_level(logging.DEBUG)
signal = Signal(MockSignalBackend(SoftSignalBackend(int)))
signal = Signal(MockSignalBackend(SoftSignalBackend(int), Mock()))
await signal.connect()
assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1
await signal.connect(force_reconnect=True)
Expand All @@ -80,7 +80,7 @@ async def connect(self, timeout=DEFAULT_TIMEOUT):
self.succeed_on_connect = True
raise RuntimeError("connect fail")

signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int)))
signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int), Mock()))

with pytest.raises(RuntimeError, match="connect fail"):
await signal.connect(mock=False)
Expand Down
Loading
Loading