diff --git a/stan/math/fwd/fun/abs.hpp b/stan/math/fwd/fun/abs.hpp index ffd21791ff0..1f50b149fd6 100644 --- a/stan/math/fwd/fun/abs.hpp +++ b/stan/math/fwd/fun/abs.hpp @@ -32,7 +32,7 @@ inline fvar abs(const fvar& x) { * @return absolute value of the argument */ template -inline std::complex> abs(const std::complex>& z) { +inline fvar abs(const std::complex>& z) { return internal::complex_abs(z); } diff --git a/stan/math/rev/fun/abs.hpp b/stan/math/rev/fun/abs.hpp index d77abd3e09e..9c9ad10dee4 100644 --- a/stan/math/rev/fun/abs.hpp +++ b/stan/math/rev/fun/abs.hpp @@ -44,9 +44,7 @@ inline var abs(const var& a) { return fabs(a); } * @param[in] z argument * @return absolute value of the argument */ -inline std::complex abs(const std::complex& z) { - return internal::complex_abs(z); -} +inline var abs(const std::complex& z) { return internal::complex_abs(z); } } // namespace math } // namespace stan diff --git a/test/unit/math/mix/fun/abs_test.cpp b/test/unit/math/mix/fun/abs_test.cpp index 11f5da342b2..47b59f23bc7 100644 --- a/test/unit/math/mix/fun/abs_test.cpp +++ b/test/unit/math/mix/fun/abs_test.cpp @@ -2,8 +2,9 @@ #include #include #include +#include -TEST(mixScalFun, abs) { +TEST(mixFun, abs) { auto f = [](const auto& x) { using std::abs; return abs(x); @@ -29,3 +30,12 @@ TEST(mixScalFun, abs) { } } } +TEST(mixFun, absReturnType) { + // validate return types not overpromoted to complex by assignability + std::complex a = 3; + stan::math::var b = abs(a); + + std::complex> c = 3; + stan::math::fvar d = abs(c); + SUCCEED(); +}