Skip to content

Commit

Permalink
update type traits for fft, square_dist, and trace funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jul 15, 2024
1 parent d7350f2 commit a54fb01
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 58 deletions.
24 changes: 12 additions & 12 deletions stan/math/rev/fun/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ namespace math {
*/
template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
inline plain_type_t<V> fft(const V& x) {
inline auto fft(const V& x) {
if (unlikely(x.size() <= 1)) {
return plain_type_t<V>(x);
return arena_t<plain_type_t<V>>(x);
}

arena_t<V> arena_v = x;
arena_t<V> res = fft(to_complex(arena_v.real().val(), arena_v.imag().val()));
arena_t<plain_type_t<V>> res = fft(to_complex(arena_v.real().val(), arena_v.imag().val()));

reverse_pass_callback([arena_v, res]() mutable {
auto adj_inv_fft = inv_fft(to_complex(res.real().adj(), res.imag().adj()));
Expand All @@ -54,7 +54,7 @@ inline plain_type_t<V> fft(const V& x) {
arena_v.imag().adj() += adj_inv_fft.imag();
});

return plain_type_t<V>(res);
return res;
}

/**
Expand Down Expand Up @@ -84,13 +84,13 @@ inline plain_type_t<V> fft(const V& x) {
*/
template <typename V, require_eigen_vector_vt<is_complex, V>* = nullptr,
require_var_t<base_type_t<value_type_t<V>>>* = nullptr>
inline plain_type_t<V> inv_fft(const V& y) {
inline auto inv_fft(const V& y) {
if (unlikely(y.size() <= 1)) {
return plain_type_t<V>(y);
return arena_t<plain_type_t<V>>(y);
}

arena_t<V> arena_v = y;
arena_t<V> res
arena_t<plain_type_t<V>> res
= inv_fft(to_complex(arena_v.real().val(), arena_v.imag().val()));

reverse_pass_callback([arena_v, res]() mutable {
Expand All @@ -100,7 +100,7 @@ inline plain_type_t<V> inv_fft(const V& y) {
arena_v.real().adj() += adj_fft.real();
arena_v.imag().adj() += adj_fft.imag();
});
return plain_type_t<V>(res);
return res;
}

/**
Expand All @@ -120,7 +120,7 @@ inline plain_type_t<V> inv_fft(const V& y) {
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline plain_type_t<M> fft2(const M& x) {
inline auto fft2(const M& x) {
arena_t<M> arena_v = x;
arena_t<M> res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));

Expand All @@ -131,7 +131,7 @@ inline plain_type_t<M> fft2(const M& x) {
arena_v.imag().adj() += adj_inv_fft.imag();
});

return plain_type_t<M>(res);
return res;
}

/**
Expand All @@ -152,7 +152,7 @@ inline plain_type_t<M> fft2(const M& x) {
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline plain_type_t<M> inv_fft2(const M& y) {
inline auto inv_fft2(const M& y) {
arena_t<M> arena_v = y;
arena_t<M> res
= inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));
Expand All @@ -164,7 +164,7 @@ inline plain_type_t<M> inv_fft2(const M& y) {
arena_v.real().adj() += adj_fft.real();
arena_v.imag().adj() += adj_fft.imag();
});
return plain_type_t<M>(res);
return res;
}

} // namespace math
Expand Down
19 changes: 6 additions & 13 deletions stan/math/rev/fun/squared_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,12 @@ inline var squared_distance(const T1& A, const T2& B) {
check_matching_sizes("squared_distance", "A", A.val(), "B", B.val());
if (unlikely(A.size() == 0)) {
return var(0.0);
} else if constexpr (is_autodiffable_v<T1, T2>) {
arena_t<T1> arena_A = A;
arena_t<T2> arena_B = B;
arena_t<Eigen::VectorXd> res_diff(arena_A.size());
double res_val = 0.0;
}
arena_t<T1> arena_A = A;
arena_t<T2> arena_B = B;
arena_t<Eigen::VectorXd> res_diff(arena_A.size());
double res_val = 0.0;
if constexpr (is_autodiffable_v<T1, T2>) {
for (size_t i = 0; i < arena_A.size(); ++i) {
const double diff = arena_A.val().coeff(i) - arena_B.val().coeff(i);
res_diff.coeffRef(i) = diff;
Expand All @@ -178,10 +179,6 @@ inline var squared_distance(const T1& A, const T2& B) {
}
}));
} else if constexpr (is_autodiffable_v<T1>) {
arena_t<T1> arena_A = A;
arena_t<T2> arena_B = value_of(B);
arena_t<Eigen::VectorXd> res_diff(arena_A.size());
double res_val = 0.0;
for (size_t i = 0; i < arena_A.size(); ++i) {
const double diff = arena_A.val().coeff(i) - arena_B.coeff(i);
res_diff.coeffRef(i) = diff;
Expand All @@ -192,10 +189,6 @@ inline var squared_distance(const T1& A, const T2& B) {
arena_A.adj() += 2.0 * res.adj() * res_diff;
}));
} else {
arena_t<T1> arena_A = value_of(A);
arena_t<T2> arena_B = B;
arena_t<Eigen::VectorXd> res_diff(arena_A.size());
double res_val = 0.0;
for (size_t i = 0; i < arena_A.size(); ++i) {
const double diff = arena_A.coeff(i) - arena_B.val().coeff(i);
res_diff.coeffRef(i) = diff;
Expand Down
6 changes: 3 additions & 3 deletions stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
return 0;
}

arena_t<Ta> arena_A = A.matrix();
arena_t<Tb> arena_B = B;
arena_t<Td> arena_D = D;
if constexpr (is_autodiffable_v<Ta, Tb, Td>) {
arena_t<Ta> arena_A = A.matrix();
arena_t<Tb> arena_B = B;
arena_t<Td> arena_D = D;
auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);

Expand Down
48 changes: 18 additions & 30 deletions stan/math/rev/fun/trace_quad_form.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,68 +119,56 @@ inline var trace_quad_form(Mat1&& A, Mat2&& B) {
check_square("trace_quad_form", "A", A);
check_multiplicable("trace_quad_form", "A", A, "B", B);

var res;

arena_t<Mat1> arena_A = std::forward<Mat1>(A);
arena_t<Mat2> arena_B = std::forward<Mat2>(B);
if constexpr (is_autodiffable_v<Mat1, Mat2>) {
arena_t<Mat1> arena_A = std::forward<Mat1>(A);
arena_t<Mat2> arena_B = std::forward<Mat2>(B);

res = (value_of(arena_B).transpose() * value_of(arena_A)
* value_of(arena_B))
var res = (arena_B.val_op().transpose() * arena_A.val_op()
* arena_B.val_op())
.trace();

reverse_pass_callback([arena_A, arena_B, res]() mutable {
if constexpr (is_var_matrix<Mat1>::value) {
arena_A.adj().noalias()
+= res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
+= res.adj() * arena_B.val_op() * arena_B.val_op().transpose();
} else {
arena_A.adj()
+= res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
+= res.adj() * arena_B.val_op() * arena_B.val_op().transpose();
}

if constexpr (is_var_matrix<Mat2>::value) {
arena_B.adj().noalias()
+= res.adj() * (value_of(arena_A) + value_of(arena_A).transpose())
* value_of(arena_B);
+= res.adj() * (arena_A.val_op() + arena_A.val_op().transpose())
* arena_B.val_op();
} else {
arena_B.adj() += res.adj()
* (value_of(arena_A) + value_of(arena_A).transpose())
* value_of(arena_B);
* (arena_A.val_op() + arena_A.val_op().transpose())
* arena_B.val_op();
}
});
return res;
} else if constexpr (is_autodiffable_v<Mat2>) {
arena_t<Mat1> arena_A = value_of(std::forward<Mat1>(A));
arena_t<Mat2> arena_B = std::forward<Mat2>(B);

res = (value_of(arena_B).transpose() * value_of(arena_A)
* value_of(arena_B))
var res = (arena_B.val_op().transpose() * arena_A
* arena_B.val_op())
.trace();

reverse_pass_callback([arena_A, arena_B, res]() mutable {
if constexpr (is_var_matrix<Mat2>::value) {
arena_B.adj().noalias()
+= res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
+= res.adj() * (arena_A + arena_A.transpose()) * arena_B.val_op();
} else {
arena_B.adj()
+= res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
+= res.adj() * (arena_A + arena_A.transpose()) * arena_B.val_op();
}
});
return res;
} else {
arena_t<Mat1> arena_A = A;
arena_t<Mat2> arena_B = value_of(B);

res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace();

var res = (arena_B.transpose() * arena_A.val_op() * arena_B).trace();
reverse_pass_callback([arena_A, arena_B, res]() mutable {
if constexpr (is_var_matrix<Mat1>::value) {
arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose();
} else {
arena_A.adj() += res.adj() * arena_B * arena_B.transpose();
}
});
return res;
}

return res;
}

} // namespace math
Expand Down

0 comments on commit a54fb01

Please sign in to comment.