Skip to content

Commit

Permalink
Added complex number support to automatic autodiff tester and added s…
Browse files Browse the repository at this point in the history
…ome complex tests (Issue stan-dev#123)
  • Loading branch information
bbbales2 committed Nov 1, 2018
1 parent 48c9d6a commit d18f00e
Show file tree
Hide file tree
Showing 13 changed files with 508 additions and 10 deletions.
7 changes: 7 additions & 0 deletions stan/math/prim/mat/fun/promote_double_to.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ std::tuple<std::vector<T> > promote_element_double_to(
return std::make_tuple(promoted_input);
}

template <typename T>
std::tuple<std::complex<T> > promote_element_double_to(
const std::complex<double>& input) {
std::complex<T> promoted_input(input.real(), input.imag());
return std::make_tuple(promoted_input);
}

template <typename T>
std::tuple<T> promote_element_double_to(const double& input) {
return std::make_tuple(T(input));
Expand Down
22 changes: 22 additions & 0 deletions stan/math/prim/mat/fun/variable_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class variable_adapter {
return count_T_impl(count + x.size(), args...);
}

template <typename... Pargs>
size_t count_T_impl(size_t count, const std::complex<T>& x,
const Pargs&... args) {
return count_T_impl(count + 2, args...);
}

template <typename... Pargs>
size_t count_T_impl(size_t count, const T& x, const Pargs&... args) {
return count_T_impl(count + 1, args...);
Expand Down Expand Up @@ -109,6 +115,22 @@ class variable_adapter {
return get(i - arg.size(), args...);
}

/**
* Return the ith element of arg if i is less than 2.
* Otherwise return the (i - 2)th element of the
* remaining args
*
* @tparam Pargs Types of the rest of the input arguments to process
* @return Reference to ith T in args_
*/
template <typename... Pargs>
T& get(size_t i, std::complex<T>& arg, Pargs&... args) {
if (i < 2)
return reinterpret_cast<T(&)[2]>(arg)[i];
else
return get(i - 2, args...);
}

/**
* Return arg if i == 0. Otherwise, return
* the (i - 1)th element of the remaining args
Expand Down
15 changes: 15 additions & 0 deletions stan/math/prim/scal/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ inline double value_of<double>(double x) {
return x;
}

/**
* Return the specified argument.
*
* <p>See <code>value_of(T)</code> for a polymorphic
* implementation using static casts.
*
* <p>This inline pass-through no-op should be compiled away.
*
* @param x value
* @return input value
*/
inline std::complex<double> value_of(std::complex<double> x) {
return x;
}

/**
* Return the specified argument.
*
Expand Down
20 changes: 20 additions & 0 deletions stan/math/rev/scal/fun/to_var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ inline var to_var(double x) { return var(x); }
*/
inline var to_var(const var& x) { return x; }

/**
* Converts argument to an automatic differentiation variable.
*
* Returns a std::complex<var> variable with the values given in x.
*
* @param[in] x A complex value
* @return An automatic differentiation variable with the input value.
*/
inline std::complex<var> to_var(std::complex<double> x) {
return std::complex<var>(x);
}

/**
* Return input if it is already a std::complex<var>.
*
* @param[in] x An automatic differentiation variable.
* @return An automatic differentiation variable with the input value.
*/
inline std::complex<var> to_var(const std::complex<var>& x) { return x; }

} // namespace math
} // namespace stan
#endif
10 changes: 10 additions & 0 deletions stan/math/rev/scal/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ namespace math {
*/
inline double value_of(const var& v) { return v.vi_->val_; }

/**
* Return the values of the specified std::complex<var> variable.
*
* @param v Variable.
* @return Value of variable.
*/
inline std::complex<double> value_of(const std::complex<var>& v) {
return std::complex<double>(value_of(v.real()), value_of(v.imag()));
}

} // namespace math
} // namespace stan
#endif
39 changes: 39 additions & 0 deletions test/unit/math/prim/mat/fun/promote_double_to_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,45 @@ TEST(MathFunctions, promote_double_to_mix) {
EXPECT_FLOAT_EQ(static_cast<float>(1.0), std::get<0>(y2));
}

TEST(MathFunctions, promote_double_to_std_complex_float) {
std::complex<float> x(1, 2);
auto y = stan::math::promote_double_to<double>(std::make_tuple(x));

EXPECT_EQ(0, std::tuple_size<decltype(y)>::value);
EXPECT_TRUE((std::is_same<std::tuple<>, decltype(y)>::value));
}

TEST(MathFunctions, promote_double_to_std_complex_double) {
std::complex<double> x(1.0, 2.0);
auto y = stan::math::promote_double_to<float>(std::make_tuple(x));

EXPECT_TRUE(
(std::is_same<std::tuple<std::complex<float> >, decltype(y)>::value));
EXPECT_EQ(1, std::tuple_size<decltype(y)>::value);
EXPECT_FLOAT_EQ(static_cast<float>(1.0), std::get<0>(y).real());
EXPECT_FLOAT_EQ(static_cast<float>(2.0), std::get<0>(y).imag());
}

