Skip to content

Commit

Permalink
service: Add generic create_session(s) API (#426)
Browse files Browse the repository at this point in the history
* service: Add a helper function to get the gRPC device channel

* tests: Add tests for get_insecure_grpc_device_channel

* service: Add closing_session helper function

* tests: Add unit tests for closing_session

* service: Add generic create_session(s) methods

* tests: Add tests for create_session(s)

* service: Add TypedSessionInformation

* service: Remove an unnecessary ExitStack

* tests: Add comment about api_key

* tests: Rename a couple test cases

* service: Extract _get_matching_session_infos and refactor _with_session

* tests: Update test cases

* tests: Add a type-checking test case

* service: Make instrument_type_id required for generic create_session(s)

* tests: Update grpcdevice tests

* service: Remove reference to session tuple

* service: Update docstrings
  • Loading branch information
bkeryan authored Oct 3, 2023
1 parent 032ad4d commit b8346fc
Show file tree
Hide file tree
Showing 9 changed files with 775 additions and 6 deletions.
28 changes: 28 additions & 0 deletions ni_measurementlink_service/_drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Shared code for interfacing with driver APIs."""
from __future__ import annotations

import contextlib
from typing import ContextManager, TypeVar

TSession = TypeVar("TSession")


def closing_session(session: TSession) -> ContextManager[TSession]:
"""Create a context manager that closes the session.
Args:
session: A driver session.
Returns:
A context manager that yields the session and closes it.
"""
if isinstance(session, contextlib.AbstractContextManager):
# Assume the session yields itself.
return session
elif hasattr(session, "close"):
return contextlib.closing(session)
else:
raise TypeError(
f"Invalid session type '{type(session)}'. A session must be a context manager and/or "
"have a close() method."
)
49 changes: 49 additions & 0 deletions ni_measurementlink_service/_drivers/_grpcdevice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Shared functions for interacting with NI gRPC Device Server."""
from __future__ import annotations

from typing import Optional

import grpc

from ni_measurementlink_service._channelpool import GrpcChannelPool
from ni_measurementlink_service._configuration import (
GRPC_DEVICE_ADDRESS,
USE_GRPC_DEVICE_SERVER,
)
from ni_measurementlink_service._internal.discovery_client import DiscoveryClient

SERVICE_CLASS = "ni.measurementlink.v1.grpcdeviceserver"
"""The service class for NI gRPC Device Server."""


def get_insecure_grpc_device_channel(
discovery_client: DiscoveryClient,
grpc_channel_pool: GrpcChannelPool,
provided_interface: str,
) -> Optional[grpc.Channel]:
"""Get an unencrypted gRPC channel targeting NI gRPC Device Server.
Args:
discovery_client: The discovery client.
grpc_channel_pool: The gRPC channel pool.
provided_interface: The driver API's NI gRPC Device Server interface
name.
Returns:
A gRPC channel targeting the NI gRPC Device Server, or ``None`` if the
configuration file specifies that ``USE_GRPC_DEVICE_SERVER`` is false.
"""
if not USE_GRPC_DEVICE_SERVER:
return None

address = GRPC_DEVICE_ADDRESS
if not address:
service_location = discovery_client.resolve_service(
provided_interface=provided_interface,
service_class=SERVICE_CLASS,
)
address = service_location.insecure_address

return grpc_channel_pool.get_channel(address)
206 changes: 200 additions & 6 deletions ni_measurementlink_service/session_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@
from __future__ import annotations

import abc
import contextlib
import logging
import sys
import threading
import warnings
from contextlib import ExitStack
from functools import cached_property
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
)

import grpc
from deprecation import DeprecatedWarning

