From a54fb016fbd0df9c97466b5f284ac57bf11c8d47 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Mon, 15 Jul 2024 17:39:57 -0400 Subject: [PATCH] update type traits for fft, square_dist, and trace funcs --- stan/math/rev/fun/fft.hpp | 24 +++++----- stan/math/rev/fun/squared_distance.hpp | 19 +++----- .../rev/fun/trace_gen_inv_quad_form_ldlt.hpp | 6 +-- stan/math/rev/fun/trace_quad_form.hpp | 48 +++++++------------ 4 files changed, 39 insertions(+), 58 deletions(-) diff --git a/stan/math/rev/fun/fft.hpp b/stan/math/rev/fun/fft.hpp index 797950b3056..34f2b2a13e8 100644 --- a/stan/math/rev/fun/fft.hpp +++ b/stan/math/rev/fun/fft.hpp @@ -39,13 +39,13 @@ namespace math { */ template * = nullptr, require_var_t>>* = nullptr> -inline plain_type_t fft(const V& x) { +inline auto fft(const V& x) { if (unlikely(x.size() <= 1)) { - return plain_type_t(x); + return arena_t>(x); } arena_t arena_v = x; - arena_t res = fft(to_complex(arena_v.real().val(), arena_v.imag().val())); + arena_t> res = fft(to_complex(arena_v.real().val(), arena_v.imag().val())); reverse_pass_callback([arena_v, res]() mutable { auto adj_inv_fft = inv_fft(to_complex(res.real().adj(), res.imag().adj())); @@ -54,7 +54,7 @@ inline plain_type_t fft(const V& x) { arena_v.imag().adj() += adj_inv_fft.imag(); }); - return plain_type_t(res); + return res; } /** @@ -84,13 +84,13 @@ inline plain_type_t fft(const V& x) { */ template * = nullptr, require_var_t>>* = nullptr> -inline plain_type_t inv_fft(const V& y) { +inline auto inv_fft(const V& y) { if (unlikely(y.size() <= 1)) { - return plain_type_t(y); + return arena_t>(y); } arena_t arena_v = y; - arena_t res + arena_t> res = inv_fft(to_complex(arena_v.real().val(), arena_v.imag().val())); reverse_pass_callback([arena_v, res]() mutable { @@ -100,7 +100,7 @@ inline plain_type_t inv_fft(const V& y) { arena_v.real().adj() += adj_fft.real(); arena_v.imag().adj() += adj_fft.imag(); }); - return plain_type_t(res); + return res; } /** @@ -120,7 +120,7 @@ inline plain_type_t inv_fft(const V& y) { */ template * = nullptr, require_var_t>>* = nullptr> -inline plain_type_t fft2(const M& x) { +inline auto fft2(const M& x) { arena_t arena_v = x; arena_t res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val())); @@ -131,7 +131,7 @@ inline plain_type_t fft2(const M& x) { arena_v.imag().adj() += adj_inv_fft.imag(); }); - return plain_type_t(res); + return res; } /** @@ -152,7 +152,7 @@ inline plain_type_t fft2(const M& x) { */ template * = nullptr, require_var_t>>* = nullptr> -inline plain_type_t inv_fft2(const M& y) { +inline auto inv_fft2(const M& y) { arena_t arena_v = y; arena_t res = inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val())); @@ -164,7 +164,7 @@ inline plain_type_t inv_fft2(const M& y) { arena_v.real().adj() += adj_fft.real(); arena_v.imag().adj() += adj_fft.imag(); }); - return plain_type_t(res); + return res; } } // namespace math diff --git a/stan/math/rev/fun/squared_distance.hpp b/stan/math/rev/fun/squared_distance.hpp index bbce21bb3f5..13b5c0f0bb9 100644 --- a/stan/math/rev/fun/squared_distance.hpp +++ b/stan/math/rev/fun/squared_distance.hpp @@ -158,11 +158,12 @@ inline var squared_distance(const T1& A, const T2& B) { check_matching_sizes("squared_distance", "A", A.val(), "B", B.val()); if (unlikely(A.size() == 0)) { return var(0.0); - } else if constexpr (is_autodiffable_v) { - arena_t arena_A = A; - arena_t arena_B = B; - arena_t res_diff(arena_A.size()); - double res_val = 0.0; + } + arena_t arena_A = A; + arena_t arena_B = B; + arena_t res_diff(arena_A.size()); + double res_val = 0.0; + if constexpr (is_autodiffable_v) { for (size_t i = 0; i < arena_A.size(); ++i) { const double diff = arena_A.val().coeff(i) - arena_B.val().coeff(i); res_diff.coeffRef(i) = diff; @@ -178,10 +179,6 @@ inline var squared_distance(const T1& A, const T2& B) { } })); } else if constexpr (is_autodiffable_v) { - arena_t arena_A = A; - arena_t arena_B = value_of(B); - arena_t res_diff(arena_A.size()); - double res_val = 0.0; for (size_t i = 0; i < arena_A.size(); ++i) { const double diff = arena_A.val().coeff(i) - arena_B.coeff(i); res_diff.coeffRef(i) = diff; @@ -192,10 +189,6 @@ inline var squared_distance(const T1& A, const T2& B) { arena_A.adj() += 2.0 * res.adj() * res_diff; })); } else { - arena_t arena_A = value_of(A); - arena_t arena_B = B; - arena_t res_diff(arena_A.size()); - double res_val = 0.0; for (size_t i = 0; i < arena_A.size(); ++i) { const double diff = arena_A.coeff(i) - arena_B.val().coeff(i); res_diff.coeffRef(i) = diff; diff --git a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp index c15969bc8b6..12184718347 100644 --- a/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp +++ b/stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp @@ -40,10 +40,10 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor& A, return 0; } + arena_t arena_A = A.matrix(); + arena_t arena_B = B; + arena_t arena_D = D; if constexpr (is_autodiffable_v) { - arena_t arena_A = A.matrix(); - arena_t arena_B = B; - arena_t arena_D = D; auto AsolveB = to_arena(A.ldlt().solve(arena_B.val())); auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB); diff --git a/stan/math/rev/fun/trace_quad_form.hpp b/stan/math/rev/fun/trace_quad_form.hpp index 8c4673821f4..f61cb743923 100644 --- a/stan/math/rev/fun/trace_quad_form.hpp +++ b/stan/math/rev/fun/trace_quad_form.hpp @@ -119,58 +119,47 @@ inline var trace_quad_form(Mat1&& A, Mat2&& B) { check_square("trace_quad_form", "A", A); check_multiplicable("trace_quad_form", "A", A, "B", B); - var res; - + arena_t arena_A = std::forward(A); + arena_t arena_B = std::forward(B); if constexpr (is_autodiffable_v) { - arena_t arena_A = std::forward(A); - arena_t arena_B = std::forward(B); - - res = (value_of(arena_B).transpose() * value_of(arena_A) - * value_of(arena_B)) + var res = (arena_B.val_op().transpose() * arena_A.val_op() + * arena_B.val_op()) .trace(); - reverse_pass_callback([arena_A, arena_B, res]() mutable { if constexpr (is_var_matrix::value) { arena_A.adj().noalias() - += res.adj() * value_of(arena_B) * value_of(arena_B).transpose(); + += res.adj() * arena_B.val_op() * arena_B.val_op().transpose(); } else { arena_A.adj() - += res.adj() * value_of(arena_B) * value_of(arena_B).transpose(); + += res.adj() * arena_B.val_op() * arena_B.val_op().transpose(); } - if constexpr (is_var_matrix::value) { arena_B.adj().noalias() - += res.adj() * (value_of(arena_A) + value_of(arena_A).transpose()) - * value_of(arena_B); + += res.adj() * (arena_A.val_op() + arena_A.val_op().transpose()) + * arena_B.val_op(); } else { arena_B.adj() += res.adj() - * (value_of(arena_A) + value_of(arena_A).transpose()) - * value_of(arena_B); + * (arena_A.val_op() + arena_A.val_op().transpose()) + * arena_B.val_op(); } }); + return res; } else if constexpr (is_autodiffable_v) { - arena_t arena_A = value_of(std::forward(A)); - arena_t arena_B = std::forward(B); - - res = (value_of(arena_B).transpose() * value_of(arena_A) - * value_of(arena_B)) + var res = (arena_B.val_op().transpose() * arena_A + * arena_B.val_op()) .trace(); - reverse_pass_callback([arena_A, arena_B, res]() mutable { if constexpr (is_var_matrix::value) { arena_B.adj().noalias() - += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B); + += res.adj() * (arena_A + arena_A.transpose()) * arena_B.val_op(); } else { arena_B.adj() - += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B); + += res.adj() * (arena_A + arena_A.transpose()) * arena_B.val_op(); } }); + return res; } else { - arena_t arena_A = A; - arena_t arena_B = value_of(B); - - res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace(); - + var res = (arena_B.transpose() * arena_A.val_op() * arena_B).trace(); reverse_pass_callback([arena_A, arena_B, res]() mutable { if constexpr (is_var_matrix::value) { arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose(); @@ -178,9 +167,8 @@ inline var trace_quad_form(Mat1&& A, Mat2&& B) { arena_A.adj() += res.adj() * arena_B * arena_B.transpose(); } }); + return res; } - - return res; } } // namespace math