diff --git a/include/pybind11/typing.h b/include/pybind11/typing.h index bc275fc50b..2c9eaffbcb 100644 --- a/include/pybind11/typing.h +++ b/include/pybind11/typing.h @@ -63,6 +63,16 @@ class Callable : public function { using function::function; }; +template +class Union : public object { + using object::object; +}; + +template +class Optional : public object { + using object::object; +}; + PYBIND11_NAMESPACE_END(typing) PYBIND11_NAMESPACE_BEGIN(detail) @@ -121,5 +131,17 @@ struct handle_type_name> { + const_name("], ") + make_caster::name + const_name("]"); }; +template +struct handle_type_name> { + static constexpr auto name = const_name("Union[") + + ::pybind11::detail::concat(make_caster::name...) + + const_name("]"); +}; + +template +struct handle_type_name> { + static constexpr auto name = const_name("Optional[") + make_caster::name + const_name("]"); +}; + PYBIND11_NAMESPACE_END(detail) PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index f3709d40c6..e97dceee3e 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -844,4 +844,26 @@ TEST_SUBMODULE(pytypes, m) { m.def("annotate_iterator_int", [](const py::typing::Iterator &) {}); m.def("annotate_fn", [](const py::typing::Callable, py::str)> &) {}); + + m.def("annotate_union", + [](py::typing::List> l, + py::str a, + py::int_ b, + py::object c) -> py::typing::List> { + l.append(a); + l.append(b); + l.append(c); + return l; + }); + + m.def("union_typing_only", + [](py::typing::List> &l) + -> py::typing::List> { return l; }); + + m.def("annotate_optional", + [](py::list &list) -> py::typing::List> { + list.append(py::str("hi")); + list.append(py::none()); + return list; + }); } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 38edfd9998..ca2c51c405 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -955,3 +955,24 @@ def test_fn_annotations(doc): doc(m.annotate_fn) == "annotate_fn(arg0: Callable[[list[str], str], int]) -> None" ) + + +def test_union_annotations(doc): + assert ( + doc(m.annotate_union) + == "annotate_union(arg0: list[Union[str, int, object]], arg1: str, arg2: int, arg3: object) -> list[Union[str, int, object]]" + ) + + +def test_union_typing_only(doc): + assert ( + doc(m.union_typing_only) + == "union_typing_only(arg0: list[Union[str]]) -> list[Union[int]]" + ) + + +def test_optional_annotations(doc): + assert ( + doc(m.annotate_optional) + == "annotate_optional(arg0: list) -> list[Optional[str]]" + )