Skip to content

Commit

Permalink
Switch from using apischema to pydantic (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester authored Apr 14, 2023
1 parent 8e754f3 commit 1ca16e9
Show file tree
Hide file tree
Showing 24 changed files with 1,159 additions and 395 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ dependencies = [
"ophyd",
"nslsii",
"pyepics",
"apischema",
"pydantic",
"stomp.py",
"scanspec<=0.5.5",
"scanspec",
"PyYAML",
"click",
]
Expand Down
9 changes: 8 additions & 1 deletion src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def controller(ctx, host: str, port: int, log_level: str):
return
logging.basicConfig(level=log_level)
ctx.ensure_object(dict)
client = AmqClient(StompMessagingTemplate.autoconfigured(StompConfig(host, port)))
client = AmqClient(
StompMessagingTemplate.autoconfigured(
StompConfig(
host=host,
port=port,
)
)
)
ctx.obj["client"] = client
client.app.connect()

Expand Down
21 changes: 9 additions & 12 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union

from pydantic import BaseModel, Field

@dataclass
class StompConfig:

class StompConfig(BaseModel):
"""
Config for connecting to stomp broker
"""
Expand All @@ -13,27 +13,24 @@ class StompConfig:
port: int = 61613


@dataclass
class EnvironmentConfig:
class EnvironmentConfig(BaseModel):
"""
Config for the RunEngine environment
"""

startup_script: Union[Path, str] = "blueapi.startup.example"


@dataclass
class LoggingConfig:
class LoggingConfig(BaseModel):
level: str = "INFO"


@dataclass
class ApplicationConfig:
class ApplicationConfig(BaseModel):
"""
Config for the worker application as a whole. Root of
config tree.
"""

stomp: StompConfig = field(default_factory=StompConfig)
env: EnvironmentConfig = field(default_factory=EnvironmentConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
stomp: StompConfig = Field(default_factory=StompConfig)
env: EnvironmentConfig = Field(default_factory=EnvironmentConfig)
logging: LoggingConfig = Field(default_factory=LoggingConfig)
4 changes: 2 additions & 2 deletions src/blueapi/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
PlanGenerator,
WatchableStatus,
is_bluesky_compatible_device,
is_bluesky_compatible_device_type,
is_bluesky_plan_generator,
)
from .context import BlueskyContext
from .device_lookup import create_bluesky_protocol_conversions
from .event import EventPublisher, EventStream

__all__ = [
Expand All @@ -19,12 +19,12 @@
"MsgGenerator",
"Device",
"BLUESKY_PROTOCOLS",
"create_bluesky_protocol_conversions",
"BlueskyContext",
"EventPublisher",
"EventStream",
"DataEvent",
"WatchableStatus",
"is_bluesky_compatible_device",
"is_bluesky_plan_generator",
"is_bluesky_compatible_device_type",
]
29 changes: 18 additions & 11 deletions src/blueapi/core/bluesky_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Generator, Mapping, Type, Union

from bluesky.protocols import (
Expand All @@ -20,6 +19,7 @@
WritesExternalAssets,
)
from bluesky.utils import Msg
from pydantic import BaseModel, Field

try:
from typing import Protocol, runtime_checkable
Expand Down Expand Up @@ -57,12 +57,19 @@

def is_bluesky_compatible_device(obj: Any) -> bool:
is_object = not inspect.isclass(obj)
follows_protocols = any(
map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS)
)
# We must separately check if Obj refers to an instance rather than a
# class, as both follow the protocols but only one is a "device".
return is_object and follows_protocols
return is_object and _follows_bluesky_protocols(obj)


def is_bluesky_compatible_device_type(cls: Type[Any]) -> bool:
# We must separately check if Obj refers to an class rather than an
# instance, as both follow the protocols but only one is a type.
return inspect.isclass(cls) and _follows_bluesky_protocols(cls)


def _follows_bluesky_protocols(obj: Any) -> bool:
return any(map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS))


