From 3cd7945046032c0364bfd3b66a84a4c33be4d7e0 Mon Sep 17 00:00:00 2001 From: DiamondJoseph <53935796+DiamondJoseph@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:27:49 +0100 Subject: [PATCH] Handle parameterised generics (#598) Closes #597 --------- Co-authored-by: Callum Forrester --- src/blueapi/service/runner.py | 5 ++- tests/service/test_runner.py | 82 ++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index feca3b908..e122bf7d7 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -149,9 +149,10 @@ def _rpc( def _valid_return(value: Any, expected_type: type[T] | None = None) -> T: - if expected_type is None or expected_type is Any: + if expected_type is None: return value - return TypeAdapter(expected_type).validate_python(value) + else: + return TypeAdapter(expected_type).validate_python(value) def _validate_function(func: Any, function_name: str) -> Callable: diff --git a/tests/service/test_runner.py b/tests/service/test_runner.py index e3edde0c4..b0c66b745 100644 --- a/tests/service/test_runner.py +++ b/tests/service/test_runner.py @@ -1,8 +1,10 @@ +from typing import Any, Generic, TypeVar from unittest import mock from unittest.mock import MagicMock, patch import pytest -from pydantic import ValidationError +from ophyd import Callable +from pydantic import BaseModel, ValidationError from blueapi.service import interface from blueapi.service.model import EnvironmentResponse @@ -161,3 +163,81 @@ def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher): match="1 validation error for int", ): started_runner.run(wrong_return_type) + + +T = TypeVar("T") + + +class SimpleModel(BaseModel): + a: int + b: str + + +class NestedModel(BaseModel): + nested: SimpleModel + c: bool + + +class GenericModel(BaseModel, Generic[T]): + a: T + b: str + + +def return_int() -> int: + return 1 + + +def return_str() -> str: + return "hello" + + +def return_list() -> list[int]: + return [1, 2, 3] + + +def return_dict() -> dict[str, int]: + return { + "test": 1, + "other_test": 2, + } + + +def return_simple_model() -> SimpleModel: + return SimpleModel(a=1, b="hi") + + +def return_nested_model() -> NestedModel: + return NestedModel(nested=return_simple_model(), c=False) + + +def return_unbound_generic_model() -> GenericModel: + return GenericModel(a="foo", b="bar") + + +def return_bound_generic_model() -> GenericModel[int]: + return GenericModel(a=1, b="hi") + + +def return_explicitly_bound_generic_model() -> GenericModel[int]: + return GenericModel[int](a=1, b="hi") + + +@pytest.mark.parametrize( + "rpc_function", + [ + return_int, + return_str, + return_list, + return_dict, + return_simple_model, + return_nested_model, + return_unbound_generic_model, + # https://github.com/pydantic/pydantic/issues/6870 return_bound_generic_model, + return_explicitly_bound_generic_model, + ], +) +def test_accepts_return_type( + started_runner: WorkerDispatcher, + rpc_function: Callable[[], Any], +): + started_runner.run(rpc_function)