Skip to content

Commit

Permalink
Handle parameterised generics (#598)
Browse files Browse the repository at this point in the history
Closes #597

---------

Co-authored-by: Callum Forrester <callum.forrester@diamond.ac.uk>
  • Loading branch information
2 people authored and ZohebShaikh committed Aug 29, 2024
1 parent 730191c commit 3cd7945
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def _rpc(


def _valid_return(value: Any, expected_type: type[T] | None = None) -> T:
if expected_type is None or expected_type is Any:
if expected_type is None:
return value
return TypeAdapter(expected_type).validate_python(value)
else:
return TypeAdapter(expected_type).validate_python(value)


def _validate_function(func: Any, function_name: str) -> Callable:
Expand Down
82 changes: 81 additions & 1 deletion tests/service/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Generic, TypeVar
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
from pydantic import ValidationError
from ophyd import Callable
from pydantic import BaseModel, ValidationError

from blueapi.service import interface
from blueapi.service.model import EnvironmentResponse
Expand Down Expand Up @@ -161,3 +163,81 @@ def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher):
match="1 validation error for int",
):
started_runner.run(wrong_return_type)


T = TypeVar("T")


class SimpleModel(BaseModel):
a: int
b: str


class NestedModel(BaseModel):
nested: SimpleModel
c: bool


class GenericModel(BaseModel, Generic[T]):
a: T
b: str


def return_int() -> int:
return 1


def return_str() -> str:
return "hello"


def return_list() -> list[int]:
return [1, 2, 3]


def return_dict() -> dict[str, int]:
return {
"test": 1,
"other_test": 2,
}


def return_simple_model() -> SimpleModel:
return SimpleModel(a=1, b="hi")


def return_nested_model() -> NestedModel:
return NestedModel(nested=return_simple_model(), c=False)


def return_unbound_generic_model() -> GenericModel:
return GenericModel(a="foo", b="bar")


def return_bound_generic_model() -> GenericModel[int]:
return GenericModel(a=1, b="hi")


def return_explicitly_bound_generic_model() -> GenericModel[int]:
return GenericModel[int](a=1, b="hi")


@pytest.mark.parametrize(
"rpc_function",
[
return_int,
return_str,
return_list,
return_dict,
return_simple_model,
return_nested_model,
return_unbound_generic_model,
# https://github.com/pydantic/pydantic/issues/6870 return_bound_generic_model,
return_explicitly_bound_generic_model,
],
)
def test_accepts_return_type(
started_runner: WorkerDispatcher,
rpc_function: Callable[[], Any],
):
started_runner.run(rpc_function)

0 comments on commit 3cd7945

Please sign in to comment.