Skip to content

Commit

Permalink
Added a from_cpp() method for the Eigen::Ref type_caster (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThoreWietzke authored Nov 1, 2023
1 parent 45d9415 commit 7ae190c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/nanobind/eigen/dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,12 @@ struct type_caster<Eigen::Ref<T, Options, StrideType>,
// Restrict to contiguous 'T' (limitation in Eigen, see PR #215)
can_map_contiguous_memory_v<Ref>;

using NDArray =
array_for_eigen_t<Ref, std::conditional_t<std::is_const_v<T>,
const typename Ref::Scalar,
typename Ref::Scalar>>;
using NDArrayCaster = type_caster<NDArray>;

/// Eigen::Map<T> caster with fixed strides
using Map = Eigen::Map<T, Options, StrideType>;
using MapCaster = make_caster<Map>;
Expand Down Expand Up @@ -429,6 +435,26 @@ struct type_caster<Eigen::Ref<T, Options, StrideType>,
return false;
}

/*Duplicate of Eigen::Map caster*/
static handle from_cpp(const Ref &v, rv_policy, cleanup_list *cleanup) noexcept {
size_t shape[ndim_v<T>];
int64_t strides[ndim_v<T>];

if constexpr (ndim_v<T> == 1) {
shape[0] = v.size();
strides[0] = v.innerStride();
} else {
shape[0] = v.rows();
shape[1] = v.cols();
strides[0] = v.rowStride();
strides[1] = v.colStride();
}

return NDArrayCaster::from_cpp(
NDArray((void *) v.data(), ndim_v<T>, shape, handle(), strides),
rv_policy::reference, cleanup);
}

operator Ref() {
if constexpr (MaybeConvert) {
if (dcaster.caster.value.is_valid())
Expand Down
30 changes: 30 additions & 0 deletions tests/test_eigen.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <nanobind/eigen/dense.h>
#include <nanobind/eigen/sparse.h>
#include <nanobind/trampoline.h>

namespace nb = nanobind;

Expand Down Expand Up @@ -218,4 +219,33 @@ NB_MODULE(test_eigen_ext, m) {
return nb::cast<Eigen::Ref<const Eigen::VectorXi, Eigen::Unaligned, Eigen::InnerStride<3>>>(obj);
});

class Base {
public:
~Base() {};
virtual void modRefData(Eigen::Ref<Eigen::VectorXd> a) {};
virtual void modRefDataConst(Eigen::Ref<const Eigen::VectorXd> a) {};
};

class PyBase : public Base {
NB_TRAMPOLINE(Base, 2);
public:
void modRefData(Eigen::Ref<Eigen::VectorXd> a) override {NB_OVERRIDE_PURE(modRefData, a);}
void modRefDataConst(Eigen::Ref<const Eigen::VectorXd> a) override {NB_OVERRIDE_PURE(modRefDataConst, a);}
};

nb::class_<Base, PyBase>(m, "Base")
.def(nb::init<>())
.def("modRefData", &Base::modRefData)
.def("modRefDataConst", &Base::modRefDataConst);

m.def("modifyRef", [](Base* base) {
Eigen::Vector2d input {{1.0}, {2.0}};
base->modRefData(input);
return input;
});
m.def("modifyRefConst", [](Base* base) {
Eigen::Vector2d input {{1.0}, {2.0}};
base->modRefDataConst(input);
return input;
});
}
15 changes: 15 additions & 0 deletions tests/test_eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,18 @@ def test12_cast():
for v in vec, vec2, vecf:
with pytest.raises(RuntimeError, match='bad[_ ]cast'):
t.castToRef03CnstVXi(v)

@needs_numpy_and_eigen
def test13_mutate_python():
class Derived(t.Base):
def modRefData(self, input):
input[0] = 3.0

def modRefDataConst(self, input):
input[0] = 3.0

vecRef = np.array([3.0, 2.0])
der = Derived()
assert_array_equal(t.modifyRef(der), vecRef)
with pytest.raises(ValueError):
t.modifyRefConst(der)

0 comments on commit 7ae190c

Please sign in to comment.