from ni_measurementlink_service._channelpool import GrpcChannelPool
from ni_measurementlink_service._drivers import closing_session
from ni_measurementlink_service._featuretoggles import (
SESSION_MANAGEMENT_2024Q1,
requires_feature,
)
from ni_measurementlink_service._internal.discovery_client import DiscoveryClient
from ni_measurementlink_service._internal.stubs import session_pb2
from ni_measurementlink_service._internal.stubs.ni.measurementlink import (
Expand Down Expand Up @@ -65,6 +80,10 @@
INSTRUMENT_TYPE_NI_SWITCH_EXECUTIVE_VIRTUAL_DEVICE = "niSwitchExecutiveVirtualDevice"


TSession = TypeVar("TSession")
TSession_co = TypeVar("TSession_co", covariant=True)


class PinMapContext(NamedTuple):
"""Container for the pin map and sites."""

Expand Down Expand Up @@ -98,9 +117,7 @@ class SessionInformation(NamedTuple):
"""Container for the session information."""

session_name: str
"""Session identifier used to identify the session in the session management service, as well
as in driver services such as grpc-device.
"""
"""Session name used by the session management service and NI gRPC Device Server."""

resource_name: str
"""Resource name used to open this session in the driver."""
Expand All @@ -113,15 +130,15 @@ class SessionInformation(NamedTuple):
"""

instrument_type_id: str
"""Instrument type ID to identify which type of instrument the session represents.
"""Indicates the instrument type for this session.
Pin maps have built in instrument definitions using the instrument
type id constants such as `INSTRUMENT_TYPE_NI_DCPOWER`. For custom instruments, the
user defined instrument type id is defined in the pin map file.
"""

session_exists: bool
"""Indicates whether the session has been registered with the session management service.
"""Indicates whether the session is registered with the session management service.
When calling measurements from TestStand, the test sequence's ``ProcessSetup`` callback
creates instrument sessions and registers them with the session management service so that
Expand All @@ -136,13 +153,73 @@ class SessionInformation(NamedTuple):
"""

channel_mappings: Iterable[ChannelMapping]
"""List of site and pin/relay mappings that correspond to each channel in the channel_list.
"""List of mappings from channels to pins and sites.
Each item contains a mapping for a channel in this instrument resource, in the order of the
channel_list. This field is empty for any SessionInformation returned from
Client.reserve_all_registered_sessions.
"""

session: object = None
"""The driver session object.
This field is None until the appropriate create_session(s) method is called.
"""

def _as_typed(self, session_type: Type[TSession]) -> TypedSessionInformation[TSession]:
assert isinstance(self.session, session_type)
return cast(TypedSessionInformation[TSession], self)

def _with_session(self, session: object) -> SessionInformation:
return self._replace(session=session)

def _with_typed_session(self, session: TSession) -> TypedSessionInformation[TSession]:
return self._with_session(session)._as_typed(type(session))


# Python versions <3.11 do not support generic named tuples, so we use a generic
# protocol to return typed session information.
class TypedSessionInformation(Protocol, Generic[TSession_co]):
"""Generic version of :any:`SessionInformation` that preserves the session type.
For more details, see the corresponding documentation for :any:`SessionInformation`.
"""

@property
def session_name(self) -> str:
"""Session name used by the session management service and NI gRPC Device Server."""
...

@property
def resource_name(self) -> str:
"""Resource name used to open this session in the driver."""
...

@property
def channel_list(self) -> str:
"""Channel list used for driver initialization and measurement methods."""
...

@property
def instrument_type_id(self) -> str:
"""Indicates the instrument type for this session."""
...

@property
def session_exists(self) -> bool:
"""Indicates whether the session is registered with the session management service."""
...

@property
def channel_mappings(self) -> Iterable[ChannelMapping]:
"""List of mappings from channels to pins and sites."""
...

@property
def session(self) -> TSession_co:
"""The driver session object."""
...


def _convert_channel_mapping_from_grpc(
channel_mapping: session_management_service_pb2.ChannelMapping,
Expand Down Expand Up @@ -205,6 +282,7 @@ def __init__(
"""Initialize reservation object."""
self._session_manager = session_manager
self._session_info = session_info
self._session_cache: Dict[str, object] = {}

def __enter__(self: Self) -> Self:
"""Context management protocol. Returns self."""
Expand All @@ -224,6 +302,122 @@ def unreserve(self) -> None:
"""Unreserve sessions."""
self._session_manager._unreserve_sessions(self._session_info)

@contextlib.contextmanager
def _cache_session(self, session_name: str, session: TSession) -> Generator[None, None, None]:
if session_name in self._session_cache:
raise RuntimeError(f"Session '{session_name}' already exists.")
self._session_cache[session_name] = session
try:
yield
finally:
del self._session_cache[session_name]

def _get_matching_session_infos(self, instrument_type_id: str) -> List[SessionInformation]:
return [
_convert_session_info_from_grpc(info)._with_session(
self._session_cache.get(info.session.name)
)
for info in self._session_info
if instrument_type_id and instrument_type_id == info.instrument_type_id
]

@contextlib.contextmanager
def _create_session_core(
self,
session_constructor: Callable[[SessionInformation], TSession],
instrument_type_id: str,
) -> Generator[TypedSessionInformation[TSession], None, None]:
if not instrument_type_id:
raise ValueError("This method requires an instrument type ID.")
session_infos = self._get_matching_session_infos(instrument_type_id)
if len(session_infos) == 0:
raise ValueError(
f"No sessions matched instrument type ID '{instrument_type_id}'. "
"Expected single session, got 0 sessions."
)
elif len(session_infos) > 1:
raise ValueError(
f"Too many sessions matched instrument type ID '{instrument_type_id}'. "
f"Expected single session, got {len(session_infos)} sessions."
)

session_info = session_infos[0]
with closing_session(session_constructor(session_info)) as session:
with self._cache_session(session_info.session_name, session):
yield session_info._with_typed_session(session)

@contextlib.contextmanager
def _create_sessions_core(
self,
session_constructor: Callable[[SessionInformation], TSession],
instrument_type_id: str,
) -> Generator[Sequence[TypedSessionInformation[TSession]], None, None]:
if not instrument_type_id:
raise ValueError("This method requires an instrument type ID.")
session_infos = self._get_matching_session_infos(instrument_type_id)
if len(session_infos) == 0:
raise ValueError(
f"No sessions matched instrument type ID '{instrument_type_id}'. "
"Expected single or multiple sessions, got 0 sessions."
)

with ExitStack() as stack:
typed_session_infos: List[TypedSessionInformation[TSession]] = []
for session_info in session_infos:
session = stack.enter_context(closing_session(session_constructor(session_info)))
stack.enter_context(self._cache_session(session_info.session_name, session))
typed_session_infos.append(session_info._with_typed_session(session))
yield typed_session_infos

@requires_feature(SESSION_MANAGEMENT_2024Q1)
def create_session(
self,
session_constructor: Callable[[SessionInformation], TSession],
instrument_type_id: str,
) -> ContextManager[TypedSessionInformation[TSession]]:
"""Create a single instrument session.
This is a generic method that supports any instrument driver.
Args:
session_constructor: A function that constructs sessions based on session
information.
instrument_type_id: Instrument type ID for the session.
For NI instruments, use instrument type id constants, such as
:py:const:`INSTRUMENT_TYPE_NI_DCPOWER` or :py:const:`INSTRUMENT_TYPE_NI_DMM`.
For custom instruments, use the instrument type id defined in the pin map file.
Returns:
A context manager that yields a session information object. The
created session is available via the ``session`` field.
"""
return self._create_session_core(session_constructor, instrument_type_id)

@requires_feature(SESSION_MANAGEMENT_2024Q1)
def create_sessions(
self,
session_constructor: Callable[[SessionInformation], TSession],
instrument_type_id: str,
) -> ContextManager[Sequence[TypedSessionInformation[TSession]]]:
"""Create multiple instrument sessions.
This is a generic method that supports any instrument driver.
Args:
session_constructor: A function that constructs sessions based on session
information.
instrument_type_id: Instrument type ID for the session.
For NI instruments, use instrument type id constants, such as
:py:const:`INSTRUMENT_TYPE_NI_DCPOWER` or :py:const:`INSTRUMENT_TYPE_NI_DMM`.
For custom instruments, use the instrument type id defined in the pin map file.
Returns:
A context manager that yields a sequence of session information
objects. The created sessions are available via the ``session``
field.
"""
return self._create_sessions_core(session_constructor, instrument_type_id)


class SingleSessionReservation(BaseReservation):
"""Manages reservation for a single session."""
Expand Down
1 change: 1 addition & 0 deletions tests/unit/_drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ni_measurementlink_service._drivers."""
Loading

0 comments on commit b8346fc

Please sign in to comment.