From 429a229f0f0d9a7697da0e1a8555bdf756e60ddb Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 8 Oct 2024 22:07:03 +0200 Subject: [PATCH] Add support for Kokkos::View element access in the rvs mode --- include/clad/Differentiator/KokkosBuiltins.h | 216 +++++++++++++++++++ unittests/Kokkos/ViewAccess.cpp | 9 +- 2 files changed, 220 insertions(+), 5 deletions(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index 1a6253027..d17d33a69 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -30,6 +30,37 @@ constructor_pushforward( Kokkos::View( "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; } +template +clad::ValueAndAdjoint<::Kokkos::View, + ::Kokkos::View> +constructor_reverse_forw( + clad::ConstructorReverseForwTag<::Kokkos::View>, + const ::std::string& name, const size_t& idx0, const size_t& idx1, + const size_t& idx2, const size_t& idx3, const size_t& idx4, + const size_t& idx5, const size_t& idx6, const size_t& idx7, + const ::std::string& /*d_name*/, const size_t& /*d_idx0*/, + const size_t& /*d_idx1*/, const size_t& /*d_idx2*/, + const size_t& /*d_idx3*/, const size_t& /*d_idx4*/, + const size_t& /*d_idx5*/, const size_t& /*d_idx6*/, + const size_t& /*d_idx7*/) { + return {::Kokkos::View(name, idx0, idx1, idx2, idx3, + idx4, idx5, idx6, idx7), + ::Kokkos::View( + "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; +} +template +void constructor_pullback(::Kokkos::View* v, + const ::std::string& name, const size_t& idx0, + const size_t& idx1, const size_t& idx2, + const size_t& idx3, const size_t& idx4, + const size_t& idx5, const size_t& idx6, + const size_t& idx7, + ::Kokkos::View* d_v, + const ::std::string* /*d_name*/, + const size_t& /*d_idx0*/, const size_t* /*d_idx1*/, + const size_t* /*d_idx2*/, const size_t* /*d_idx3*/, + const size_t* /*d_idx4*/, const size_t* /*d_idx5*/, + const size_t* /*d_idx6*/, const size_t* /*d_idx7*/) {} /// View indexing template @@ -107,6 +138,191 @@ operator_call_pushforward(const View* v, Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7), (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)}; } +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx i0, + const ::Kokkos::View* d_v, + Idx /*d_i0*/) { + return {(*v)(i0), (*d_v)(i0)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx i0, Diff d_y, + ::Kokkos::View* d_v, + dIdx* /*d_i0*/) { + (*d_v)(i0) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/) { + return {(*v)(i0, i1), (*d_v)(i0, i1)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/) { + (*d_v)(i0, i1) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/) { + return {(*v)(i0, i1, i2), (*d_v)(i0, i1, i2)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/) { + (*d_v)(i0, i1, i2) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/) { + return {(*v)(i0, i1, i2, i3), (*d_v)(i0, i1, i2, i3)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/) { + return {(*v)(i0, i1, i2, i3, i4), (*d_v)(i0, i1, i2, i3, i4)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i4*/) { + (*d_v)(i0, i1, i2, i3, i4) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/) { + return {(*v)(i0, i1, i2, i3, i4, i5), (*d_v)(i0, i1, i2, i3, i4, i5)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i4*/, dIdx5* /*d_i5*/) { + (*d_v)(i0, i1, i2, i3, i4, i5) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + Idx6 i6, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/, + Idx6 /*d_i6*/) { + return {(*v)(i0, i1, i2, i3, i4, i5, i6), (*d_v)(i0, i1, i2, i3, i4, i5, i6)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Idx6 i6, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i3*/, dIdx5* /*d_i3*/, + dIdx6* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3, i4, i5, i6) += d_y; +} +template +clad::ValueAndAdjoint< + typename ::Kokkos::View::reference_type&, + typename ::Kokkos::View::reference_type&> +operator_call_reverse_forw(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, Idx5 i5, + Idx6 i6, Idx7 i7, + const ::Kokkos::View* d_v, + Idx0 /*d_i0*/, Idx1 /*d_i1*/, Idx2 /*d_i2*/, + Idx3 /*d_i3*/, Idx4 /*d_i4*/, Idx5 /*d_i5*/, + Idx6 /*d_i6*/, Idx7 /*d_i7*/) { + return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7), + (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)}; +} +template +void operator_call_pullback(const ::Kokkos::View* v, + Idx0 i0, Idx1 i1, Idx2 i2, Idx3 i3, Idx4 i4, + Idx5 i5, Idx6 i6, Idx7 i7, Diff d_y, + ::Kokkos::View* d_v, + dIdx0* /*d_i0*/, dIdx1* /*d_i1*/, dIdx2* /*d_i2*/, + dIdx3* /*d_i3*/, dIdx4* /*d_i3*/, dIdx5* /*d_i3*/, + dIdx6* /*d_i3*/, dIdx7* /*d_i3*/) { + (*d_v)(i0, i1, i2, i3, i4, i5, i6, i7) += d_y; +} } // namespace class_functions /// Kokkos functions (view utils) diff --git a/unittests/Kokkos/ViewAccess.cpp b/unittests/Kokkos/ViewAccess.cpp index e77b278f0..12cc355d1 100644 --- a/unittests/Kokkos/ViewAccess.cpp +++ b/unittests/Kokkos/ViewAccess.cpp @@ -60,11 +60,10 @@ TEST(ViewAccess, Test2) { double dx_f_2_FD = finite_difference_tangent(f_2_tmp, 3., epsilon); EXPECT_NEAR(f_2_x.execute(3, 4), dx_f_2_FD, tolerance * dx_f_2_FD); - // TODO: uncomment this once it has been implemented - // auto f_grad_exe = clad::gradient(f); - // double dx, dy; - // f_grad_exe.execute(3., 4., &dx, &dy); - // EXPECT_NEAR(f_x.execute(3, 4),dx,tolerance*dx); + auto f_grad_exe = clad::gradient(f); + double dx, dy; + f_grad_exe.execute(3., 4., &dx, &dy); + EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx); // double dx_2, dy_2; // auto f_2_grad_exe = clad::gradient(f_2);