diff --git a/stan/math/rev/core/var.hpp b/stan/math/rev/core/var.hpp index 4b88509a6e2..5bfd61fda72 100644 --- a/stan/math/rev/core/var.hpp +++ b/stan/math/rev/core/var.hpp @@ -7,6 +7,9 @@ #include #include #include +#include +#include +#include namespace stan { namespace math { @@ -179,6 +182,69 @@ class var { // NOLINTNEXTLINE var(unsigned long x) : vi_(new vari(static_cast(x))) {} // NOLINT + /** + * Construct a variable from the specified arithmetic argument + * by constructing a new vari with the argument + * cast to double, and a zero adjoint. Only works + * if the imaginary part is zero. + * + * @param x Value of the variable. + */ + explicit var(const std::complex& x) { + if (imag(x) == 0) { + vi_ = new vari(real(x)); + } else { + std::stringstream ss; + ss << "Imaginary part of std::complex used to construct var" + << " must be zero. Found real part = " << real(x) << " and " + << " found imaginary part = " << imag(x) << std::endl; + std::string msg = ss.str(); + throw std::invalid_argument(msg); + } + } + + /** + * Construct a variable from the specified arithmetic argument + * by constructing a new vari with the argument + * cast to double, and a zero adjoint. Only works + * if the imaginary part is zero. + * + * @param x Value of the variable. + */ + explicit var(const std::complex& x) { + if (imag(x) == 0) { + vi_ = new vari(static_cast(real(x))); + } else { + std::stringstream ss; + ss << "Imaginary part of std::complex used to construct var" + << " must be zero. Found real part = " << real(x) << " and " + << " found imaginary part = " << imag(x) << std::endl; + std::string msg = ss.str(); + throw std::invalid_argument(msg); + } + } + + /** + * Construct a variable from the specified arithmetic argument + * by constructing a new vari with the argument + * cast to double, and a zero adjoint. Only works + * if the imaginary part is zero. + * + * @param x Value of the variable. + */ + explicit var(const std::complex& x) { + if (imag(x) == 0) { + vi_ = new vari(static_cast(real(x))); + } else { + std::stringstream ss; + ss << "Imaginary part of std::complex used to construct var" + << " must be zero. Found real part = " << real(x) << " and " + << " found imaginary part = " << imag(x) << std::endl; + std::string msg = ss.str(); + throw std::invalid_argument(msg); + } + } + #ifdef _WIN64 // these two ctors are for Win64 to enable 64-bit signed diff --git a/test/unit/math/rev/core/var_test.cpp b/test/unit/math/rev/core/var_test.cpp index 8805d36ab92..dd174f64e6e 100644 --- a/test/unit/math/rev/core/var_test.cpp +++ b/test/unit/math/rev/core/var_test.cpp @@ -69,6 +69,16 @@ TEST_F(AgradRev, ctorOverloads) { // ptrdiff_t EXPECT_FLOAT_EQ(37, var(static_cast(37)).val()); EXPECT_FLOAT_EQ(0, var(static_cast(0)).val()); + + // complex but with zero imaginary part + EXPECT_FLOAT_EQ(37, var(std::complex(37, 0)).val()); + EXPECT_FLOAT_EQ(37, var(std::complex(37, 0)).val()); + EXPECT_FLOAT_EQ(37, var(std::complex(37, 0)).val()); + + // complex but with non-zero imaginary part + EXPECT_THROW(var(std::complex(37, 10)), std::invalid_argument); + EXPECT_THROW(var(std::complex(37, 10)), std::invalid_argument); + EXPECT_THROW(var(std::complex(37, 10)), std::invalid_argument); } TEST_F(AgradRev, a_eq_x) {