diff --git a/stan/math/prim/constraint/sum_to_zero_constrain.hpp b/stan/math/prim/constraint/sum_to_zero_constrain.hpp index 92fa4a2daa0..212aed2790a 100644 --- a/stan/math/prim/constraint/sum_to_zero_constrain.hpp +++ b/stan/math/prim/constraint/sum_to_zero_constrain.hpp @@ -25,12 +25,13 @@ namespace math { template * = nullptr, require_not_st_var* = nullptr> inline plain_type_t sum_to_zero_constrain(const Vec& y) { - int Km1 = y.size(); + const auto Km1 = y.size(); plain_type_t x(Km1 + 1); // copy the first Km1 elements - x.head(Km1) = y; + auto&& y_ref = to_ref(y); + x.head(Km1) = y_ref; // set the last element to -sum(y) - x.coeffRef(Km1) = -sum(y); + x.coeffRef(Km1) = -sum(y_ref); return x; } diff --git a/stan/math/rev/constraint/sum_to_zero_constrain.hpp b/stan/math/rev/constraint/sum_to_zero_constrain.hpp index 10cbcd295de..01878004691 100644 --- a/stan/math/rev/constraint/sum_to_zero_constrain.hpp +++ b/stan/math/rev/constraint/sum_to_zero_constrain.hpp @@ -29,28 +29,20 @@ template * = nullptr> inline auto sum_to_zero_constrain(const T& y) { using ret_type = plain_type_t; - size_t N = y.size(); - Eigen::VectorXd x_val(N + 1); - - arena_t arena_y = y; - + const auto N = y.size(); if (unlikely(N == 0)) { - x_val << 0; - return ret_type(x_val); + return ret_type(Eigen::VectorXd{{0}}); } - - x_val.head(N) = y.val(); + Eigen::VectorXd x_val = Eigen::VectorXd::Zero(N + 1); + auto arena_y = to_arena(y); + double x_sum = -sum(arena_y.val()); + x_val.head(N) = arena_y.val(); + x_val(N) = x_sum; arena_t arena_x = x_val; - - var x_N = -sum(y); - - arena_x.coeffRef(N) = x_N; - - reverse_pass_callback([arena_y, arena_x, x_N, N]() mutable { - arena_y.adj() += arena_x.adj().head(N); - x_N.adj() += arena_x.adj().coeff(N); + reverse_pass_callback([arena_y, arena_x, x_sum, N]() mutable { + arena_y.adj().array() -= arena_x.adj_op()(N); + arena_y.adj() += arena_x.adj_op().head(N); }); - return ret_type(arena_x); } diff --git a/test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp b/test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp index cec8c2cd068..ae4ce5516db 100644 --- a/test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp +++ b/test/unit/math/mix/constraint/sum_to_zero_constrain_test.cpp @@ -24,11 +24,11 @@ void expect_sum_to_zero_transform(const T& x) { auto f2 = [](const auto& x) { return g2(x); }; auto f3 = [](const auto& x) { return g3(x); }; stan::test::expect_ad(f1, x); - // stan::test::expect_ad_matvar(f1, x); + stan::test::expect_ad_matvar(f1, x); stan::test::expect_ad(f2, x); - // stan::test::expect_ad_matvar(f2, x); + stan::test::expect_ad_matvar(f2, x); stan::test::expect_ad(f3, x); - // stan::test::expect_ad_matvar(f3, x); + stan::test::expect_ad_matvar(f3, x); } } // namespace sum_to_zero_constrain_test