From 1ca16e99ec97dc4e01aeca24a32abed3619995e2 Mon Sep 17 00:00:00 2001 From: Callum Forrester <29771545+callumforrester@users.noreply.github.com> Date: Fri, 14 Apr 2023 13:43:01 +0100 Subject: [PATCH] Switch from using apischema to pydantic (#90) --- pyproject.toml | 4 +- src/blueapi/cli/cli.py | 9 +- src/blueapi/config.py | 21 +- src/blueapi/core/__init__.py | 4 +- src/blueapi/core/bluesky_types.py | 29 +- src/blueapi/core/context.py | 35 +- src/blueapi/core/device_lookup.py | 46 +- src/blueapi/messaging/stomptemplate.py | 6 +- src/blueapi/plans/plans.py | 40 +- src/blueapi/service/app.py | 6 +- src/blueapi/service/model.py | 44 +- src/blueapi/utils/__init__.py | 8 +- src/blueapi/utils/config.py | 6 +- src/blueapi/utils/schema.py | 103 ----- src/blueapi/utils/serialization.py | 24 ++ src/blueapi/utils/type_validator.py | 340 +++++++++++++++ src/blueapi/worker/event.py | 79 ++-- src/blueapi/worker/reworker.py | 39 +- src/blueapi/worker/task.py | 72 ++-- tests/core/test_device_lookup.py | 13 - tests/messaging/test_stomptemplate.py | 7 +- tests/utils/test_config.py | 17 +- tests/utils/test_schema.py | 39 -- tests/utils/test_type_validator.py | 563 +++++++++++++++++++++++++ 24 files changed, 1159 insertions(+), 395 deletions(-) delete mode 100644 src/blueapi/utils/schema.py create mode 100644 src/blueapi/utils/serialization.py create mode 100644 src/blueapi/utils/type_validator.py delete mode 100644 tests/core/test_device_lookup.py delete mode 100644 tests/utils/test_schema.py create mode 100644 tests/utils/test_type_validator.py diff --git a/pyproject.toml b/pyproject.toml index b27dfef91..6ba3db4d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,9 +18,9 @@ dependencies = [ "ophyd", "nslsii", "pyepics", - "apischema", + "pydantic", "stomp.py", - "scanspec<=0.5.5", + "scanspec", "PyYAML", "click", ] diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index b969fe70e..57ac9414d 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -61,7 +61,14 @@ def controller(ctx, host: str, port: int, log_level: str): return logging.basicConfig(level=log_level) ctx.ensure_object(dict) - client = AmqClient(StompMessagingTemplate.autoconfigured(StompConfig(host, port))) + client = AmqClient( + StompMessagingTemplate.autoconfigured( + StompConfig( + host=host, + port=port, + ) + ) + ) ctx.obj["client"] = client client.app.connect() diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e66196543..6000628d2 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,10 +1,10 @@ -from dataclasses import dataclass, field from pathlib import Path from typing import Union +from pydantic import BaseModel, Field -@dataclass -class StompConfig: + +class StompConfig(BaseModel): """ Config for connecting to stomp broker """ @@ -13,8 +13,7 @@ class StompConfig: port: int = 61613 -@dataclass -class EnvironmentConfig: +class EnvironmentConfig(BaseModel): """ Config for the RunEngine environment """ @@ -22,18 +21,16 @@ class EnvironmentConfig: startup_script: Union[Path, str] = "blueapi.startup.example" -@dataclass -class LoggingConfig: +class LoggingConfig(BaseModel): level: str = "INFO" -@dataclass -class ApplicationConfig: +class ApplicationConfig(BaseModel): """ Config for the worker application as a whole. Root of config tree. """ - stomp: StompConfig = field(default_factory=StompConfig) - env: EnvironmentConfig = field(default_factory=EnvironmentConfig) - logging: LoggingConfig = field(default_factory=LoggingConfig) + stomp: StompConfig = Field(default_factory=StompConfig) + env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index 1040d86a3..061c0ce03 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -7,10 +7,10 @@ PlanGenerator, WatchableStatus, is_bluesky_compatible_device, + is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) from .context import BlueskyContext -from .device_lookup import create_bluesky_protocol_conversions from .event import EventPublisher, EventStream __all__ = [ @@ -19,7 +19,6 @@ "MsgGenerator", "Device", "BLUESKY_PROTOCOLS", - "create_bluesky_protocol_conversions", "BlueskyContext", "EventPublisher", "EventStream", @@ -27,4 +26,5 @@ "WatchableStatus", "is_bluesky_compatible_device", "is_bluesky_plan_generator", + "is_bluesky_compatible_device_type", ] diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 2a757b2c3..aca9d6444 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -1,5 +1,4 @@ import inspect -from dataclasses import dataclass from typing import Any, Callable, Generator, Mapping, Type, Union from bluesky.protocols import ( @@ -20,6 +19,7 @@ WritesExternalAssets, ) from bluesky.utils import Msg +from pydantic import BaseModel, Field try: from typing import Protocol, runtime_checkable @@ -57,12 +57,19 @@ def is_bluesky_compatible_device(obj: Any) -> bool: is_object = not inspect.isclass(obj) - follows_protocols = any( - map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS) - ) # We must separately check if Obj refers to an instance rather than a # class, as both follow the protocols but only one is a "device". - return is_object and follows_protocols + return is_object and _follows_bluesky_protocols(obj) + + +def is_bluesky_compatible_device_type(cls: Type[Any]) -> bool: + # We must separately check if Obj refers to an class rather than an + # instance, as both follow the protocols but only one is a type. + return inspect.isclass(cls) and _follows_bluesky_protocols(cls) + + +def _follows_bluesky_protocols(obj: Any) -> bool: + return any(map(lambda protocol: isinstance(obj, protocol), BLUESKY_PROTOCOLS)) def is_bluesky_plan_generator(func: PlanGenerator) -> bool: @@ -72,18 +79,18 @@ def is_bluesky_plan_generator(func: PlanGenerator) -> bool: ) -@dataclass -class Plan: +class Plan(BaseModel): """ A plan that can be run """ - name: str - model: Type[Any] + name: str = Field(description="Referenceable name of the plan") + model: Type[BaseModel] = Field( + description="Validation model of the parameters for the plan" + ) -@dataclass -class DataEvent: +class DataEvent(BaseModel): """ Event representing collection of some data. Conforms to the Bluesky event model: https://github.com/bluesky/event-model diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 6e484e94f..5d46506c1 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -3,14 +3,20 @@ from importlib import import_module from pathlib import Path from types import ModuleType -from typing import Dict, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union from bluesky import RunEngine from bluesky.protocols import Flyable, Readable +from pydantic import BaseConfig -from blueapi.utils import load_module_all, schema_for_func +from blueapi.utils import ( + TypeValidatorDefinition, + create_model_with_type_validators, + load_module_all, +) from .bluesky_types import ( + BLUESKY_PROTOCOLS, Device, Plan, PlanGenerator, @@ -22,6 +28,10 @@ LOGGER = logging.getLogger(__name__) +class PlanModelConfig(BaseConfig): + arbitrary_types_allowed = True + + @dataclass class BlueskyContext: """ @@ -107,8 +117,14 @@ def my_plan(a: int, b: str): if not is_bluesky_plan_generator(plan): raise TypeError(f"{plan} is not a valid plan generator function") - schema = schema_for_func(plan) - self.plans[plan.__name__] = Plan(plan.__name__, schema) + validators = list(device_validators(self)) + model = create_model_with_type_validators( + plan.__name__, + validators, + func=plan, + config=PlanModelConfig, + ) + self.plans[plan.__name__] = Plan(name=plan.__name__, model=model) self.plan_functions[plan.__name__] = plan return plan @@ -138,3 +154,14 @@ def device(self, device: Device, name: Optional[str] = None) -> None: raise KeyError(f"Must supply a name for this device: {device}") self.devices[name] = device + + +def device_validators(ctx: BlueskyContext) -> Iterable[TypeValidatorDefinition]: + def get_device(name: str) -> Device: + device = ctx.find_device(name) + if device is None: + raise KeyError(f"Could not find a device named {name}") + return device + + for proto in BLUESKY_PROTOCOLS: + yield TypeValidatorDefinition(proto, get_device) diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 95c228a06..957a057f7 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,48 +1,6 @@ -from functools import partial -from typing import Any, Callable, Iterable, List, Optional, Type, TypeVar - -from apischema.conversions.conversions import Conversion - -from .bluesky_types import BLUESKY_PROTOCOLS, Device, is_bluesky_compatible_device - - -def create_bluesky_protocol_conversions( - device_lookup: Callable[[str], Device], -) -> Iterable[Conversion]: - """ - Generate a series of APISchema Conversions for the valid Device types. - The conversions use a user-defined function to lookup devices by name. - - Args: - device_lookup (Callable[[str], Device]): Function to lookup Device by name, - expects an Exception if name not - found - - Returns: - Iterable[Conversion]: Conversions for locating devices - """ - - def find_device_matching_name_and_type(target_type: Type, name: str) -> Any: - # Find the device in the - device = device_lookup(name) - - # The schema has asked for a particular protocol, at this point in the code we - # have found the device but need to check that it complies with the requested - # protocol. If it doesn't, it means there is a typing error in the plan. - if isinstance(device, target_type): - return device - else: - raise TypeError(f"{name} needs to be of type {target_type}") - - # Create a conversion for each type, the conversion function will automatically - # perform a structural subtyping check - for a_type in BLUESKY_PROTOCOLS: - yield Conversion( - partial(find_device_matching_name_and_type, a_type), - source=str, - target=a_type, - ) +from typing import Any, List, Optional, TypeVar +from .bluesky_types import Device, is_bluesky_compatible_device #: Device obeying Bluesky protocols D = TypeVar("D", bound=Device) diff --git a/src/blueapi/messaging/stomptemplate.py b/src/blueapi/messaging/stomptemplate.py index c69acd046..856abc001 100644 --- a/src/blueapi/messaging/stomptemplate.py +++ b/src/blueapi/messaging/stomptemplate.py @@ -8,12 +8,12 @@ from typing import Any, Callable, Dict, List, Optional, Set import stomp -from apischema import deserialize, serialize +from pydantic import parse_obj_as from stomp.exception import ConnectFailedException from stomp.utils import Frame from blueapi.config import StompConfig -from blueapi.utils import handle_all_exceptions +from blueapi.utils import handle_all_exceptions, serialize from .base import DestinationProvider, MessageListener, MessagingTemplate from .context import MessageContext @@ -140,7 +140,7 @@ def subscribe(self, destination: str, callback: MessageListener) -> None: def wrapper(frame: Frame) -> None: as_dict = json.loads(frame.body) - value = deserialize(obj_type, as_dict) + value = parse_obj_as(obj_type, as_dict) context = MessageContext( frame.headers["destination"], diff --git a/src/blueapi/plans/plans.py b/src/blueapi/plans/plans.py index de9baff93..da4ff3858 100644 --- a/src/blueapi/plans/plans.py +++ b/src/blueapi/plans/plans.py @@ -1,11 +1,8 @@ import operator from functools import reduce -from typing import Any, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, List, Mapping, Optional, Union import bluesky.plans as bp -from apischema import serialize -from apischema.conversions.conversions import Conversion -from apischema.conversions.converters import AnyConversion, default_serialization from bluesky.protocols import Movable, Readable from cycler import Cycler, cycler from scanspec.specs import Spec @@ -15,17 +12,20 @@ def scan( detectors: List[Readable], - spec: Spec[Movable], + axes_to_move: Mapping[str, Movable], + spec: Spec[str], metadata: Optional[Mapping[str, Any]] = None, ) -> MsgGenerator: """ Scan wrapping `bp.scan_nd` Args: - detectors (List[Readable]): List of readable devices, will take a reading at + detectors: List of readable devices, will take a reading at each point - spec (Spec[Movable]): ScanSpec modelling the path of the scan - metadata (Optional[Mapping[str, Any]], optional): Key-value metadata to include + axes_to_move: All axes involved in this scan, names and + objects + spec: ScanSpec modelling the path of the scan + metadata: Key-value metadata to include in exported data, defaults to None. @@ -38,41 +38,31 @@ def scan( metadata = { "detectors": [detector.name for detector in detectors], - "scanspec": serialize(spec, default_conversion=_convert_devices), - "shape": _shape(spec), + # "scanspec": serialize(spec, default_conversion=_convert_devices), + "shape": spec.shape(), **(metadata or {}), } - cycler = _scanspec_to_cycler(spec) + cycler = _scanspec_to_cycler(spec, axes_to_move) yield from bp.scan_nd(detectors, cycler, md=metadata) -# TODO: Use built-in scanspec utility method following completion of DAQ-4487 -def _shape(spec: Spec[Movable]) -> Tuple[int, ...]: - return tuple(len(dim) for dim in spec.calculate()) - - -def _convert_devices(a_type: Type[Any]) -> Optional[AnyConversion]: - if issubclass(a_type, Movable): - return Conversion(str, source=a_type) - else: - return default_serialization(a_type) - - -def _scanspec_to_cycler(spec: Spec) -> Cycler: +def _scanspec_to_cycler(spec: Spec[str], axes: Mapping[str, Movable]) -> Cycler: """ Convert a scanspec to a cycler for compatibility with legacy Bluesky plans such as `bp.scan_nd`. Use the midpoints of the scanspec since cyclers are noramlly used for software triggered scans. Args: - spec (Spec): A scanspec + spec: A scanspec + axes: Names and axes to move Returns: Cycler: A new cycler """ midpoints = spec.frames().midpoints + midpoints = {axes[name]: points for name, points in midpoints.items()} # Need to "add" the cyclers for all the axes together. The code below is # effectively: cycler(motor1, [...]) + cycler(motor2, [...]) + ... diff --git a/src/blueapi/service/app.py b/src/blueapi/service/app.py index c53e6b592..ff641d93c 100644 --- a/src/blueapi/service/app.py +++ b/src/blueapi/service/app.py @@ -77,12 +77,12 @@ def _on_run_request(self, message_context: MessageContext, task: RunPlan) -> Non reply_queue = message_context.reply_destination if reply_queue is not None: - response = TaskResponse(correlation_id) + response = TaskResponse(task_name=correlation_id) self._template.send(reply_queue, response) def _get_plans(self, message_context: MessageContext, message: PlanRequest) -> None: plans = list(map(PlanModel.from_plan, self._ctx.plans.values())) - response = PlanResponse(plans) + response = PlanResponse(plans=plans) assert message_context.reply_destination is not None self._template.send(message_context.reply_destination, response) @@ -91,7 +91,7 @@ def _get_devices( self, message_context: MessageContext, message: DeviceRequest ) -> None: devices = list(map(DeviceModel.from_device, self._ctx.devices.values())) - response = DeviceResponse(devices) + response = DeviceResponse(devices=devices) assert message_context.reply_destination is not None self._template.send(message_context.reply_destination, response) diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 9a9c6087d..ee220e64f 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -1,29 +1,27 @@ -from dataclasses import dataclass from typing import Iterable, List -from apischema import settings from bluesky.protocols import HasName +from pydantic import BaseModel, Field from blueapi.core import BLUESKY_PROTOCOLS, Device, Plan _UNKNOWN_NAME = "UNKNOWN" -settings.camel_case = True - -@dataclass -class DeviceModel: +class DeviceModel(BaseModel): """ Representation of a device """ - name: str - protocols: List[str] + name: str = Field(description="Name of the device") + protocols: List[str] = Field( + description="Protocols that a device conforms to, indicating its capabilities" + ) @classmethod def from_device(cls, device: Device) -> "DeviceModel": name = device.name if isinstance(device, HasName) else _UNKNOWN_NAME - return cls(name, list(_protocol_names(device))) + return cls(name=name, protocols=list(_protocol_names(device))) def _protocol_names(device: Device) -> Iterable[str]: @@ -32,8 +30,7 @@ def _protocol_names(device: Device) -> Iterable[str]: yield protocol.__name__ -@dataclass -class DeviceRequest: +class DeviceRequest(BaseModel): """ A query for devices """ @@ -41,30 +38,27 @@ class DeviceRequest: ... -@dataclass -class DeviceResponse: +class DeviceResponse(BaseModel): """ Response to a query for devices """ - devices: List[DeviceModel] + devices: List[DeviceModel] = Field(description="Devices available to use in plans") -@dataclass -class PlanModel: +class PlanModel(BaseModel): """ Representation of a plan """ - name: str + name: str = Field(description="Name of the plan") @classmethod def from_plan(cls, plan: Plan) -> "PlanModel": - return cls(plan.name) + return cls(name=plan.name) -@dataclass -class PlanRequest: +class PlanRequest(BaseModel): """ A query for plans """ @@ -72,19 +66,17 @@ class PlanRequest: ... -@dataclass -class PlanResponse: +class PlanResponse(BaseModel): """ Response to a query for plans """ - plans: List[PlanModel] + plans: List[PlanModel] = Field(description="Plans available to use by a worker") -@dataclass -class TaskResponse: +class TaskResponse(BaseModel): """ Acknowledgement that a task has started, includes its ID """ - task_name: str + task_name: str = Field(description="Unique identifier for the task") diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index ebb10c6e7..4d1bff4dc 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,12 +1,14 @@ from .config import ConfigLoader from .modules import load_module_all -from .schema import nested_deserialize_with_overrides, schema_for_func +from .serialization import serialize from .thread_exception import handle_all_exceptions +from .type_validator import TypeValidatorDefinition, create_model_with_type_validators __all__ = [ "handle_all_exceptions", - "nested_deserialize_with_overrides", - "schema_for_func", "load_module_all", "ConfigLoader", + "create_model_with_type_validators", + "TypeValidatorDefinition", + "serialize", ] diff --git a/src/blueapi/utils/config.py b/src/blueapi/utils/config.py index bcaa0c72c..93d14a2ab 100644 --- a/src/blueapi/utils/config.py +++ b/src/blueapi/utils/config.py @@ -2,10 +2,10 @@ from typing import Any, Generic, Mapping, Type, TypeVar import yaml -from apischema import deserialize +from pydantic import BaseModel, parse_obj_as #: Configuration schema dataclass -C = TypeVar("C") +C = TypeVar("C", bound=BaseModel) class ConfigLoader(Generic[C]): @@ -59,4 +59,4 @@ def load(self) -> C: C: Dataclass instance holding config """ - return deserialize(self._schema, self._values) + return parse_obj_as(self._schema, self._values) diff --git a/src/blueapi/utils/schema.py b/src/blueapi/utils/schema.py deleted file mode 100644 index a73ce7282..000000000 --- a/src/blueapi/utils/schema.py +++ /dev/null @@ -1,103 +0,0 @@ -from dataclasses import make_dataclass -from inspect import Parameter, signature -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar, Union - -from apischema import deserialize -from apischema.conversions.conversions import Conversion -from apischema.conversions.converters import AnyConversion, default_deserialization - - -def schema_for_func(func: Callable[..., Any]) -> Type: - """ - Generate a dataclass that acts as a schema for validation with apischema. - Inspect the parameters, default values and type annotations of a function and - generate the schema. - - Example: - - def foo(a: int, b: str, c: bool): - ... - - schema = schema_for_func(foo) - - Schema is the runtime equivalent of: - - @dataclass - class fooo_params: - a: int - b: str - c: bool - - Args: - func (Callable[..., Any]): The source function, all parameters must have type - annotations - - Raises: - TypeError: If a type annotation is either `Any` or not supplied - - Returns: - Type: A runtime dataclass whose fields encapsulate the names, types and default - values of the function parameters - """ - - class_name = f"{func.__name__}_params" - fields: List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] = [] - - # Iterate through parameters and convert them to dataclass fields - for name, param in signature(func).parameters.items(): - a_type = param.annotation - # Do not allow parameters without type annotations or with the `Any` annotation - if a_type is Parameter.empty: - raise TypeError( - f"Error serializing function {func.__name__}, all parameters must have " - "a type annotation" - ) - elif a_type is Any: - raise TypeError( - f"Error serializing function {func.__name__} parameter {name} all " - "parameters cannot have `Any` as a type annotation" - ) - - default_value = param.default - - # Include the default value in the field if there is onee - if default_value is not Parameter.empty: - fields.append((name, a_type, default_value)) - else: - fields.append((name, a_type)) - - data_class = make_dataclass(class_name, fields) - return data_class - - -T = TypeVar("T") - - -def nested_deserialize_with_overrides( - schema: Type[T], obj: Any, overrides: Optional[Iterable[Conversion]] = None -) -> T: - """ - Deserialize a dictionary using apischema with custom overrides. Unlike apischema's - built-in override argument, this propagates the overrides to nested dictionaries. - - Args: - schema (Type[T]): Type to deserialize to - obj (Any): Raw object to deserialize, usually a dictionary - overrides (Optional[Iterable[Conversion]], optional): apischema conversions to - customize deserialization. - Defaults to None. - - Returns: - T: Deserialized object - """ - - conversions = {conversion.target: conversion for conversion in overrides or []} - - def deserialize_with_converters(a_type: Type[Any]) -> Optional[AnyConversion]: - # If the type is in _conversions then we can override the function used to - # resolve the parameter, otherwise we use apischema's default deserializer - if a_type in conversions.keys(): - return conversions[a_type] - return default_deserialization(a_type) - - return deserialize(schema, obj, default_conversion=deserialize_with_converters) diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py new file mode 100644 index 000000000..141d0b702 --- /dev/null +++ b/src/blueapi/utils/serialization.py @@ -0,0 +1,24 @@ +from typing import Any + +from pydantic import BaseModel + + +def serialize(obj: Any) -> Any: + """ + Pydantic-aware serialization routine that can also be + used on primitives. So serialize(4) is 4, but + serialize() is a dictionary. + + Args: + obj: The object to serialize + + Returns: + Any: The serialized object + """ + + if isinstance(obj, BaseModel): + return obj.dict() + elif hasattr(obj, "__pydantic_model__"): + return serialize(getattr(obj, "__pydantic_model__")) + else: + return obj diff --git a/src/blueapi/utils/type_validator.py b/src/blueapi/utils/type_validator.py new file mode 100644 index 000000000..461324c43 --- /dev/null +++ b/src/blueapi/utils/type_validator.py @@ -0,0 +1,340 @@ +from collections.abc import Mapping as AbcMapping +from dataclasses import dataclass +from inspect import Parameter, isclass, signature +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + get_args, + overload, +) + +from pydantic import BaseConfig, BaseModel, create_model, validator +from pydantic.fields import Undefined + +if TYPE_CHECKING: + from pydantic.typing import AnyCallable, AnyClassMethod +else: + AnyCallable, AnyClassMethod = Any, Any + + +_PYDANTIC_LIST_TYPES: List[Type] = [List, Tuple, Set] # type: ignore +_PYDANTIC_DICT_TYPES: List[Type] = [Dict, Mapping] + +T = TypeVar("T") +U = TypeVar("U") +FieldDefinition = Tuple[Type, Any] +Fields = Mapping[str, FieldDefinition] +Validator = Union[Callable[[AnyCallable], AnyClassMethod], classmethod] + + +@dataclass +class TypeValidatorDefinition(Generic[T]): + """ + Definition of a validator to be applied to all + types during validation. + + Args: + field_type: Convert all fields of this type + func: Convert using this function + """ + + field_type: Type[T] + func: Callable[[Any], T] + + def __str__(self) -> str: + type_name = getattr( + self.field_type, "__name__", str(hash(str(self.field_type))) + ) + return f"converter_{type_name}" + + +@overload +def create_model_with_type_validators( + name: str, + definitions: List[TypeValidatorDefinition], + *, + fields: Fields, + config: Optional[Type[BaseConfig]] = None, +) -> Type[BaseModel]: + """ + Create a model based on the fields supplied + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + fields: Definitions of fields from which to make the model. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model with the fields and + type validators supplied. + """ + + ... + + +@overload +def create_model_with_type_validators( + name: str, + definitions: List[TypeValidatorDefinition], + *, + func: Callable[..., Any], + config: Optional[Type[BaseConfig]] = None, +) -> Type[BaseModel]: + """ + Create a model from a function's parameters with type + validators. + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + func: The model is constructed from the function parameters, + which must be type-annotated. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model based on the + function parameters. + """ + + ... + + +@overload +def create_model_with_type_validators( + name: str, + definitions: List[TypeValidatorDefinition], + *, + base: Type[BaseModel], +) -> Type[BaseModel]: + """ + Apply type validators to an existing model + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + base (Type[BaseModel]): Base class for the model + + Returns: + Type[BaseModel]: A new version of `base` with type validators + """ + + ... + + +def create_model_with_type_validators( + name: str, + definitions: List[TypeValidatorDefinition], + *, + fields: Optional[Fields] = None, + base: Optional[Type[BaseModel]] = None, + func: Optional[Callable[..., Any]] = None, + config: Optional[Type[BaseConfig]] = None, + cache: Optional[Dict[Type, Type]] = None, +) -> Type[BaseModel]: + """ + Create a pydantic model with type validators according to + definitions given. Validators are applied to all fields + of a particular type. + + Args: + name: Name of the new model + definitions: Definitions of how to validate which types of field + fields: Definitions of fields from which to make the model. + Defaults to None. + base: Optional base class for the model. Defaults to None. + func: Function, if supplied, the model is constructed from the + function parameters, which must be type-annotated. + Defaults to None. + config: Pydantic config for the model. Defaults to None. + + Returns: + Type[BaseModel]: A new pydantic model + """ + + cache = cache or {} + all_fields = {**(fields or {})} + if base is not None: + all_fields = {**all_fields, **_extract_fields_from_model(base)} + if func is not None: + all_fields = {**all_fields, **_extract_fields_from_function(func)} + for name, field in all_fields.items(): + annotation, val = field + if annotation in cache: + all_fields[name] = cache[annotation], val + else: + all_fields[name] = apply_type_validators(annotation, definitions), val + # model_type = find_model_type(annotation) + # if model_type is not None: + # recursed = create_model_with_type_validators( + # annotation.__name__, definitions, base=model_type + # ) + # all_fields[name] = recursed, val + validators = _type_validators(all_fields, definitions) + return create_model( # type: ignore + name, **all_fields, __base__=base, __validators__=validators, __config__=config + ) + + +def apply_type_validators( + model_type: Type, + definitions: List[TypeValidatorDefinition], + cache: Optional[Dict[Type, Type]] = None, +) -> Type: + cache = cache or {} + if model_type in cache: + return cache[model_type] + + if isclass(model_type) and issubclass(model_type, BaseModel): + if "__root__" in model_type.__fields__: + return apply_type_validators( + model_type.__fields__["__root__"].type_, definitions, cache=cache + ) + else: + return create_model_with_type_validators( + model_type.__name__, + definitions, + base=model_type, + ) + elif isclass(model_type) and hasattr(model_type, "__pydantic_model__"): + model = getattr(model_type, "__pydantic_model__") + return apply_type_validators(model, definitions, cache=cache) + else: + params = [ + apply_type_validators(param, definitions, cache=cache) + for param in get_args(model_type) + ] + if params and hasattr(model_type, "__origin__"): + origin = getattr(model_type, "__origin__") + origin = _sanitise_origin(origin) + return origin[tuple(params)] + return model_type + + +def _sanitise_origin(origin: Type) -> Type: + return { # type: ignore + list: List, + set: Set, + tuple: Tuple, + AbcMapping: Mapping, + dict: Mapping, + }.get(origin, origin) + + +def _extract_fields_from_model(model: Type[BaseModel]) -> Fields: + return { + name: (field.type_, field.field_info) + for name, field in model.__fields__.items() + } + + +def _extract_fields_from_function(func: Callable[..., Any]) -> Fields: + fields: Dict[str, FieldDefinition] = {} + for name, param in signature(func).parameters.items(): + type_annotation = param.annotation + if type_annotation is Parameter.empty: + raise TypeError(f"Missing type annotation for parameter {name}") + default_value = param.default + if default_value is Parameter.empty: + default_value = Undefined + + anno = (type_annotation, default_value) + fields[name] = anno + + return fields + + +def _type_validators( + fields: Fields, + definitions: Iterable[TypeValidatorDefinition], +) -> Mapping[str, Validator]: + """ + Generate type validators from fields and definitions. + + Args: + fields: fields to validate. + definitions: Definitions of how to validate which types of field + + Raises: + TypeError: If a validator can be applied to more than one field. + + Returns: + Mapping[str, Validator]: Dict-like structure mapping validator + names to pydantic validators. + """ + + all_validators = {} + + for definition in definitions: + field_names = _determine_fields_of_type(fields, definition.field_type) + for name in field_names: + val = _make_type_validator(name, definition) + val_method_name = f"validate_{name}" + if val_method_name in all_validators: + raise TypeError(f"Ambiguous type validator for field: {name}") + all_validators[val_method_name] = val + + return all_validators + + +def _make_type_validator(name: str, definition: TypeValidatorDefinition) -> Validator: + def validate_type(value: Any) -> Any: + return apply_to_scalars(definition.func, value) + + return validator(name, allow_reuse=True, pre=True, always=True)(validate_type) + + +def _determine_fields_of_type(fields: Fields, field_type: Type) -> Iterable[str]: + for name, field in fields.items(): + annotation, _ = field + if is_type_or_container_type(annotation, field_type): + yield name + + +def is_type_or_container_type(type_to_check: Type, field_type: Type) -> bool: + return params_contains(type_to_check, field_type) + + +def params_contains(type_to_check: Type, field_type: Type) -> bool: + type_params = get_args(type_to_check) + return type_to_check is field_type or any( + map(lambda v: params_contains(v, field_type), type_params) + ) + + +def apply_to_scalars(func: Callable[[T], U], obj: Any) -> Any: + if is_list_type(obj): + return list(map(lambda v: apply_to_scalars(func, v), obj)) + elif is_dict_type(obj): + return {k: apply_to_scalars(func, v) for k, v in obj.items()} + else: + return func(obj) + + +def is_list_type(obj: Any) -> bool: + return any(map(lambda t: isinstance(obj, t), _PYDANTIC_LIST_TYPES)) + + +def is_dict_type(obj: Any) -> bool: + return any(map(lambda t: isinstance(obj, t), _PYDANTIC_DICT_TYPES)) + + +def find_model_type(anno: Type) -> Optional[Type[BaseModel]]: + if isclass(anno): + if issubclass(anno, BaseModel): + return anno + elif hasattr(anno, "__pydantic_model__"): + return getattr(anno, "__pydantic_model__") + return None diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 75edc413d..457ec443c 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass, field from enum import Enum from typing import List, Mapping, Optional, Union from bluesky.run_engine import RunEngineStateMachine +from pydantic import BaseModel, Field from super_state_machine.extras import PropertyMachine, ProxyString # The RunEngine can return any of these three types as its state RawRunEngineState = Union[PropertyMachine, ProxyString, str] -class WorkerState(Enum): +class WorkerState(str, Enum): """ The state of the Worker. """ @@ -27,42 +27,70 @@ class WorkerState(Enum): @classmethod def from_bluesky_state(cls, bluesky_state: RawRunEngineState) -> "WorkerState": + """Convert the state of a bluesky RunEngine + + Args: + bluesky_state: Bluesky RunEngine state + + Returns: + RunnerState: Mapped RunEngine state + """ + if isinstance(bluesky_state, RunEngineStateMachine.States): return cls.from_bluesky_state(bluesky_state.value) return WorkerState(str(bluesky_state).upper()) -@dataclass -class StatusView: +class StatusView(BaseModel): """ - A snapshot of a Status, optionally representing progress + A snapshot of a Status of an operation, optionally representing progress """ - display_name: str = "UNKNOWN" - current: Optional[float] = None - initial: Optional[float] = None - target: Optional[float] = None - unit: str = "units" - precision: int = 3 - done: bool = False - percentage: Optional[float] = None - time_elapsed: Optional[float] = None - time_remaining: Optional[float] = None - - -@dataclass -class ProgressEvent: + display_name: str = Field( + description="Human-readable name indicating what this status describes", + default="Unknown", + ) + current: Optional[float] = Field( + description="Current value of operation progress, if known", default=None + ) + initial: Optional[float] = Field( + description="Initial value of operation progress, if known", default=None + ) + target: Optional[float] = Field( + description="Target value operation of progress, if known", default=None + ) + unit: str = Field(description="Units of progress", default="units") + precision: int = Field( + description="Sensible precision of progress to display", default=3 + ) + done: bool = Field( + description="Whether the operation this status describes is complete", + default=False, + ) + percentage: Optional[float] = Field( + description="Percentage of status completion, if known", default=None + ) + time_elapsed: Optional[float] = Field( + description="Time elapsed since status operation beginning, if known", + default=None, + ) + time_remaining: Optional[float] = Field( + description="Estimated time remaining until operation completion, if known", + default=None, + ) + + +class ProgressEvent(BaseModel): """ Event describing the progress of processes within a running task, such as moving motors and exposing detectors. """ task_name: str - statuses: Mapping[str, StatusView] = field(default_factory=dict) + statuses: Mapping[str, StatusView] = Field(default_factory=dict) -@dataclass -class TaskStatus: +class TaskStatus(BaseModel): """ Status of a task the worker is running. """ @@ -72,8 +100,7 @@ class TaskStatus: task_failed: bool -@dataclass -class WorkerEvent: +class WorkerEvent(BaseModel): """ Event describing the state of the worker and any tasks it's running. Includes error and warning information. @@ -81,8 +108,8 @@ class WorkerEvent: state: WorkerState task_status: Optional[TaskStatus] = None - errors: List[str] = field(default_factory=list) - warnings: List[str] = field(default_factory=list) + errors: List[str] = Field(default_factory=list) + warnings: List[str] = Field(default_factory=list) def is_error(self) -> bool: return (self.task_status is not None and self.task_status.task_failed) or bool( diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 382d99415..1415325a4 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -141,22 +141,29 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( - self._current.name, - self._current.is_complete, - self._current.is_error or bool(errors), + task_name=self._current.name, + task_complete=self._current.is_complete, + task_failed=self._current.is_error or bool(errors), ) correlation_id = self._current.name else: task_status = None correlation_id = None - event = WorkerEvent(self._state, task_status, errors, warnings) + event = WorkerEvent( + state=self._state, + task_status=task_status, + errors=errors, + warnings=warnings, + ) self._worker_events.publish(event, correlation_id) def _on_document(self, name: str, document: Mapping[str, Any]) -> None: if self._current is not None: correlation_id = self._current.name - self._data_events.publish(DataEvent(name, document), correlation_id) + self._data_events.publish( + DataEvent(name=name, doc=document), correlation_id + ) else: raise KeyError( "Trying to emit a document despite the fact that the RunEngine is idle" @@ -204,16 +211,16 @@ def _on_status_event( else: percentage = 1.0 view = StatusView( - name or "UNKNOWN", - current, - initial, - target, - unit or "units", - precision or 3, - status.done, - percentage, - time_elapsed, - time_remaining, + display_name=name or "UNKNOWN", + current=current, + initial=initial, + target=target, + unit=unit or "units", + precision=precision or 3, + done=status.done, + percentage=percentage, + time_elapsed=time_elapsed, + time_remaining=time_remaining, ) self._status_snapshot[status_name] = view self._publish_status_snapshot() @@ -224,7 +231,7 @@ def _publish_status_snapshot(self) -> None: else: self._progress_events.publish( ProgressEvent( - self._current.name, + task_name=self._current.name, statuses=self._status_snapshot, ), self._current.name, diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index d136b65bd..efbb01313 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,40 +1,19 @@ import logging from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Mapping, Union +from dataclasses import dataclass +from typing import Any, Mapping -from apischema import deserializer, identity, serializer -from apischema.conversions import Conversion +from pydantic import BaseModel, Field, parse_obj_as -from blueapi.core import ( - BlueskyContext, - Device, - Plan, - create_bluesky_protocol_conversions, -) -from blueapi.utils import nested_deserialize_with_overrides +from blueapi.core import BlueskyContext, Plan # TODO: Make a TaggedUnion -class Task(ABC): +class Task(ABC, BaseModel): """ Object that can run with a TaskContext """ - _union: Any = None - - # You can use __init_subclass__ to register new subclass automatically - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - # Deserializers stack directly as a Union - deserializer(Conversion(identity, source=cls, target=Task)) - # Only Base serializer must be registered (and updated for each subclass) as - # a Union, and not be inherited - Task._union = cls if Task._union is None else Union[Task._union, cls] - serializer( - Conversion(identity, source=Task, target=Task._union, inherited=False) - ) - @abstractmethod def do_task(self, __ctx: BlueskyContext) -> None: """ @@ -48,38 +27,43 @@ def do_task(self, __ctx: BlueskyContext) -> None: LOGGER = logging.getLogger(__name__) -@dataclass class RunPlan(Task): """ Task that will run a plan """ - name: str - params: Mapping[str, Any] = field(default_factory=dict) - # plan: Generator[Msg, None, Any] + name: str = Field(description="Name of plan to run") + params: Mapping[str, Any] = Field( + description="Values for parameters to plan, if any", default_factory=dict + ) def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") plan = ctx.plans[self.name] - plan_function = ctx.plan_functions[self.name] - sanitized_params = lookup_params(ctx, plan, self.params) - plan_generator = plan_function(**sanitized_params) + func = ctx.plan_functions[self.name] + sanitized_params = _lookup_params(ctx, plan, self.params) + plan_generator = func(**sanitized_params.dict()) ctx.run_engine(plan_generator) -def lookup_params( +def _lookup_params( ctx: BlueskyContext, plan: Plan, params: Mapping[str, Any] -) -> Mapping[str, Any]: - def find_device(name: str) -> Device: - device = ctx.find_device(name) - if device is not None: - return device - else: - raise KeyError(f"Could not find device {name}") - - overrides = list(create_bluesky_protocol_conversions(find_device)) - return nested_deserialize_with_overrides(plan.model, params, overrides).__dict__ +) -> BaseModel: + """ + Checks plan parameters against context + + Args: + ctx: Context holding plans and devices + plan: Plan object including schema + params: Parameter values to be validated against schema + + Returns: + Mapping[str, Any]: _description_ + """ + + model = plan.model + return parse_obj_as(model, params) @dataclass diff --git a/tests/core/test_device_lookup.py b/tests/core/test_device_lookup.py deleted file mode 100644 index 7bac1d908..000000000 --- a/tests/core/test_device_lookup.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Type -from unittest.mock import MagicMock - -import pytest - -from blueapi.core import BLUESKY_PROTOCOLS, create_bluesky_protocol_conversions - - -@pytest.mark.parametrize("a_type", BLUESKY_PROTOCOLS) -def test_creates_resolver_for(a_type: Type[Any]): - converters = create_bluesky_protocol_conversions(MagicMock()) - target_types = map(lambda c: c.target, converters) - assert a_type in list(target_types) diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 35a775382..812838660 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -1,10 +1,10 @@ import itertools from concurrent.futures import Future -from dataclasses import dataclass from queue import Queue from typing import Any, Iterable, Type import pytest +from pydantic import BaseModel from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate @@ -97,8 +97,7 @@ def server(ctx: MessageContext, message: str) -> None: assert reply == "ack" -@dataclass -class Foo: +class Foo(BaseModel): a: int b: str @@ -106,7 +105,7 @@ class Foo: @pytest.mark.stomp @pytest.mark.parametrize( "message,message_type", - [("test", str), (1, int), (Foo(1, "test"), Foo)], + [("test", str), (1, int), (Foo(a=1, b="test"), Foo)], ) def test_deserialization( template: MessagingTemplate, test_queue: str, message: Any, message_type: Type diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index f4abb3b4c..6722ad5e1 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -1,35 +1,30 @@ import os -from dataclasses import dataclass, field from pathlib import Path from typing import Any, Type import pytest -from apischema import ValidationError +from pydantic import BaseModel, Field, ValidationError from blueapi.utils import ConfigLoader -@dataclass -class Config: +class Config(BaseModel): foo: int bar: str -@dataclass -class ConfigWithDefaults: +class ConfigWithDefaults(BaseModel): foo: int = 3 bar: str = "hello world" -@dataclass -class NestedConfig: +class NestedConfig(BaseModel): nested: Config baz: bool -@dataclass -class NestedConfigWithDefaults: - nested: ConfigWithDefaults = field(default_factory=ConfigWithDefaults) +class NestedConfigWithDefaults(BaseModel): + nested: ConfigWithDefaults = Field(default_factory=ConfigWithDefaults) baz: bool = False diff --git a/tests/utils/test_schema.py b/tests/utils/test_schema.py deleted file mode 100644 index 48728d9d7..000000000 --- a/tests/utils/test_schema.py +++ /dev/null @@ -1,39 +0,0 @@ -import dataclasses -from typing import Any - -import pytest - -from blueapi.utils import schema_for_func - - -def test_schema_generated() -> None: - def func(foo: int, bar: str = "hello") -> None: - ... - - schema = schema_for_func(func) - assert dataclasses.is_dataclass(schema) - foo, bar = dataclasses.fields(schema) - - assert foo.name == "foo" - assert foo.type == int - assert foo.default == dataclasses.MISSING - - assert bar.name == "bar" - assert bar.type == str - assert bar.default == "hello" - - -def test_rejects_any() -> None: - def func(foo: int, bar: Any) -> None: - ... - - with pytest.raises(TypeError): - schema_for_func(func) - - -def test_rejects_no_param() -> None: - def func(foo: int, bar) -> None: - ... - - with pytest.raises(TypeError): - schema_for_func(func) diff --git a/tests/utils/test_type_validator.py b/tests/utils/test_type_validator.py new file mode 100644 index 000000000..db9303e4d --- /dev/null +++ b/tests/utils/test_type_validator.py @@ -0,0 +1,563 @@ +from typing import Any, Dict, List, Literal, Mapping, Optional, Set, Tuple, Type, Union + +import pytest +from pydantic import BaseConfig, BaseModel, Field, parse_obj_as +from pydantic.dataclasses import dataclass +from pydantic.fields import Undefined +from scanspec.regions import Circle +from scanspec.specs import Line, Spec + +from blueapi.utils import TypeValidatorDefinition, create_model_with_type_validators + + +class DefaultConfig(BaseConfig): + arbitrary_types_allowed = True + + +_REG: Mapping[str, int] = { + letter: number for number, letter in enumerate("abcdefghijklmnopqrstuvwxyz") +} + + +class ComplexObject: + _name: str + + def __init__(self, name: str) -> None: + self._name = name + + def name(self) -> str: + return self._name + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ComplexObject) and __value.name() == self._name + + def __str__(self) -> str: + return f"ComplexObject({self._name})" + + def __repr__(self) -> str: + return f"ComplexObject({self._name})" + + +class SpecWrapper(BaseModel): + spec: Spec + + +def spec_wrapper(spec: Spec) -> None: + ... + + +class Bar(BaseModel): + a: int + b: ComplexObject + type: Literal["Bar"] = Field(default="Bar") + + class Config: + arbitrary_types_allowed = True + + +class Baz(BaseModel): + obj: Bar + c: str + type: Literal["Baz"] = Field(default="Baz") + + +class ComplexLinkedList(BaseModel): + obj: ComplexObject + child: Optional["ComplexLinkedList"] = None + + class Config: + arbitrary_types_allowed = True + + +@dataclass(config=DefaultConfig) +class DataclassBar: + a: int + b: ComplexObject + + +@dataclass +class DataclassBaz: + obj: DataclassBar + c: str + + +@dataclass +class DataclassMixed: + obj: Bar + c: str + + +def foo(a: int, b: str) -> None: + ... + + +def bar(obj: ComplexObject) -> None: + ... + + +def baz(bar: Bar) -> None: + ... + + +_DB: Mapping[str, ComplexObject] = {name: ComplexObject(name) for name in _REG.keys()} + + +def lookup(letter: str) -> int: + assert type(letter) is str, f"Expteced a string, got a {type(letter)}" + return _REG[letter] + + +def has_even_length(msg: str) -> bool: + assert type(msg) is str, f"Expteced a string, got a {type(msg)}" + return len(msg) % 2 == 0 + + +def lookup_complex(name: str) -> ComplexObject: + assert type(name) is str, f"Expteced a string, got a {type(name)}" + return _DB[name] + + +def test_validates_single_type() -> None: + assert_validates_single_type(int, "c", 2) + + +def test_leaves_unvalidated_types_alone() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"a": (int, Undefined), "b": (str, Undefined)}, + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) + assert parsed.a == 2 # type: ignore + assert parsed.b == "hello" # type: ignore + + +def test_validates_multiple_types() -> None: + model = create_model_with_type_validators( + "Foo", + [ + TypeValidatorDefinition(int, lookup), + TypeValidatorDefinition(bool, has_even_length), + ], + fields={"a": (int, Undefined), "b": (bool, Undefined)}, + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello"}) + assert parsed.a == 2 # type: ignore + assert parsed.b is False # type: ignore + + +def test_validates_multiple_fields() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"a": (int, Undefined), "b": (int, Undefined)}, + ) + parsed = parse_obj_as(model, {"a": "c", "b": "d"}) + assert parsed.a == 2 # type: ignore + assert parsed.b == 3 # type: ignore + + +def test_validates_multiple_fields_and_types() -> None: + model = create_model_with_type_validators( + "Foo", + [ + TypeValidatorDefinition(int, lookup), + TypeValidatorDefinition(bool, has_even_length), + ], + fields={ + "a": (int, Undefined), + "b": (bool, Undefined), + "c": (int, Undefined), + "d": (bool, Undefined), + }, + ) + parsed = parse_obj_as(model, {"a": "c", "b": "hello", "c": "d", "d": "word"}) + assert parsed.a == 2 # type: ignore + assert parsed.b is False # type: ignore + assert parsed.c == 3 # type: ignore + assert parsed.d is True # type: ignore + + +def test_does_not_tolerate_multiple_converters_for_same_type() -> None: + with pytest.raises(TypeError): + create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup), TypeValidatorDefinition(int, int)], + fields={"a": (int, Undefined), "b": (int, Undefined)}, + ) + + +def test_validates_list_type() -> None: + assert_validates_single_type(List[int], ["a", "b", "c"], [0, 1, 2]) + + +def test_validates_set_type() -> None: + assert_validates_single_type(Set[int], ["a", "b", "c"], {0, 1, 2}) + + +def test_validates_tuple_type() -> None: + assert_validates_single_type( + Tuple[int, ...], # type: ignore + [ + "a", + "b", + "c", + ], + (0, 1, 2), + ) + + +def test_validates_nested_container_type() -> None: + assert_validates_single_type( + List[Set[Tuple[int, int]]], + [[["a", "b"], ["c", "d"]], [["e", "f"]]], + [{(0, 1), (2, 3)}, {(4, 5)}], + ) + + +@pytest.mark.parametrize("dict_type", [Dict, Mapping]) +def test_validates_dict_type(dict_type: Type) -> None: + assert_validates_single_type( + dict_type[str, int], + { + "a": "a", + "b": "b", + "c": "c", + }, + { + "a": 0, + "b": 1, + "c": 2, + }, + ) + + +def test_validates_nested_mapping() -> None: + assert_validates_single_type( + Dict[str, List[int]], + { + "a": ["a", "b"], + "b": ["c", "d", "e"], + "c": ["f"], + }, + { + "a": [0, 1], + "b": [2, 3, 4], + "c": [5], + }, + ) + + +def test_validates_complex_object() -> None: + assert_validates_complex_object(ComplexObject, "d", ComplexObject("d")) + + +def test_validates_complex_object_list() -> None: + assert_validates_complex_object( + List[ComplexObject], + ["a", "b", "c"], + [ + ComplexObject("a"), + ComplexObject("b"), + ComplexObject("c"), + ], + ) + + +def test_applies_to_base() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + base=Bar, + ) + parsed = parse_obj_as(model, {"a": 2, "b": "g"}) + assert parsed.a == 2 # type: ignore + assert parsed.b == ComplexObject("g") # type: ignore + + +def test_applies_to_nested_base() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + base=Baz, + ) + parsed = parse_obj_as(model, {"obj": {"a": 2, "b": "g"}, "c": "hello"}) + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore + assert parsed.c == "hello" # type: ignore + + +def test_validates_submodel() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (Bar, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "a": 2, + "b": "g", + }, + }, + ) + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore + + +def test_validates_nested_submodel() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (Baz, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore + + +def test_validates_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (DataclassBar, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "a": 2, + "b": "g", + }, + }, + ) + assert parsed.obj.a == 2 # type: ignore + assert parsed.obj.b == ComplexObject("g") # type: ignore + + +def test_validates_nested_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (DataclassBaz, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore + + +def test_validates_mixed_dataclass() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (DataclassMixed, Undefined)}, + ) + parsed = parse_obj_as( + model, + { + "obj": { + "obj": { + "a": 2, + "b": "g", + }, + "c": "hello", + } + }, + ) + assert parsed.obj.obj.a == 2 # type: ignore + assert parsed.obj.obj.b == ComplexObject("g") # type: ignore + assert parsed.obj.c == "hello" # type: ignore + + +def test_validates_default_value() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"a": (int, "e")}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).a == 4 # type: ignore + + +def test_validates_complex_value() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (ComplexObject, "t")}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).obj == ComplexObject("t") # type: ignore + + +def test_validates_field_info() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"a": (int, Field(default="f"))}, + config=DefaultConfig, + ) + assert parse_obj_as(model, {}).a == 5 # type: ignore + + +SPECS = [ + Line("x", 0.0, 10.0, 10), + Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10), + (Line("x", 0.0, 10.0, 10) * Line("y", 0.0, 10.0, 10)) + & Circle("x", "y", 1.0, 2.8, radius=0.5), +] + + +@pytest.mark.parametrize("spec", SPECS) +def test_validates_scanspec(spec: Spec) -> None: + assert parse_spec(spec).spec == spec # type: ignore + + +@pytest.mark.parametrize("spec", SPECS) +def test_validates_scanspec_wrapper(spec: Spec) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"wrapper": (SpecWrapper, Undefined)}, + ) + parsed = parse_obj_as(model, {"wrapper": {"spec": spec.serialize()}}) + assert parsed.wrapper.spec == spec # type: ignore + + +@pytest.mark.parametrize("spec", SPECS) +def test_validates_scanspec_wrapping_function(spec: Spec) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + func=spec_wrapper, + ) + parsed = parse_obj_as(model, {"spec": spec.serialize()}) + assert parsed.spec == spec # type: ignore + + +def lookup_union(value: Union[int, str]) -> int: + if isinstance(value, str): + return lookup(value) + else: + return value + + +@pytest.mark.parametrize("value,expected", [(4, 4), ("b", 1)]) +def test_validates_union(value: Union[int, str], expected: int) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(Union[int, str], lookup_union)], # type: ignore + fields={"un": (Union[int, str], Undefined)}, # type: ignore + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"un": value}) + assert parsed.un == expected # type: ignore + + +def test_validates_model_union() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], # type: ignore + fields={ + "un": ( # type: ignore + Union[Bar, Baz], + Field(..., discriminator="type"), + ) + }, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"un": {"a": 5, "b": "g", "type": "Bar"}}) + assert parsed.un == Bar(a=5, b=ComplexObject("g")) # type: ignore + + +def test_model_from_simple_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", [TypeValidatorDefinition(int, lookup)], func=foo + ) + parsed = parse_obj_as(model, {"a": "g", "b": "hello"}) + assert parsed.a == 6 # type: ignore + assert parsed.b == "hello" # type: ignore + + +def test_model_from_complex_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + func=bar, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"obj": "f"}) + assert parsed.obj == ComplexObject("f") # type: ignore + + +def test_model_from_nested_function_signature() -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + func=baz, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"bar": {"a": 4, "b": "k"}}) + assert parsed.bar.a == 4 # type: ignore + assert parsed.bar.b == ComplexObject("k") # type: ignore + + +def parse_spec(spec: Spec) -> Any: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"spec": (Spec, Undefined)}, + ) + return parse_obj_as(model, {"spec": spec.serialize()}) + + +def assert_validates_single_type( + field_type: Type, input_value: Any, expected_output: Any +) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(int, lookup)], + fields={"ch": (field_type, Undefined)}, + ) + parsed = parse_obj_as(model, {"ch": input_value}) + assert parsed.ch == expected_output # type: ignore + + +def assert_validates_complex_object( + field_type: Type, + input_value: Any, + expected_output: Any, + default_value: Any = Undefined, +) -> None: + model = create_model_with_type_validators( + "Foo", + [TypeValidatorDefinition(ComplexObject, lookup_complex)], + fields={"obj": (field_type, default_value)}, + config=DefaultConfig, + ) + parsed = parse_obj_as(model, {"obj": input_value}) + assert parsed.obj == expected_output # type: ignore