def is_bluesky_plan_generator(func: PlanGenerator) -> bool:
Expand All @@ -72,18 +79,18 @@ def is_bluesky_plan_generator(func: PlanGenerator) -> bool:
)


@dataclass
class Plan:
class Plan(BaseModel):
"""
A plan that can be run
"""

name: str
model: Type[Any]
name: str = Field(description="Referenceable name of the plan")
model: Type[BaseModel] = Field(
description="Validation model of the parameters for the plan"
)


@dataclass
class DataEvent:
class DataEvent(BaseModel):
"""
Event representing collection of some data. Conforms to the Bluesky event model:
https://github.com/bluesky/event-model
Expand Down
35 changes: 31 additions & 4 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
from importlib import import_module
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Union

from bluesky import RunEngine
from bluesky.protocols import Flyable, Readable
from pydantic import BaseConfig

from blueapi.utils import load_module_all, schema_for_func
from blueapi.utils import (
TypeValidatorDefinition,
create_model_with_type_validators,
load_module_all,
)

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Device,
Plan,
PlanGenerator,
Expand All @@ -22,6 +28,10 @@
LOGGER = logging.getLogger(__name__)


class PlanModelConfig(BaseConfig):
arbitrary_types_allowed = True


@dataclass
class BlueskyContext:
"""
Expand Down Expand Up @@ -107,8 +117,14 @@ def my_plan(a: int, b: str):
if not is_bluesky_plan_generator(plan):
raise TypeError(f"{plan} is not a valid plan generator function")

schema = schema_for_func(plan)
self.plans[plan.__name__] = Plan(plan.__name__, schema)
validators = list(device_validators(self))
model = create_model_with_type_validators(
plan.__name__,
validators,
func=plan,
config=PlanModelConfig,
)
self.plans[plan.__name__] = Plan(name=plan.__name__, model=model)
self.plan_functions[plan.__name__] = plan
return plan

Expand Down Expand Up @@ -138,3 +154,14 @@ def device(self, device: Device, name: Optional[str] = None) -> None:
raise KeyError(f"Must supply a name for this device: {device}")

self.devices[name] = device


def device_validators(ctx: BlueskyContext) -> Iterable[TypeValidatorDefinition]:
def get_device(name: str) -> Device:
device = ctx.find_device(name)
if device is None:
raise KeyError(f"Could not find a device named {name}")
return device

for proto in BLUESKY_PROTOCOLS:
yield TypeValidatorDefinition(proto, get_device)
46 changes: 2 additions & 44 deletions src/blueapi/core/device_lookup.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,6 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar

from apischema.conversions.conversions import Conversion

from .bluesky_types import BLUESKY_PROTOCOLS, Device, is_bluesky_compatible_device


def create_bluesky_protocol_conversions(
device_lookup: Callable[[str], Device],
) -> Iterable[Conversion]:
"""
Generate a series of APISchema Conversions for the valid Device types.
The conversions use a user-defined function to lookup devices by name.
Args:
device_lookup (Callable[[str], Device]): Function to lookup Device by name,
expects an Exception if name not
found
Returns:
Iterable[Conversion]: Conversions for locating devices
"""

def find_device_matching_name_and_type(target_type: Type, name: str) -> Any:
# Find the device in the
device = device_lookup(name)

# The schema has asked for a particular protocol, at this point in the code we
# have found the device but need to check that it complies with the requested
# protocol. If it doesn't, it means there is a typing error in the plan.
if isinstance(device, target_type):
return device
else:
raise TypeError(f"{name} needs to be of type {target_type}")

# Create a conversion for each type, the conversion function will automatically
# perform a structural subtyping check
for a_type in BLUESKY_PROTOCOLS:
yield Conversion(
partial(find_device_matching_name_and_type, a_type),
source=str,
target=a_type,
)
from typing import Any, List, Optional, TypeVar

from .bluesky_types import Device, is_bluesky_compatible_device

