Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic var templates for operators and std::iterator_trait var/fvar specialization #1525

Merged
merged 39 commits into from
Feb 3, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8309363
Adds std iterator_traits specialization for fvar and var. The var ope…
SteveBronder Dec 18, 2019
bd9b20c
cleanup the template names
SteveBronder Dec 18, 2019
f8f2e3a
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Dec 18, 2019
8806331
cleanup the unary var templates
SteveBronder Dec 18, 2019
a6b4d81
cleanup the unary var templates
SteveBronder Dec 18, 2019
3868acf
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Dec 18, 2019
5e07ef9
replace var_or_fvar_t with autodiff_t in require_generics tests
SteveBronder Dec 18, 2019
119a7e7
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Dec 18, 2019
1348dda
remove template from ostream for var
SteveBronder Dec 18, 2019
858e1f2
Merge branch 'cleanup/generic-templates-var' of github.com:stan-dev/m…
SteveBronder Dec 18, 2019
ec6345a
fix order of members in kinsoldata
SteveBronder Dec 18, 2019
4300497
revert beta-binomial changes
SteveBronder Dec 20, 2019
fd35443
Fixup template order for core var impls
SteveBronder Dec 24, 2019
73181b4
Merge commit 'cb961050fd2bb013cc658eb8e3eafaa6bdf1cf38' into HEAD
yashikno Dec 24, 2019
36777ab
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Dec 24, 2019
ad62f4f
Moves cmath var and fvar stuff over to prim and uses initlializer bra…
SteveBronder Dec 30, 2019
fd2a4a8
Merge commit '65aec14f5caea8b9e38a7475afcd4a6681648497' into HEAD
yashikno Dec 30, 2019
930329d
[Jenkins] auto-formatting by clang-format version 5.0.2-svn328729-1~e…
stan-buildbot Dec 30, 2019
6185370
put back the namespace call for exp in gp_periodic_cov and remove sta…
SteveBronder Dec 30, 2019
12ec611
Merge branch 'cleanup/generic-templates-var' of github.com:stan-dev/m…
SteveBronder Dec 30, 2019
88c86e8
merge to develop
SteveBronder Dec 30, 2019
5564039
merge to develop
SteveBronder Dec 31, 2019
a694d1a
Merge branch 'develop' into cleanup/generic-templates-var
SteveBronder Jan 3, 2020
5740cd4
Merge remote-tracking branch 'origin/develop' into cleanup/generic-te…
SteveBronder Jan 5, 2020
4cba532
Merge remote-tracking branch 'origin/develop' into cleanup/generic-te…
SteveBronder Jan 5, 2020
30ee8c9
Remove Var&& arguments for const var& arguments in rev/core
SteveBronder Jan 6, 2020
1a565b6
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Jan 6, 2020
75eaf43
merge to develop
SteveBronder Jan 13, 2020
1342d33
Merge branch 'develop' of https://github.com/stan-dev/math into clean…
Jan 21, 2020
3604151
split cmath.hpp into files w. unit tests
Jan 22, 2020
9eebaa7
fix isnan tests to call right fun
Jan 22, 2020
993ceaa
testing pass by value for var ops instead of by reference
SteveBronder Jan 25, 2020
cd60021
Merge branch 'cleanup/pass-var-by-value', remote-tracking branch 'ori…
SteveBronder Jan 25, 2020
ced2df7
Merge branch 'cleanup/generic-templates-var' of github.com:stan-dev/m…
SteveBronder Jan 25, 2020
8d460ac
Merge commit 'c0d2265f842a1b2df04855fac49d87e9962aa878' into HEAD
yashikno Jan 25, 2020
2795719
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Jan 25, 2020
99a916b
Merge branch 'develop' of https://github.com/stan-dev/math into clean…
Jan 29, 2020
84f5b26
Merge branch 'develop' of https://github.com/stan-dev/math into clean…
Jan 31, 2020
06ad089
templated var overloads for pow
Jan 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_greater_than.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ inline bool operator>(Var1&& a, Var2&& b) {
* @param b Second value.
* @return True if first variable's value is greater than second value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator>(Var&& a, Arith b) {
return a.val() > b;
}
Expand All @@ -56,8 +55,7 @@ inline bool operator>(Var&& a, Arith b) {
* @param b Second variable.
* @return True if first value is greater than second variable's value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator>(Arith a, Var&& b) {
return a > b.val();
}
Expand Down
5 changes: 2 additions & 3 deletions stan/math/rev/core/operator_greater_than_or_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ inline bool operator>=(Var1&& a, Var2&& b) {
* @return True if first variable's value is greater than or equal
* to second value.
*/
template <typename Var, typename Arith, require_arithmetic_t<Arith>...,
require_var_t<Var>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator>=(Var&& a, Arith b) {
return a.val() >= b;
}
Expand All @@ -60,7 +59,7 @@ inline bool operator>=(Var&& a, Arith b) {
* @return True if the first value is greater than or equal to the
* second variable's value.
*/
template <typename Var, typename Arith, require_arithmetic_t<Arith>...,
template <typename Arith, typename Var, require_arithmetic_t<Arith>...,
require_var_t<Var>...>
inline bool operator>=(Arith a, Var&& b) {
return a >= b.val();
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_less_than.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ inline bool operator<(Var1&& a, Var2&& b) {
* @param b Second value.
* @return True if first variable's value is less than second value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator<(Var&& a, Arith b) {
return a.val() < b;
}
Expand All @@ -55,8 +54,7 @@ inline bool operator<(Var&& a, Arith b) {
* @param b Second variable.
* @return True if first value is less than second variable's value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator<(Arith a, Var&& b) {
return a < b.val();
}
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_less_than_or_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ inline bool operator<=(Var1&& a, Var2&& b) {
* @return True if first variable's value is less than or equal to
* the second value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator<=(Var&& a, Arith b) {
return a.val() <= b;
}
Expand All @@ -59,8 +58,7 @@ inline bool operator<=(Var&& a, Arith b) {
* @return True if first value is less than or equal to the second
* variable's value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator<=(Arith a, Var&& b) {
return a <= b.val();
}
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_logical_and.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ inline bool operator&&(Var1&& x, Var2&& y) {
* @return conjunction of first argument's value and second
* argument
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator&&(Var&& x, Arith y) {
return x.val() && y;
}
Expand All @@ -50,8 +49,7 @@ inline bool operator&&(Var&& x, Arith y) {
* @return conjunction of first argument and second argument's
* value
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator&&(Arith x, Var&& y) {
return x && y.val();
}
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_logical_or.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ inline bool operator||(Var1&& x, Var2&& y) {
* @return disjunction of first argument's value and second
* argument
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator||(Var&& x, Arith y) {
return x.val() || y;
}
Expand All @@ -50,8 +49,7 @@ inline bool operator||(Var&& x, Arith y) {
* @return disjunction of first argument and the second
* argument's value
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator||(Arith x, Var&& y) {
return x || y.val();
}
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_multiplication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ inline var operator*(Var1&& a, Var2&& b) {
* @param b Scalar operand.
* @return Variable result of multiplying operands.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline var operator*(Var&& a, Arith b) {
if (b == 1.0) {
return a;
Expand All @@ -119,8 +118,7 @@ inline var operator*(Var&& a, Arith b) {
* @param b Variable operand.
* @return Variable result of multiplying the operands.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline var operator*(Arith a, Var&& b) {
if (a == 1.0) {
return b;
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_not_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ inline bool operator!=(Var1&& a, Var2&& b) {
* @return True if the first variable's value is not the same as the
* second value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline bool operator!=(Var&& a, Arith b) {
return a.val() != b;
}
Expand All @@ -59,8 +58,7 @@ inline bool operator!=(Var&& a, Arith b) {
* @return True if the first value is not the same as the
* second variable's value.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline bool operator!=(Arith a, Var&& b) {
return a != b.val();
}
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/operator_subtraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ inline var operator-(Var1&& a, Var2&& b) {
* @param b Second scalar operand.
* @return Result of subtracting the scalar from the variable.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Var, typename Arith, require_var_t<Var>..., require_arithmetic_t<Arith>...>
inline var operator-(Var&& a, Arith b) {
if (b == 0.0) {
return a;
Expand All @@ -133,8 +132,7 @@ inline var operator-(Var&& a, Arith b) {
* @param b Second variable operand.
* @return Result of sutracting a variable from a scalar.
*/
template <typename Var, typename Arith, require_var_t<Var>...,
require_arithmetic_t<Arith>...>
template <typename Arith, typename Var, require_arithmetic_t<Arith>..., require_var_t<Var>...>
inline var operator-(Arith a, Var&& b) {
return var(new internal::subtract_dv_vari(a, b.vi_));
}
Expand Down
19 changes: 14 additions & 5 deletions stan/math/rev/core/precomputed_gradients.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ class precomputed_gradients_vari : public vari {
* @throws std::invalid_argument if the sizes of the vectors
* don't match.
*/
precomputed_gradients_vari(double val, const std::vector<var>& vars,
const std::vector<double>& gradients)
template <typename Arith, typename VecVar, typename VecArith,
require_arithmetic_t<Arith>...,
require_vector_like_vt<is_var, VecVar>...,
require_vector_like_vt<std::is_arithmetic, VecArith>...>
precomputed_gradients_vari(Arith val, VecVar&& vars, VecArith&& gradients)
: vari(val),
size_(vars.size()),
varis_(ChainableStack::instance_->memalloc_.alloc_array<vari*>(
Expand Down Expand Up @@ -80,16 +83,22 @@ class precomputed_gradients_vari : public vari {
* specified value, vector of operands, and vector of partial
* derivatives of value with respect to the operands.
*
* @tparam Arith An arithmetic type
* @tparam VecVar A vector of vars
* @tparam VecArith A vector of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops...

* @param[in] value The value of the resulting dependent variable.
* @param[in] operands operands.
* @param[in] gradients vector of partial derivatives of result with
* respect to operands.
* @return An auto-diff variable that uses the precomputed
* gradients provided.
*/
inline var precomputed_gradients(double value, const std::vector<var>& operands,
const std::vector<double>& gradients) {
return var(new precomputed_gradients_vari(value, operands, gradients));
template <typename Arith, typename VecVar, typename VecArith,
require_arithmetic_t<Arith>...,
require_vector_like_vt<is_var, VecVar>...,
require_vector_like_vt<std::is_arithmetic, VecArith>...>
inline auto precomputed_gradients(Arith value, VecVar&& operands, VecArith&& gradients) {
return var(new precomputed_gradients_vari(value, std::forward<VecVar>(operands), std::forward<VecArith>(gradients)));
}
} // namespace math
} // namespace stan
Expand Down
3 changes: 2 additions & 1 deletion stan/math/rev/core/std_isinf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace std {
* @param a Argument.
* @return 1 if argument is infinite and 0 otherwise.
*/
inline int isinf(const stan::math::var& a) {
template <typename Var, stan::require_var_t<Var>...>
inline bool isinf(Var&& a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be made fully generic in prim.

return stan::math::is_inf(a.val());
}

Expand Down
3 changes: 2 additions & 1 deletion stan/math/rev/core/std_isnan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace std {
* @param a Variable to test.
* @return <code>true</code> if value is not a number.
*/
inline int isnan(const stan::math::var& a) { return isnan(a.val()); }
template <typename Var, stan::require_var_t<Var>...>
inline bool isnan(Var&& a) { return isnan(a.val()); }

} // namespace std
#endif
3 changes: 2 additions & 1 deletion stan/math/rev/core/vector_vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class op_vector_vari : public vari {
vari** vis_;

public:
op_vector_vari(double f, const std::vector<var>& vs)
template <typename Arith, typename VecVar, require_arithmetic_t<Arith>..., require_vector_like_vt<is_var, VecVar>...>
op_vector_vari(Arith f, VecVar&& vs)
: vari(f), size_(vs.size()) {
vis_ = reinterpret_cast<vari**>(operator new(sizeof(vari*) * vs.size()));
for (size_t i = 0; i < vs.size(); ++i) {
Expand Down