From b8346fc2531db83eb44b9bd76fc362fc22a91448 Mon Sep 17 00:00:00 2001 From: Brad Keryan Date: Tue, 3 Oct 2023 14:39:21 -0500 Subject: [PATCH] service: Add generic create_session(s) API (#426) * service: Add a helper function to get the gRPC device channel * tests: Add tests for get_insecure_grpc_device_channel * service: Add closing_session helper function * tests: Add unit tests for closing_session * service: Add generic create_session(s) methods * tests: Add tests for create_session(s) * service: Add TypedSessionInformation * service: Remove an unnecessary ExitStack * tests: Add comment about api_key * tests: Rename a couple test cases * service: Extract _get_matching_session_infos and refactor _with_session * tests: Update test cases * tests: Add a type-checking test case * service: Make instrument_type_id required for generic create_session(s) * tests: Update grpcdevice tests * service: Remove reference to session tuple * service: Update docstrings --- .../_drivers/__init__.py | 28 ++ .../_drivers/_grpcdevice.py | 49 ++++ .../session_management.py | 206 ++++++++++++++- tests/unit/_drivers/__init__.py | 1 + tests/unit/_drivers/test_grpcdevice.py | 64 +++++ tests/unit/test_drivers.py | 26 ++ tests/unit/test_reservation.py | 250 ++++++++++++++++++ tests/unit/test_session_management.py | 9 + tests/utilities/fake_driver.py | 148 +++++++++++ 9 files changed, 775 insertions(+), 6 deletions(-) create mode 100644 ni_measurementlink_service/_drivers/__init__.py create mode 100644 ni_measurementlink_service/_drivers/_grpcdevice.py create mode 100644 tests/unit/_drivers/__init__.py create mode 100644 tests/unit/_drivers/test_grpcdevice.py create mode 100644 tests/unit/test_drivers.py create mode 100644 tests/unit/test_reservation.py create mode 100644 tests/utilities/fake_driver.py diff --git a/ni_measurementlink_service/_drivers/__init__.py b/ni_measurementlink_service/_drivers/__init__.py new file mode 100644 index 000000000..34a85bdf6 --- /dev/null +++ b/ni_measurementlink_service/_drivers/__init__.py @@ -0,0 +1,28 @@ +"""Shared code for interfacing with driver APIs.""" +from __future__ import annotations + +import contextlib +from typing import ContextManager, TypeVar + +TSession = TypeVar("TSession") + + +def closing_session(session: TSession) -> ContextManager[TSession]: + """Create a context manager that closes the session. + + Args: + session: A driver session. + + Returns: + A context manager that yields the session and closes it. + """ + if isinstance(session, contextlib.AbstractContextManager): + # Assume the session yields itself. + return session + elif hasattr(session, "close"): + return contextlib.closing(session) + else: + raise TypeError( + f"Invalid session type '{type(session)}'. A session must be a context manager and/or " + "have a close() method." + ) diff --git a/ni_measurementlink_service/_drivers/_grpcdevice.py b/ni_measurementlink_service/_drivers/_grpcdevice.py new file mode 100644 index 000000000..10ca95882 --- /dev/null +++ b/ni_measurementlink_service/_drivers/_grpcdevice.py @@ -0,0 +1,49 @@ +"""Shared functions for interacting with NI gRPC Device Server.""" +from __future__ import annotations + +from typing import Optional + +import grpc + +from ni_measurementlink_service._channelpool import GrpcChannelPool +from ni_measurementlink_service._configuration import ( + GRPC_DEVICE_ADDRESS, + USE_GRPC_DEVICE_SERVER, +) +from ni_measurementlink_service._internal.discovery_client import DiscoveryClient + +SERVICE_CLASS = "ni.measurementlink.v1.grpcdeviceserver" +"""The service class for NI gRPC Device Server.""" + + +def get_insecure_grpc_device_channel( + discovery_client: DiscoveryClient, + grpc_channel_pool: GrpcChannelPool, + provided_interface: str, +) -> Optional[grpc.Channel]: + """Get an unencrypted gRPC channel targeting NI gRPC Device Server. + + Args: + discovery_client: The discovery client. + + grpc_channel_pool: The gRPC channel pool. + + provided_interface: The driver API's NI gRPC Device Server interface + name. + + Returns: + A gRPC channel targeting the NI gRPC Device Server, or ``None`` if the + configuration file specifies that ``USE_GRPC_DEVICE_SERVER`` is false. + """ + if not USE_GRPC_DEVICE_SERVER: + return None + + address = GRPC_DEVICE_ADDRESS + if not address: + service_location = discovery_client.resolve_service( + provided_interface=provided_interface, + service_class=SERVICE_CLASS, + ) + address = service_location.insecure_address + + return grpc_channel_pool.get_channel(address) diff --git a/ni_measurementlink_service/session_management.py b/ni_measurementlink_service/session_management.py index 29c3af87c..90f5260e9 100644 --- a/ni_measurementlink_service/session_management.py +++ b/ni_measurementlink_service/session_management.py @@ -2,29 +2,44 @@ from __future__ import annotations import abc +import contextlib import logging import sys import threading import warnings +from contextlib import ExitStack from functools import cached_property from types import TracebackType from typing import ( TYPE_CHECKING, Any, + Callable, + ContextManager, + Dict, + Generator, + Generic, Iterable, List, Literal, NamedTuple, Optional, + Protocol, Sequence, Type, + TypeVar, Union, + cast, ) import grpc from deprecation import DeprecatedWarning from ni_measurementlink_service._channelpool import GrpcChannelPool +from ni_measurementlink_service._drivers import closing_session +from ni_measurementlink_service._featuretoggles import ( + SESSION_MANAGEMENT_2024Q1, + requires_feature, +) from ni_measurementlink_service._internal.discovery_client import DiscoveryClient from ni_measurementlink_service._internal.stubs import session_pb2 from ni_measurementlink_service._internal.stubs.ni.measurementlink import ( @@ -65,6 +80,10 @@ INSTRUMENT_TYPE_NI_SWITCH_EXECUTIVE_VIRTUAL_DEVICE = "niSwitchExecutiveVirtualDevice" +TSession = TypeVar("TSession") +TSession_co = TypeVar("TSession_co", covariant=True) + + class PinMapContext(NamedTuple): """Container for the pin map and sites.""" @@ -98,9 +117,7 @@ class SessionInformation(NamedTuple): """Container for the session information.""" session_name: str - """Session identifier used to identify the session in the session management service, as well - as in driver services such as grpc-device. - """ + """Session name used by the session management service and NI gRPC Device Server.""" resource_name: str """Resource name used to open this session in the driver.""" @@ -113,7 +130,7 @@ class SessionInformation(NamedTuple): """ instrument_type_id: str - """Instrument type ID to identify which type of instrument the session represents. + """Indicates the instrument type for this session. Pin maps have built in instrument definitions using the instrument type id constants such as `INSTRUMENT_TYPE_NI_DCPOWER`. For custom instruments, the @@ -121,7 +138,7 @@ class SessionInformation(NamedTuple): """ session_exists: bool - """Indicates whether the session has been registered with the session management service. + """Indicates whether the session is registered with the session management service. When calling measurements from TestStand, the test sequence's ``ProcessSetup`` callback creates instrument sessions and registers them with the session management service so that @@ -136,13 +153,73 @@ class SessionInformation(NamedTuple): """ channel_mappings: Iterable[ChannelMapping] - """List of site and pin/relay mappings that correspond to each channel in the channel_list. + """List of mappings from channels to pins and sites. Each item contains a mapping for a channel in this instrument resource, in the order of the channel_list. This field is empty for any SessionInformation returned from Client.reserve_all_registered_sessions. """ + session: object = None + """The driver session object. + + This field is None until the appropriate create_session(s) method is called. + """ + + def _as_typed(self, session_type: Type[TSession]) -> TypedSessionInformation[TSession]: + assert isinstance(self.session, session_type) + return cast(TypedSessionInformation[TSession], self) + + def _with_session(self, session: object) -> SessionInformation: + return self._replace(session=session) + + def _with_typed_session(self, session: TSession) -> TypedSessionInformation[TSession]: + return self._with_session(session)._as_typed(type(session)) + + +# Python versions <3.11 do not support generic named tuples, so we use a generic +# protocol to return typed session information. +class TypedSessionInformation(Protocol, Generic[TSession_co]): + """Generic version of :any:`SessionInformation` that preserves the session type. + + For more details, see the corresponding documentation for :any:`SessionInformation`. + """ + + @property + def session_name(self) -> str: + """Session name used by the session management service and NI gRPC Device Server.""" + ... + + @property + def resource_name(self) -> str: + """Resource name used to open this session in the driver.""" + ... + + @property + def channel_list(self) -> str: + """Channel list used for driver initialization and measurement methods.""" + ... + + @property + def instrument_type_id(self) -> str: + """Indicates the instrument type for this session.""" + ... + + @property + def session_exists(self) -> bool: + """Indicates whether the session is registered with the session management service.""" + ... + + @property + def channel_mappings(self) -> Iterable[ChannelMapping]: + """List of mappings from channels to pins and sites.""" + ... + + @property + def session(self) -> TSession_co: + """The driver session object.""" + ... + def _convert_channel_mapping_from_grpc( channel_mapping: session_management_service_pb2.ChannelMapping, @@ -205,6 +282,7 @@ def __init__( """Initialize reservation object.""" self._session_manager = session_manager self._session_info = session_info + self._session_cache: Dict[str, object] = {} def __enter__(self: Self) -> Self: """Context management protocol. Returns self.""" @@ -224,6 +302,122 @@ def unreserve(self) -> None: """Unreserve sessions.""" self._session_manager._unreserve_sessions(self._session_info) + @contextlib.contextmanager + def _cache_session(self, session_name: str, session: TSession) -> Generator[None, None, None]: + if session_name in self._session_cache: + raise RuntimeError(f"Session '{session_name}' already exists.") + self._session_cache[session_name] = session + try: + yield + finally: + del self._session_cache[session_name] + + def _get_matching_session_infos(self, instrument_type_id: str) -> List[SessionInformation]: + return [ + _convert_session_info_from_grpc(info)._with_session( + self._session_cache.get(info.session.name) + ) + for info in self._session_info + if instrument_type_id and instrument_type_id == info.instrument_type_id + ] + + @contextlib.contextmanager + def _create_session_core( + self, + session_constructor: Callable[[SessionInformation], TSession], + instrument_type_id: str, + ) -> Generator[TypedSessionInformation[TSession], None, None]: + if not instrument_type_id: + raise ValueError("This method requires an instrument type ID.") + session_infos = self._get_matching_session_infos(instrument_type_id) + if len(session_infos) == 0: + raise ValueError( + f"No sessions matched instrument type ID '{instrument_type_id}'. " + "Expected single session, got 0 sessions." + ) + elif len(session_infos) > 1: + raise ValueError( + f"Too many sessions matched instrument type ID '{instrument_type_id}'. " + f"Expected single session, got {len(session_infos)} sessions." + ) + + session_info = session_infos[0] + with closing_session(session_constructor(session_info)) as session: + with self._cache_session(session_info.session_name, session): + yield session_info._with_typed_session(session) + + @contextlib.contextmanager + def _create_sessions_core( + self, + session_constructor: Callable[[SessionInformation], TSession], + instrument_type_id: str, + ) -> Generator[Sequence[TypedSessionInformation[TSession]], None, None]: + if not instrument_type_id: + raise ValueError("This method requires an instrument type ID.") + session_infos = self._get_matching_session_infos(instrument_type_id) + if len(session_infos) == 0: + raise ValueError( + f"No sessions matched instrument type ID '{instrument_type_id}'. " + "Expected single or multiple sessions, got 0 sessions." + ) + + with ExitStack() as stack: + typed_session_infos: List[TypedSessionInformation[TSession]] = [] + for session_info in session_infos: + session = stack.enter_context(closing_session(session_constructor(session_info))) + stack.enter_context(self._cache_session(session_info.session_name, session)) + typed_session_infos.append(session_info._with_typed_session(session)) + yield typed_session_infos + + @requires_feature(SESSION_MANAGEMENT_2024Q1) + def create_session( + self, + session_constructor: Callable[[SessionInformation], TSession], + instrument_type_id: str, + ) -> ContextManager[TypedSessionInformation[TSession]]: + """Create a single instrument session. + + This is a generic method that supports any instrument driver. + + Args: + session_constructor: A function that constructs sessions based on session + information. + instrument_type_id: Instrument type ID for the session. + For NI instruments, use instrument type id constants, such as + :py:const:`INSTRUMENT_TYPE_NI_DCPOWER` or :py:const:`INSTRUMENT_TYPE_NI_DMM`. + For custom instruments, use the instrument type id defined in the pin map file. + + Returns: + A context manager that yields a session information object. The + created session is available via the ``session`` field. + """ + return self._create_session_core(session_constructor, instrument_type_id) + + @requires_feature(SESSION_MANAGEMENT_2024Q1) + def create_sessions( + self, + session_constructor: Callable[[SessionInformation], TSession], + instrument_type_id: str, + ) -> ContextManager[Sequence[TypedSessionInformation[TSession]]]: + """Create multiple instrument sessions. + + This is a generic method that supports any instrument driver. + + Args: + session_constructor: A function that constructs sessions based on session + information. + instrument_type_id: Instrument type ID for the session. + For NI instruments, use instrument type id constants, such as + :py:const:`INSTRUMENT_TYPE_NI_DCPOWER` or :py:const:`INSTRUMENT_TYPE_NI_DMM`. + For custom instruments, use the instrument type id defined in the pin map file. + + Returns: + A context manager that yields a sequence of session information + objects. The created sessions are available via the ``session`` + field. + """ + return self._create_sessions_core(session_constructor, instrument_type_id) + class SingleSessionReservation(BaseReservation): """Manages reservation for a single session.""" diff --git a/tests/unit/_drivers/__init__.py b/tests/unit/_drivers/__init__.py new file mode 100644 index 000000000..80d5b67e8 --- /dev/null +++ b/tests/unit/_drivers/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ni_measurementlink_service._drivers.""" diff --git a/tests/unit/_drivers/test_grpcdevice.py b/tests/unit/_drivers/test_grpcdevice.py new file mode 100644 index 000000000..0578c77b3 --- /dev/null +++ b/tests/unit/_drivers/test_grpcdevice.py @@ -0,0 +1,64 @@ +from unittest.mock import Mock + +from pytest_mock import MockerFixture + +from ni_measurementlink_service._drivers._grpcdevice import ( + SERVICE_CLASS, + get_insecure_grpc_device_channel, +) +from ni_measurementlink_service._internal.discovery_client import ServiceLocation +from tests.utilities import fake_driver + + +def test___default_configuration___get_insecure_grpc_device_channel___service_discovered_and_channel_returned( + discovery_client: Mock, + grpc_channel: Mock, + grpc_channel_pool: Mock, +) -> None: + discovery_client.resolve_service.return_value = ServiceLocation("localhost", "1234", "") + grpc_channel_pool.get_channel.return_value = grpc_channel + + returned_channel = get_insecure_grpc_device_channel( + discovery_client, grpc_channel_pool, fake_driver.GRPC_SERVICE_INTERFACE_NAME + ) + + discovery_client.resolve_service.assert_called_with( + provided_interface=fake_driver.GRPC_SERVICE_INTERFACE_NAME, + service_class=SERVICE_CLASS, + ) + grpc_channel_pool.get_channel.assert_called_with("localhost:1234") + assert returned_channel is grpc_channel + + +def test___use_grpc_device_server_false___get_insecure_grpc_device_channel___none_returned( + discovery_client: Mock, + grpc_channel_pool: Mock, + mocker: MockerFixture, +) -> None: + mocker.patch("ni_measurementlink_service._drivers._grpcdevice.USE_GRPC_DEVICE_SERVER", False) + + returned_channel = get_insecure_grpc_device_channel( + discovery_client, grpc_channel_pool, fake_driver.GRPC_SERVICE_INTERFACE_NAME + ) + + assert returned_channel is None + + +def test___grpc_device_address_set___get_insecure_grpc_device_channel___address_used_and_channel_returned( + discovery_client: Mock, + grpc_channel: Mock, + grpc_channel_pool: Mock, + mocker: MockerFixture, +) -> None: + mocker.patch( + "ni_measurementlink_service._drivers._grpcdevice.GRPC_DEVICE_ADDRESS", "[::1]:31763" + ) + grpc_channel_pool.get_channel.return_value = grpc_channel + + returned_channel = get_insecure_grpc_device_channel( + discovery_client, grpc_channel_pool, fake_driver.GRPC_SERVICE_INTERFACE_NAME + ) + + discovery_client.resolve_service.assert_not_called() + grpc_channel_pool.get_channel.assert_called_with("[::1]:31763") + assert returned_channel is grpc_channel diff --git a/tests/unit/test_drivers.py b/tests/unit/test_drivers.py new file mode 100644 index 000000000..54abd26b5 --- /dev/null +++ b/tests/unit/test_drivers.py @@ -0,0 +1,26 @@ +from ni_measurementlink_service._drivers import closing_session +from tests.utilities import fake_driver + + +def test___closable_session___with_closing_session___session_closed() -> None: + with closing_session(fake_driver.ClosableSession("Dev1")) as session: + assert isinstance(session, fake_driver.ClosableSession) + assert not session.is_closed + + assert session.is_closed + + +def test___context_manager_session___with_closing_session___session_closed() -> None: + with closing_session(fake_driver.ContextManagerSession("Dev1")) as session: + assert isinstance(session, fake_driver.ContextManagerSession) + assert not session.is_closed + + assert session.is_closed + + +def test___session___with_closing_session___session_closed() -> None: + with closing_session(fake_driver.Session("Dev1")) as session: + assert isinstance(session, fake_driver.Session) + assert not session.is_closed + + assert session.is_closed diff --git a/tests/unit/test_reservation.py b/tests/unit/test_reservation.py new file mode 100644 index 000000000..c299c5eb3 --- /dev/null +++ b/tests/unit/test_reservation.py @@ -0,0 +1,250 @@ +from typing import List +from unittest.mock import Mock + +import pytest + +from ni_measurementlink_service._internal.stubs import session_pb2 +from ni_measurementlink_service._internal.stubs.ni.measurementlink.sessionmanagement.v1 import ( + session_management_service_pb2, +) +from ni_measurementlink_service.session_management import ( + MultiSessionReservation, + SessionInformation, +) +from tests.utilities import fake_driver + + +def test___single_session_info___create_session___session_info_yielded( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(1, "nifake") + ) + + with reservation.create_session(_construct_session, "nifake") as session_info: + assert session_info.session_name == "MySession0" + assert session_info.resource_name == "Dev0" + assert session_info.instrument_type_id == "nifake" + + +def test___single_session_info___create_session___session_created( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(1, "nifake") + ) + + with reservation.create_session(_construct_session, "nifake") as session_info: + assert isinstance(session_info.session, fake_driver.Session) + assert session_info.session.resource_name == "Dev0" + + +def test___single_session_info___create_session___session_lifetime_tracked( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(1, "nifake") + ) + + with reservation.create_session(_construct_session, "nifake") as session_info: + assert reservation._session_cache["MySession0"] is session_info.session + assert not session_info.session.is_closed + + assert len(reservation._session_cache) == 0 + assert session_info.session.is_closed + + +def test___empty_instrument_type_id___create_session___value_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(1, "nifake") + ) + + with pytest.raises(ValueError) as exc_info: + with reservation.create_session(_construct_session, ""): + pass + + assert "This method requires an instrument type ID." in exc_info.value.args[0] + + +def test___no_session_infos___create_session___value_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(0, "nifake") + ) + + with pytest.raises(ValueError) as exc_info: + with reservation.create_session(_construct_session, "nifake"): + pass + + assert "No sessions matched instrument type ID 'nifake'." in exc_info.value.args[0] + + +def test___multi_session_infos___create_session___value_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(2, "nifake") + ) + + with pytest.raises(ValueError) as exc_info: + with reservation.create_session(_construct_session, "nifake"): + pass + + assert "Too many sessions matched instrument type ID 'nifake'." in exc_info.value.args[0] + + +def test___session_already_exists___create_session___runtime_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(1, "nifake") + ) + + with reservation.create_session(_construct_session, "nifake"): + with pytest.raises(RuntimeError) as exc_info: + with reservation.create_session(_construct_session, "nifake"): + pass + + assert "Session 'MySession0' already exists." in exc_info.value.args[0] + + +def test___heterogenous_session_infos___create_session___grouped_by_instrument_type( + session_management_client: Mock, +) -> None: + grpc_session_infos = _create_grpc_session_infos(2, "nifoo") + grpc_session_infos[1].instrument_type_id = "nibar" + reservation = MultiSessionReservation(session_management_client, grpc_session_infos) + + with reservation.create_session( + _construct_session, "nifoo" + ) as nifoo_info, reservation.create_session(_construct_session, "nibar") as nibar_info: + assert nifoo_info.session_name == "MySession0" + assert nifoo_info.instrument_type_id == "nifoo" + assert nibar_info.session_name == "MySession1" + assert nibar_info.instrument_type_id == "nibar" + + +def test___multi_session_infos___create_sessions___session_infos_yielded( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(3, "nifake") + ) + + with reservation.create_sessions(_construct_session, "nifake") as session_infos: + assert [info.session_name for info in session_infos] == [ + "MySession0", + "MySession1", + "MySession2", + ] + assert [info.resource_name for info in session_infos] == ["Dev0", "Dev1", "Dev2"] + assert [info.instrument_type_id for info in session_infos] == ["nifake", "nifake", "nifake"] + + +def test___multi_session_infos___create_sessions___sessions_created( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(3, "nifake") + ) + + with reservation.create_sessions(_construct_session, "nifake") as session_infos: + assert all([isinstance(info.session, fake_driver.Session) for info in session_infos]) + assert [info.session.resource_name for info in session_infos] == ["Dev0", "Dev1", "Dev2"] + + +def test___multi_session_infos___create_sessions___session_lifetime_tracked( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(3, "nifake") + ) + + with reservation.create_sessions(_construct_session, "nifake") as session_infos: + assert reservation._session_cache["MySession0"] is session_infos[0].session + assert reservation._session_cache["MySession1"] is session_infos[1].session + assert reservation._session_cache["MySession2"] is session_infos[2].session + assert all([not info.session.is_closed for info in session_infos]) + + assert len(reservation._session_cache) == 0 + assert all([info.session.is_closed for info in session_infos]) + + +def test___empty_instrument_type_id___create_sessions___value_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(3, "nifake") + ) + + with pytest.raises(ValueError) as exc_info: + with reservation.create_sessions(_construct_session, ""): + pass + + assert "This method requires an instrument type ID." in exc_info.value.args[0] + + +def test___no_session_infos___create_sessions___value_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(0, "nifake") + ) + + with pytest.raises(ValueError) as exc_info: + with reservation.create_sessions(_construct_session, "nifake"): + pass + + assert "No sessions matched instrument type ID 'nifake'." in exc_info.value.args[0] + + +def test___session_already_exists___create_sessions___runtime_error_raised( + session_management_client: Mock, +) -> None: + reservation = MultiSessionReservation( + session_management_client, _create_grpc_session_infos(3, "nifake") + ) + + with reservation.create_sessions(_construct_session, "nifake"): + with pytest.raises(RuntimeError) as exc_info: + with reservation.create_sessions(_construct_session, "nifake"): + pass + + assert "Session 'MySession0' already exists." in exc_info.value.args[0] + + +def test___heterogenous_session_infos___create_sessions___grouped_by_instrument_type( + session_management_client: Mock, +) -> None: + grpc_session_infos = _create_grpc_session_infos(3, "nifoo") + grpc_session_infos[1].instrument_type_id = "nibar" + reservation = MultiSessionReservation(session_management_client, grpc_session_infos) + + with reservation.create_sessions( + _construct_session, "nifoo" + ) as nifoo_infos, reservation.create_sessions(_construct_session, "nibar") as nibar_infos: + assert [info.session_name for info in nifoo_infos] == ["MySession0", "MySession2"] + assert [info.instrument_type_id for info in nifoo_infos] == ["nifoo", "nifoo"] + assert [info.session_name for info in nibar_infos] == ["MySession1"] + assert [info.instrument_type_id for info in nibar_infos] == ["nibar"] + + +def _construct_session(session_info: SessionInformation) -> fake_driver.Session: + return fake_driver.Session(session_info.resource_name) + + +def _create_grpc_session_infos( + session_count: int, + instrument_type_id: str, +) -> List[session_management_service_pb2.SessionInformation]: + return [ + session_management_service_pb2.SessionInformation( + session=session_pb2.Session(name=f"MySession{i}"), + resource_name=f"Dev{i}", + instrument_type_id=instrument_type_id, + ) + for i in range(session_count) + ] diff --git a/tests/unit/test_session_management.py b/tests/unit/test_session_management.py index 7c0fa3efa..eeab910d2 100644 --- a/tests/unit/test_session_management.py +++ b/tests/unit/test_session_management.py @@ -20,6 +20,7 @@ SessionInformation, SessionManagementClient, SingleSessionReservation, + TypedSessionInformation, ) @@ -548,6 +549,14 @@ def test___use_reservation_type___reports_deprecated_warning_and_aliases_to_mult assert isinstance(reservation, MultiSessionReservation) +def test___session_information___type_check___implements_typed_session_information_object() -> None: + # This is a type-checking test. It does nothing at run time. + def f(typed_session_info: TypedSessionInformation[object]) -> None: + pass + + f(SessionInformation("MySession", "Dev1", "0", "niDCPower", False, [])) + + def _create_session_infos(session_count: int) -> List[SessionInformation]: return [ SessionInformation(f"MySession{i}", "", "", "", False, []) for i in range(session_count) diff --git a/tests/utilities/fake_driver.py b/tests/utilities/fake_driver.py new file mode 100644 index 000000000..5b9c78fd2 --- /dev/null +++ b/tests/utilities/fake_driver.py @@ -0,0 +1,148 @@ +"""Fake driver API for testing.""" +from __future__ import annotations + +import sys +from enum import Enum, IntEnum +from types import TracebackType +from typing import TYPE_CHECKING, Any, ContextManager, Dict, Optional, Type + +if TYPE_CHECKING: + import grpc + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + +GRPC_SERVICE_INTERFACE_NAME = "nifake_grpc.NiFake" + +# The GrpcSessionOptions classes in nimi-python and nidaqmx-python have an api_key field. +_API_KEY = "00000000-0000-0000-0000-000000000000" + + +class SessionInitializationBehavior(IntEnum): + """Specifies whether to initialize a new session or attach to an existing session.""" + + AUTO = 0 + INITIALIZE_SERVER_SESSION = 1 + ATTACH_TO_SERVER_SESSION = 2 + + +class GrpcSessionOptions: + """gRPC session options.""" + + def __init__( + self, + grpc_channel: grpc.Channel, + session_name: str, + *, + api_key: str = _API_KEY, + initialization_behavior: SessionInitializationBehavior = SessionInitializationBehavior.AUTO, + ) -> None: + """Initialize the gRPC session options.""" + self.grpc_channel = grpc_channel + self.session_name = session_name + self.api_key = api_key + self.initialization_behavior = initialization_behavior + + +class MeasurementType(Enum): + """Measurement type.""" + + VOLTAGE = 1 + CURRENT = 2 + + +class _Acquisition: + def __init__(self, session: _SessionBase) -> None: + self._session = session + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self._session.abort() + + +class _SessionBase: + """Base class for driver sessions.""" + + def __init__(self, resource_name: str, options: Dict[str, Any] = {}) -> None: + """Initialize the session.""" + self.resource_name = resource_name + self.options = options + self.is_closed = False + + def configure(self, measurement_type: MeasurementType, range: float) -> None: + """Configure the session.""" + pass + + def initiate(self) -> ContextManager[object]: + """Initiate an acquisition.""" + return _Acquisition(self) + + def abort(self) -> None: + """Abort (stop) the acquisition.""" + pass + + def read(self) -> float: + """Read a sample.""" + return 0.0 + + +class ClosableSession(_SessionBase): + """A driver session that supports close().""" + + def close(self) -> None: + """Close the session.""" + self.is_closed = True + + +class ContextManagerSession(_SessionBase): + """A driver session that supports the context manager protocol.""" + + def __enter__(self) -> Self: + """Enter the session's runtime context.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit the session's runtime context.""" + self.is_closed = True + + +class Session(_SessionBase): + """A driver session that supports both close() and the context manager protocol.""" + + def __init__(self, resource_name: str, options: Dict[str, Any] = {}) -> None: + """Initialize the session.""" + self.resource_name = resource_name + self.options = options + self.is_closed = False + + def close(self) -> None: + """Close the session.""" + self.is_closed = True + + def __enter__(self) -> Self: + """Enter the session's runtime context.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit the session's runtime context.""" + self.close()