#: Device obeying Bluesky protocols
D = TypeVar("D", bound=Device)
Expand Down
6 changes: 3 additions & 3 deletions src/blueapi/messaging/stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from typing import Any, Callable, Dict, List, Optional, Set

import stomp
from apischema import deserialize, serialize
from pydantic import parse_obj_as
from stomp.exception import ConnectFailedException
from stomp.utils import Frame

from blueapi.config import StompConfig
from blueapi.utils import handle_all_exceptions
from blueapi.utils import handle_all_exceptions, serialize

from .base import DestinationProvider, MessageListener, MessagingTemplate
from .context import MessageContext
Expand Down Expand Up @@ -140,7 +140,7 @@ def subscribe(self, destination: str, callback: MessageListener) -> None:

def wrapper(frame: Frame) -> None:
as_dict = json.loads(frame.body)
value = deserialize(obj_type, as_dict)
value = parse_obj_as(obj_type, as_dict)

context = MessageContext(
frame.headers["destination"],
Expand Down
40 changes: 15 additions & 25 deletions src/blueapi/plans/plans.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import operator
from functools import reduce
from typing import Any, List, Mapping, Optional, Tuple, Type, Union
from typing import Any, List, Mapping, Optional, Union

import bluesky.plans as bp
from apischema import serialize
from apischema.conversions.conversions import Conversion
from apischema.conversions.converters import AnyConversion, default_serialization
from bluesky.protocols import Movable, Readable
from cycler import Cycler, cycler
from scanspec.specs import Spec
Expand All @@ -15,17 +12,20 @@

def scan(
detectors: List[Readable],
spec: Spec[Movable],
axes_to_move: Mapping[str, Movable],
spec: Spec[str],
metadata: Optional[Mapping[str, Any]] = None,
) -> MsgGenerator:
"""
Scan wrapping `bp.scan_nd`
Args:
detectors (List[Readable]): List of readable devices, will take a reading at
detectors: List of readable devices, will take a reading at
each point
spec (Spec[Movable]): ScanSpec modelling the path of the scan
metadata (Optional[Mapping[str, Any]], optional): Key-value metadata to include
axes_to_move: All axes involved in this scan, names and
objects
spec: ScanSpec modelling the path of the scan
metadata: Key-value metadata to include
in exported data, defaults to
None.
Expand All @@ -38,41 +38,31 @@ def scan(

metadata = {
"detectors": [detector.name for detector in detectors],
"scanspec": serialize(spec, default_conversion=_convert_devices),
"shape": _shape(spec),
# "scanspec": serialize(spec, default_conversion=_convert_devices),
"shape": spec.shape(),
**(metadata or {}),
}

cycler = _scanspec_to_cycler(spec)
cycler = _scanspec_to_cycler(spec, axes_to_move)
yield from bp.scan_nd(detectors, cycler, md=metadata)


# TODO: Use built-in scanspec utility method following completion of DAQ-4487
def _shape(spec: Spec[Movable]) -> Tuple[int, ...]:
return tuple(len(dim) for dim in spec.calculate())


def _convert_devices(a_type: Type[Any]) -> Optional[AnyConversion]:
if issubclass(a_type, Movable):
return Conversion(str, source=a_type)
else:
return default_serialization(a_type)


def _scanspec_to_cycler(spec: Spec) -> Cycler:
def _scanspec_to_cycler(spec: Spec[str], axes: Mapping[str, Movable]) -> Cycler:
"""
Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as
`bp.scan_nd`. Use the midpoints of the scanspec since cyclers are noramlly used
for software triggered scans.
Args:
spec (Spec): A scanspec
spec: A scanspec
axes: Names and axes to move
Returns:
Cycler: A new cycler
"""

midpoints = spec.frames().midpoints
midpoints = {axes[name]: points for name, points in midpoints.items()}

# Need to "add" the cyclers for all the axes together. The code below is
# effectively: cycler(motor1, [...]) + cycler(motor2, [...]) + ...
Expand Down
Loading

0 comments on commit 1ca16e9

Please sign in to comment.