diff --git a/include/nanobind/eigen/dense.h b/include/nanobind/eigen/dense.h index 56c867e77..000f5f0a2 100644 --- a/include/nanobind/eigen/dense.h +++ b/include/nanobind/eigen/dense.h @@ -372,6 +372,12 @@ struct type_caster, // Restrict to contiguous 'T' (limitation in Eigen, see PR #215) can_map_contiguous_memory_v; + using NDArray = + array_for_eigen_t, + const typename Ref::Scalar, + typename Ref::Scalar>>; + using NDArrayCaster = type_caster; + /// Eigen::Map caster with fixed strides using Map = Eigen::Map; using MapCaster = make_caster; @@ -429,6 +435,26 @@ struct type_caster, return false; } + /*Duplicate of Eigen::Map caster*/ + static handle from_cpp(const Ref &v, rv_policy, cleanup_list *cleanup) noexcept { + size_t shape[ndim_v]; + int64_t strides[ndim_v]; + + if constexpr (ndim_v == 1) { + shape[0] = v.size(); + strides[0] = v.innerStride(); + } else { + shape[0] = v.rows(); + shape[1] = v.cols(); + strides[0] = v.rowStride(); + strides[1] = v.colStride(); + } + + return NDArrayCaster::from_cpp( + NDArray((void *) v.data(), ndim_v, shape, handle(), strides), + rv_policy::reference, cleanup); + } + operator Ref() { if constexpr (MaybeConvert) { if (dcaster.caster.value.is_valid()) diff --git a/tests/test_eigen.cpp b/tests/test_eigen.cpp index 4b6f1a99e..6666c4a09 100644 --- a/tests/test_eigen.cpp +++ b/tests/test_eigen.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace nb = nanobind; @@ -218,4 +219,33 @@ NB_MODULE(test_eigen_ext, m) { return nb::cast>>(obj); }); + class Base { + public: + ~Base() {}; + virtual void modRefData(Eigen::Ref a) {}; + virtual void modRefDataConst(Eigen::Ref a) {}; + }; + + class PyBase : public Base { + NB_TRAMPOLINE(Base, 2); + public: + void modRefData(Eigen::Ref a) override {NB_OVERRIDE_PURE(modRefData, a);} + void modRefDataConst(Eigen::Ref a) override {NB_OVERRIDE_PURE(modRefDataConst, a);} + }; + + nb::class_(m, "Base") + .def(nb::init<>()) + .def("modRefData", &Base::modRefData) + .def("modRefDataConst", &Base::modRefDataConst); + + m.def("modifyRef", [](Base* base) { + Eigen::Vector2d input {{1.0}, {2.0}}; + base->modRefData(input); + return input; + }); + m.def("modifyRefConst", [](Base* base) { + Eigen::Vector2d input {{1.0}, {2.0}}; + base->modRefDataConst(input); + return input; + }); } diff --git a/tests/test_eigen.py b/tests/test_eigen.py index 138fdf834..f8b125f78 100644 --- a/tests/test_eigen.py +++ b/tests/test_eigen.py @@ -352,3 +352,18 @@ def test12_cast(): for v in vec, vec2, vecf: with pytest.raises(RuntimeError, match='bad[_ ]cast'): t.castToRef03CnstVXi(v) + +@needs_numpy_and_eigen +def test13_mutate_python(): + class Derived(t.Base): + def modRefData(self, input): + input[0] = 3.0 + + def modRefDataConst(self, input): + input[0] = 3.0 + + vecRef = np.array([3.0, 2.0]) + der = Derived() + assert_array_equal(t.modifyRef(der), vecRef) + with pytest.raises(ValueError): + t.modifyRefConst(der) \ No newline at end of file