Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid vari on chain-stack if var is constructed from an arithmetic type #1675

Merged
merged 11 commits into from
Feb 19, 2020
Merged
39 changes: 21 additions & 18 deletions stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class var {
*
* @param x Value of the variable.
*/
var(float x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(float x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -88,7 +88,7 @@ class var {
*
* @param x Value of the variable.
*/
var(double x) : vi_(new vari(x)) {} // NOLINT
var(double x) : vi_(new vari(x, false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -97,7 +97,7 @@ class var {
*
* @param x Value of the variable.
*/
var(long double x) : vi_(new vari(x)) {} // NOLINT
var(long double x) : vi_(new vari(x, false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -106,7 +106,7 @@ class var {
*
* @param x Value of the variable.
*/
var(bool x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(bool x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -115,7 +115,7 @@ class var {
*
* @param x Value of the variable.
*/
var(char x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(char x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -124,7 +124,7 @@ class var {
*
* @param x Value of the variable.
*/
var(short x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(short x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -133,7 +133,7 @@ class var {
*
* @param x Value of the variable.
*/
var(int x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(int x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -142,7 +142,7 @@ class var {
*
* @param x Value of the variable.
*/
var(long x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(long x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -152,7 +152,7 @@ class var {
* @param x Value of the variable.
*/
var(unsigned char x) // NOLINT(runtime/explicit)
: vi_(new vari(static_cast<double>(x))) {}
: vi_(new vari(static_cast<double>(x), false)) {}

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -162,7 +162,7 @@ class var {
* @param x Value of the variable.
*/
// NOLINTNEXTLINE
var(unsigned short x) : vi_(new vari(static_cast<double>(x))) {}
var(unsigned short x) : vi_(new vari(static_cast<double>(x), false)) {}

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -171,7 +171,9 @@ class var {
*
* @param x Value of the variable.
*/
var(unsigned int x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
// NOLINTNEXTLINE
var(unsigned int x)
: vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -181,7 +183,8 @@ class var {
* @param x Value of the variable.
*/
// NOLINTNEXTLINE
var(unsigned long x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(unsigned long x)
: vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -193,7 +196,7 @@ class var {
*/
explicit var(const std::complex<double>& x) {
if (imag(x) == 0) {
vi_ = new vari(real(x));
vi_ = new vari(real(x), false);
} else {
std::stringstream ss;
ss << "Imaginary part of std::complex used to construct var"
Expand All @@ -214,7 +217,7 @@ class var {
*/
explicit var(const std::complex<float>& x) {
if (imag(x) == 0) {
vi_ = new vari(static_cast<double>(real(x)));
vi_ = new vari(static_cast<double>(real(x)), false);
} else {
std::stringstream ss;
ss << "Imaginary part of std::complex used to construct var"
Expand All @@ -235,7 +238,7 @@ class var {
*/
explicit var(const std::complex<long double>& x) {
if (imag(x) == 0) {
vi_ = new vari(static_cast<double>(real(x)));
vi_ = new vari(static_cast<double>(real(x)), false);
} else {
std::stringstream ss;
ss << "Imaginary part of std::complex used to construct var"
Expand All @@ -259,7 +262,7 @@ class var {
*
* @param x Value of the variable.
*/
var(size_t x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(size_t x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

/**
* Construct a variable from the specified arithmetic argument
Expand All @@ -268,7 +271,7 @@ class var {
*
* @param x Value of the variable.
*/
var(ptrdiff_t x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(ptrdiff_t x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT
#endif

#ifdef BOOST_MATH_USE_FLOAT128
Expand All @@ -283,7 +286,7 @@ class var {
*
* @param x Value of the variable.
*/
var(__float128 x) : vi_(new vari(static_cast<double>(x))) {} // NOLINT
var(__float128 x) : vi_(new vari(static_cast<double>(x), false)) {} // NOLINT

#endif

Expand Down
15 changes: 3 additions & 12 deletions stan/math/rev/functor/coupled_ode_system.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ struct coupled_ode_system<F, double, var> {
try {
start_nested();

vector<var> y_vars;
y_vars.reserve(N_);
for (std::size_t i = 0; i < N_; ++i)
y_vars.emplace_back(new vari(z[i], false));
vector<var> y_vars(z.begin(), z.begin() + N_);

vector<var> dy_dt_vars = f_(t, y_vars, theta_nochain_, x_, x_int_, msgs_);

Expand Down Expand Up @@ -279,10 +276,7 @@ struct coupled_ode_system<F, var, double> {
try {
start_nested();

vector<var> y_vars;
y_vars.reserve(N_);
for (std::size_t i = 0; i < N_; ++i)
y_vars.emplace_back(new vari(z[i], false));
vector<var> y_vars(z.begin(), z.begin() + N_);

vector<var> dy_dt_vars = f_(t, y_vars, theta_dbl_, x_, x_int_, msgs_);

Expand Down Expand Up @@ -460,10 +454,7 @@ struct coupled_ode_system<F, var, var> {
try {
start_nested();

vector<var> y_vars;
y_vars.reserve(N_);
for (std::size_t i = 0; i < N_; ++i)
y_vars.emplace_back(new vari(z[i], false));
vector<var> y_vars(z.begin(), z.begin() + N_);

vector<var> dy_dt_vars = f_(t, y_vars, theta_nochain_, x_, x_int_, msgs_);

Expand Down
4 changes: 1 addition & 3 deletions stan/math/rev/functor/gradient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ void gradient(const F& f, const Eigen::Matrix<double, Eigen::Dynamic, 1>& x,
double& fx, Eigen::Matrix<double, Eigen::Dynamic, 1>& grad_fx) {
start_nested();
try {
Eigen::Matrix<var, Eigen::Dynamic, 1> x_var(x.size());
for (int i = 0; i < x.size(); ++i)
x_var(i) = var(new vari(x(i), false));
Eigen::Matrix<var, Eigen::Dynamic, 1> x_var(x);
var fx_var = f(x_var);
fx = fx_var.val();
grad_fx.resize(x.size());
Expand Down
12 changes: 8 additions & 4 deletions test/unit/math/rev/core/thread_stack_instance_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,23 @@ TEST(thread_stack_instance, child_instances) {
// thread will be different at initialization (if STAN_THREADS is
// set)
stan::math::var a = 1;
stan::math::var b = a * a;

ChainableStack::AutodiffStackStorage* main_ad_stack
= ChainableStack::instance_;

auto thread_tester = [&]() -> void {
ChainableStack thread_instance;
EXPECT_TRUE(main_ad_stack->var_stack_.size()
EXPECT_TRUE(
main_ad_stack->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size()
#ifdef STAN_THREADS
>
>
#else
==
==
#endif
ChainableStack::instance_->var_stack_.size());
ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size());
};

std::thread other_work(thread_tester);
Expand Down
14 changes: 10 additions & 4 deletions test/unit/math/rev/err/check_bounded_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ TEST(AgradRevErrorHandlingScalar, CheckBoundedVarCheckVectorized) {
for (int i = 0; i < N; ++i)
a.push_back(var(i));

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(5U, stack_size);
EXPECT_NO_THROW(check_bounded(function, "a", a, -1.0, 6.0));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(5U, stack_size_after_call);
stan::math::recover_memory();
}
Expand Down Expand Up @@ -145,13 +148,16 @@ TEST(AgradRevErrorHandlingScalar, CheckBoundedVarCheckUnivariate) {
const char* function = "check_bounded";
var a(5.0);

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(1U, stack_size);
EXPECT_NO_THROW(check_bounded(function, "a", a, 4.0, 6.0));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(1U, stack_size_after_call);

stan::math::recover_memory();
Expand Down
7 changes: 5 additions & 2 deletions test/unit/math/rev/err/check_consistent_size_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ TEST(AgradRevErrorHandlingScalar, CheckConsistentSizeVarCheckVectorized) {
for (int i = 0; i < N; ++i)
a.push_back(var(i));

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(5U, stack_size);
EXPECT_NO_THROW(check_consistent_size(function, "a", a, 5U));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(5U, stack_size_after_call);
stan::math::recover_memory();
}
Expand Down
7 changes: 5 additions & 2 deletions test/unit/math/rev/err/check_consistent_sizes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ TEST(AgradRevErrorHandlingScalar, CheckConsistentSizesVarCheckVectorized) {
a.push_back(var(i));
}

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(10U, stack_size);
EXPECT_NO_THROW(check_consistent_sizes(function, "a", a, "b", b));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(10U, stack_size_after_call);
stan::math::recover_memory();
}
Expand Down
20 changes: 14 additions & 6 deletions test/unit/math/rev/err/check_finite_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@ TEST(AgradRevErrorHandlingScalar, CheckFiniteVarCheckVectorized) {
for (int i = 0; i < N; ++i)
a.push_back(var(i));

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(5U, stack_size);
EXPECT_NO_THROW(check_finite(function, "a", a));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(5U, stack_size_after_call);

a[1] = std::numeric_limits<double>::infinity();
EXPECT_THROW(check_finite(function, "a", a), std::domain_error);
stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(6U, stack_size_after_call);

stan::math::recover_memory();
Expand Down Expand Up @@ -65,19 +69,23 @@ TEST(AgradRevErrorHandlingScalar, CheckFiniteVarCheckUnivariate) {
const char* function = "check_finite";
var a(5.0);

size_t stack_size = stan::math::ChainableStack::instance_->var_stack_.size();
size_t stack_size
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();

EXPECT_EQ(1U, stack_size);
EXPECT_NO_THROW(check_finite(function, "a", a));

size_t stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(1U, stack_size_after_call);

a = std::numeric_limits<double>::infinity();
EXPECT_THROW(check_finite(function, "a", a), std::domain_error);
stack_size_after_call
= stan::math::ChainableStack::instance_->var_stack_.size();
= stan::math::ChainableStack::instance_->var_stack_.size()
+ stan::math::ChainableStack::instance_->var_nochain_stack_.size();
EXPECT_EQ(2U, stack_size_after_call);

stan::math::recover_memory();
Expand Down
Loading