diff --git a/packages/exchange/src/exchange/utils.py b/packages/exchange/src/exchange/utils.py index b95f1c48..bcdd18f0 100644 --- a/packages/exchange/src/exchange/utils.py +++ b/packages/exchange/src/exchange/utils.py @@ -1,7 +1,7 @@ import inspect import uuid from importlib.metadata import entry_points -from typing import get_args, get_origin +from typing import Literal, get_args, get_origin, Any, Union from griffe import ( Docstring, @@ -107,16 +107,27 @@ def json_schema(func: any) -> dict[str, any]: # noqa: ANN401 return schema -def _map_type_to_schema(py_type: type) -> dict[str, any]: # noqa: ANN401 +def _map_type_to_schema(py_type: type) -> dict[str, Any]: origin = get_origin(py_type) args = get_args(py_type) - if origin is list or origin is tuple: - return {"type": "array", "items": _map_type_to_schema(args[0] if args else any)} + if origin is Union: + # Handle Optional[X], which is Union[X, NoneType] + non_none_args = [arg for arg in args if arg is not type(None)] + if len(non_none_args) == 1: + # This is Optional[X] + return _map_type_to_schema(non_none_args[0]) + else: + # General Union + return {"anyOf": [_map_type_to_schema(arg) for arg in non_none_args]} + elif origin is Literal: + return {"enum": list(args)} + elif origin in (list, tuple): + return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)} elif origin is dict: return { "type": "object", - "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else any), + "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any), } elif py_type is int: return {"type": "integer"} diff --git a/packages/exchange/tests/test_utils.py b/packages/exchange/tests/test_utils.py index 6bc00f9e..09f72970 100644 --- a/packages/exchange/tests/test_utils.py +++ b/packages/exchange/tests/test_utils.py @@ -1,3 +1,4 @@ +from typing import Literal import pytest from exchange import utils from unittest.mock import patch @@ -65,6 +66,34 @@ def dummy_func(a, b, c): assert "Attempted to load from a function" in str(e.value) +def test_parse_docstring_with_optional_params() -> None: + from typing import Optional, List + + def dummy_func(a: int, b: List[int], c: Literal["foo", "bar"] = "foo", d: Optional[str] = None) -> None: + """This function does something. + + Args: + a (int): The first required parameter. + b (List[int]): The second parameter. + c (Literal["foo", "bar"], optional): A parameter with a literal default value. Defaults to "foo". + d (Optional[str], optional): Optional fourth parameter. Defaults to None. + """ + pass + + description, parameters = utils.parse_docstring(dummy_func) + assert description == "This function does something." + assert parameters == [ + {"name": "a", "annotation": "int", "description": "The first required parameter."}, + {"name": "b", "annotation": "List[int]", "description": "The second parameter."}, + { + "name": "c", + "annotation": 'Literal["foo", "bar"]', + "description": 'A parameter with a literal default value. Defaults to "foo".', + }, + {"name": "d", "annotation": "Optional[str]", "description": "Optional fourth parameter. Defaults to None."}, + ] + + def test_json_schema() -> None: def dummy_func(a: int, b: str, c: list) -> None: pass @@ -82,6 +111,31 @@ def dummy_func(a: int, b: str, c: list) -> None: } +def test_json_schema_with_optional_params() -> None: + from typing import Optional, List + + def dummy_func( + a: int, + b: Literal["foo", "bar"] = "foo", + c: Optional[List[int]] = None, + d: Optional[str] = None, + ) -> None: + pass + + schema = utils.json_schema(dummy_func) + + assert schema == { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"enum": ["foo", "bar"], "default": "foo"}, + "c": {"type": "array", "items": {"type": "integer"}, "default": None}, + "d": {"type": "string", "default": None}, + }, + "required": ["a"], + } + + def test_load_plugins() -> None: class DummyEntryPoint: def __init__(self, name, plugin):