From 1400d6e4a5b1781fd0bd984d6d2e6876369ff425 Mon Sep 17 00:00:00 2001 From: Thore Wietzke Date: Fri, 20 Oct 2023 21:08:45 +0200 Subject: [PATCH 1/3] added from_cpp() method for the Eigen::Ref type_caster --- include/nanobind/eigen/dense.h | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/include/nanobind/eigen/dense.h b/include/nanobind/eigen/dense.h index 571b1900..acdf895c 100644 --- a/include/nanobind/eigen/dense.h +++ b/include/nanobind/eigen/dense.h @@ -372,6 +372,12 @@ struct type_caster, // Restrict to contiguous 'T' (limitation in Eigen, see PR #215) can_map_contiguous_memory_v; + using NDArray = + array_for_eigen_t, + const typename Ref::Scalar, + typename Ref::Scalar>>; + using NDArrayCaster = type_caster; + /// Eigen::Map caster with fixed strides using Map = Eigen::Map; using MapCaster = make_caster; @@ -429,6 +435,25 @@ struct type_caster, return false; } + static handle from_cpp(const Ref &v, rv_policy, cleanup_list *cleanup) noexcept { + size_t shape[ndim_v]; + int64_t strides[ndim_v]; + + if constexpr (ndim_v == 1) { + shape[0] = v.size(); + strides[0] = v.innerStride(); + } else { + shape[0] = v.rows(); + shape[1] = v.cols(); + strides[0] = v.rowStride(); + strides[1] = v.colStride(); + } + + return NDArrayCaster::from_cpp( + NDArray((void *) v.data(), ndim_v, shape, handle(), strides), + rv_policy::reference, cleanup); + } + operator Ref() { if constexpr (MaybeConvert) { if (dcaster.caster.value.is_valid()) From 1ad23f87f7e09d847c9a1441cc79ac530587b359 Mon Sep 17 00:00:00 2001 From: Thore Wietzke Date: Wed, 25 Oct 2023 12:28:36 +0200 Subject: [PATCH 2/3] added tests for modifying data from C++ with Python via Eigen::Ref --- tests/test_eigen.cpp | 38 ++++++++++++++++++++++++++++++++++---- tests/test_eigen.py | 15 +++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/tests/test_eigen.cpp b/tests/test_eigen.cpp index 4b6f1a99..e7c724c4 100644 --- a/tests/test_eigen.cpp +++ b/tests/test_eigen.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace nb = nanobind; @@ -205,17 +206,46 @@ NB_MODULE(test_eigen_ext, m) { m.def("castToMapCnstVXi", [](nb::object obj) { return nb::cast>(obj); }); - m.def("castToRefVXi", [](nb::object obj) -> Eigen::VectorXi { + m.def("castToRefVXi", [](nb::object obj) { return nb::cast>(obj); }); - m.def("castToRefCnstVXi", [](nb::object obj) -> Eigen::VectorXi { + m.def("castToRefCnstVXi", [](nb::object obj) { return nb::cast>(obj); }); - m.def("castToDRefCnstVXi", [](nb::object obj) -> Eigen::VectorXi { + m.def("castToDRefCnstVXi", [](nb::object obj) { return nb::cast>(obj); }); - m.def("castToRef03CnstVXi", [](nb::object obj) -> Eigen::VectorXi { + m.def("castToRef03CnstVXi", [](nb::object obj) { return nb::cast>>(obj); }); + class Base { + public: + ~Base() {}; + virtual void modRefData(Eigen::Ref a) {}; + virtual void modRefDataConst(Eigen::Ref a) {}; + }; + + class PyBase : public Base { + NB_TRAMPOLINE(Base, 2); + public: + void modRefData(Eigen::Ref a) override {NB_OVERRIDE_PURE(modRefData, a);} + void modRefDataConst(Eigen::Ref a) override {NB_OVERRIDE_PURE(modRefDataConst, a);} + }; + + nb::class_(m, "Base") + .def(nb::init<>()) + .def("modRefData", &Base::modRefData) + .def("modRefDataConst", &Base::modRefDataConst); + + m.def("modifyRef", [](Base* base) { + Eigen::Vector2d input {{1.0}, {2.0}}; + base->modRefData(input); + return input; + }); + m.def("modifyRefConst", [](Base* base) { + Eigen::Vector2d input {{1.0}, {2.0}}; + base->modRefDataConst(input); + return input; + }); } diff --git a/tests/test_eigen.py b/tests/test_eigen.py index 138fdf83..f8b125f7 100644 --- a/tests/test_eigen.py +++ b/tests/test_eigen.py @@ -352,3 +352,18 @@ def test12_cast(): for v in vec, vec2, vecf: with pytest.raises(RuntimeError, match='bad[_ ]cast'): t.castToRef03CnstVXi(v) + +@needs_numpy_and_eigen +def test13_mutate_python(): + class Derived(t.Base): + def modRefData(self, input): + input[0] = 3.0 + + def modRefDataConst(self, input): + input[0] = 3.0 + + vecRef = np.array([3.0, 2.0]) + der = Derived() + assert_array_equal(t.modifyRef(der), vecRef) + with pytest.raises(ValueError): + t.modifyRefConst(der) \ No newline at end of file From bd19de13beeca40d37e517e6e5d9073e619ff337 Mon Sep 17 00:00:00 2001 From: Thore Wietzke Date: Fri, 27 Oct 2023 08:02:40 +0200 Subject: [PATCH 3/3] reverted cast tests for Eigen --- include/nanobind/eigen/dense.h | 1 + tests/test_eigen.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/nanobind/eigen/dense.h b/include/nanobind/eigen/dense.h index acdf895c..c4c93fe6 100644 --- a/include/nanobind/eigen/dense.h +++ b/include/nanobind/eigen/dense.h @@ -435,6 +435,7 @@ struct type_caster, return false; } + /*Duplicate of Eigen::Map caster*/ static handle from_cpp(const Ref &v, rv_policy, cleanup_list *cleanup) noexcept { size_t shape[ndim_v]; int64_t strides[ndim_v]; diff --git a/tests/test_eigen.cpp b/tests/test_eigen.cpp index e7c724c4..6666c4a0 100644 --- a/tests/test_eigen.cpp +++ b/tests/test_eigen.cpp @@ -206,16 +206,16 @@ NB_MODULE(test_eigen_ext, m) { m.def("castToMapCnstVXi", [](nb::object obj) { return nb::cast>(obj); }); - m.def("castToRefVXi", [](nb::object obj) { + m.def("castToRefVXi", [](nb::object obj) -> Eigen::VectorXi { return nb::cast>(obj); }); - m.def("castToRefCnstVXi", [](nb::object obj) { + m.def("castToRefCnstVXi", [](nb::object obj) -> Eigen::VectorXi { return nb::cast>(obj); }); - m.def("castToDRefCnstVXi", [](nb::object obj) { + m.def("castToDRefCnstVXi", [](nb::object obj) -> Eigen::VectorXi { return nb::cast>(obj); }); - m.def("castToRef03CnstVXi", [](nb::object obj) { + m.def("castToRef03CnstVXi", [](nb::object obj) -> Eigen::VectorXi { return nb::cast>>(obj); });