TEST(MathFunctions, promote_double_to_std_complex_mix) {
std::complex<float> xi(1, 2);
std::complex<double> xd1(3.0, 4.0);
std::complex<double> xd2(5.0, 6.0);
auto y1 = stan::math::promote_double_to<float>(std::make_tuple(xd1, xi, xd2));
auto y2 = stan::math::promote_double_to<float>(std::make_tuple(xd1, xi));
EXPECT_TRUE((std::is_same<std::tuple<std::complex<float>, std::complex<float> >,
decltype(y1)>::value));
EXPECT_TRUE(
(std::is_same<std::tuple<std::complex<float> >, decltype(y2)>::value));
EXPECT_EQ(2, std::tuple_size<decltype(y1)>::value);
EXPECT_EQ(1, std::tuple_size<decltype(y2)>::value);
EXPECT_FLOAT_EQ(static_cast<float>(3.0), std::get<0>(y1).real());
EXPECT_FLOAT_EQ(static_cast<float>(4.0), std::get<0>(y1).imag());
EXPECT_FLOAT_EQ(static_cast<float>(5.0), std::get<1>(y1).real());
EXPECT_FLOAT_EQ(static_cast<float>(6.0), std::get<1>(y1).imag());
EXPECT_FLOAT_EQ(static_cast<float>(3.0), std::get<0>(y2).real());
EXPECT_FLOAT_EQ(static_cast<float>(4.0), std::get<0>(y2).imag());
}

TEST(MathFunctions, promote_double_to_std_vector_int) {
std::vector<int> x(3);
auto y = stan::math::promote_double_to<double>(std::make_tuple(x));
Expand Down
42 changes: 32 additions & 10 deletions test/unit/math/prim/mat/fun/variable_adapter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ TEST(MathFunctions, adapt_int) {
EXPECT_EQ(0, a.size());
}

TEST(MathFunctions, adapt_std_complex_double) {
std::complex<double> x(1.0, 2.0);

auto a = stan::math::make_variable_adapter<double>(x);

EXPECT_EQ(2, a.size());
EXPECT_FLOAT_EQ(x.real(), a(0));
EXPECT_FLOAT_EQ(x.imag(), a(1));
}

TEST(MathFunctions, adapt_std_complex_float) {
std::complex<float> x(1.0, 2.0);

auto a = stan::math::make_variable_adapter<double>(x);

EXPECT_EQ(0, a.size());
}

TEST(MathFunctions, adapt_std_vector_double) {
std::vector<double> x = {{1.0, 2.0}};

Expand Down Expand Up @@ -81,6 +99,8 @@ TEST(MathFunctions, adapt_std_eigen_matrix_double) {
TEST(MathFunctions, adapt_all_types) {
double xd = 1.0;
int xi = 1;
std::complex<double> xdc(1, 2);
std::complex<float> xfc(3, 4);
std::vector<double> xdv = {{1.0, 2.0}};
std::vector<int> xiv = {{1, 2}};
Eigen::VectorXd xev(2);
Expand All @@ -91,23 +111,25 @@ TEST(MathFunctions, adapt_all_types) {
xrev << 2.0, 3.0;
xem << 4.0, 5.0, 6.0, 7.0, 8.0, 9.0;

auto a = stan::math::make_variable_adapter<double>(xd, xi, xdv, xiv, xd, xev,
xrev, xem);
auto a = stan::math::make_variable_adapter<double>(xd, xi, xdc, xfc, xdv, xiv,
xd, xev, xrev, xem);

EXPECT_EQ(2 + xdv.size() + xev.size() + xrev.size() + xem.size(), a.size());
EXPECT_EQ(1 + 2 + xdv.size() + 1 + xev.size() + xrev.size() + xem.size(), a.size());
EXPECT_FLOAT_EQ(xd, a(0));
EXPECT_FLOAT_EQ(xdc.real(), a(1));
EXPECT_FLOAT_EQ(xdc.imag(), a(2));
for (size_t i = 0; i < xdv.size(); ++i)
EXPECT_FLOAT_EQ(xdv[i], a(1 + i));
EXPECT_FLOAT_EQ(xd, a(1 + xdv.size()));
EXPECT_FLOAT_EQ(xdv[i], a(3 + i));
EXPECT_FLOAT_EQ(xd, a(3 + xdv.size()));

a(xdv.size()) = 5.0;
EXPECT_FLOAT_EQ(5.0, a(xdv.size()));
a(1 + xdv.size()) = 4.0;
EXPECT_FLOAT_EQ(4.0, a(1 + xdv.size()));
a(3 + xdv.size()) = 4.0;
EXPECT_FLOAT_EQ(4.0, a(3 + xdv.size()));
for (size_t i = 0; i < xev.size(); ++i)
EXPECT_FLOAT_EQ(xev(i), a(2 + xdv.size() + i));
EXPECT_FLOAT_EQ(xev(i), a(4 + xdv.size() + i));
for (size_t i = 0; i < xrev.size(); ++i)
EXPECT_FLOAT_EQ(xrev(i), a(2 + xev.size() + xdv.size() + i));
EXPECT_FLOAT_EQ(xrev(i), a(4 + xev.size() + xdv.size() + i));
for (size_t i = 0; i < xem.size(); ++i)
EXPECT_FLOAT_EQ(xem(i), a(2 + xrev.size() + xev.size() + xdv.size() + i));
EXPECT_FLOAT_EQ(xem(i), a(4 + xrev.size() + xev.size() + xdv.size() + i));
}
7 changes: 7 additions & 0 deletions test/unit/math/prim/scal/fun/value_of_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ TEST(MathFunctions, value_of) {
EXPECT_FLOAT_EQ(5.0, value_of(5));
}

TEST(MathFunctions, value_of_complex) {
using stan::math::value_of;
std::complex<double> x(1.0, 2.0);
EXPECT_FLOAT_EQ(1.0, value_of(x).real());
EXPECT_FLOAT_EQ(2.0, value_of(x).imag());
}

TEST(MathFunctions, value_of_nan) {
double nan = std::numeric_limits<double>::quiet_NaN();

Expand Down
Loading

0 comments on commit d18f00e

Please sign in to comment.