Skip to content

Commit

Permalink
fix bug in reverse pass for sum to zero constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jul 30, 2024
1 parent fb97a24 commit 3fb6ff4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 24 deletions.
7 changes: 4 additions & 3 deletions stan/math/prim/constraint/sum_to_zero_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ namespace math {
template <typename Vec, require_eigen_col_vector_t<Vec>* = nullptr,
require_not_st_var<Vec>* = nullptr>
inline plain_type_t<Vec> sum_to_zero_constrain(const Vec& y) {
int Km1 = y.size();
const auto Km1 = y.size();
plain_type_t<Vec> x(Km1 + 1);
// copy the first Km1 elements
x.head(Km1) = y;
auto&& y_ref = to_ref(y);
x.head(Km1) = y_ref;
// set the last element to -sum(y)
x.coeffRef(Km1) = -sum(y);
x.coeffRef(Km1) = -sum(y_ref);
return x;
}

Expand Down
28 changes: 10 additions & 18 deletions stan/math/rev/constraint/sum_to_zero_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,20 @@ template <typename T, require_rev_col_vector_t<T>* = nullptr>
inline auto sum_to_zero_constrain(const T& y) {
using ret_type = plain_type_t<T>;

size_t N = y.size();
Eigen::VectorXd x_val(N + 1);

arena_t<T> arena_y = y;

const auto N = y.size();
if (unlikely(N == 0)) {
x_val << 0;
return ret_type(x_val);
return ret_type(Eigen::VectorXd{{0}});
}

x_val.head(N) = y.val();
Eigen::VectorXd x_val = Eigen::VectorXd::Zero(N + 1);
auto arena_y = to_arena(y);
double x_sum = -sum(arena_y.val());
x_val.head(N) = arena_y.val();
x_val(N) = x_sum;
arena_t<ret_type> arena_x = x_val;

var x_N = -sum(y);

arena_x.coeffRef(N) = x_N;

reverse_pass_callback([arena_y, arena_x, x_N, N]() mutable {
arena_y.adj() += arena_x.adj().head(N);
x_N.adj() += arena_x.adj().coeff(N);
reverse_pass_callback([arena_y, arena_x, x_sum, N]() mutable {
arena_y.adj().array() -= arena_x.adj_op()(N);
arena_y.adj() += arena_x.adj_op().head(N);
});

return ret_type(arena_x);
}

Expand Down
6 changes: 3 additions & 3 deletions test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ void expect_sum_to_zero_transform(const T& x) {
auto f2 = [](const auto& x) { return g2(x); };
auto f3 = [](const auto& x) { return g3(x); };
stan::test::expect_ad(f1, x);
// stan::test::expect_ad_matvar(f1, x);
stan::test::expect_ad_matvar(f1, x);
stan::test::expect_ad(f2, x);
// stan::test::expect_ad_matvar(f2, x);
stan::test::expect_ad_matvar(f2, x);
stan::test::expect_ad(f3, x);
// stan::test::expect_ad_matvar(f3, x);
stan::test::expect_ad_matvar(f3, x);
}
} // namespace sum_to_zero_constrain_test

Expand Down

0 comments on commit 3fb6ff4

Please sign in to comment.