-
-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add row vector, array and int_array construction utilities #1636
Changes from all commits
f5b2efb
0531583
8c5e293
57af538
578eb6c
fd72020
26687de
e2eeb51
0777321
6a74d60
439300c
4510e06
fd882bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_ONE_HOT_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an array with 1 in the k-th position and zero elsewhere. | ||
* | ||
* @param K size of the array | ||
* @param k position of the 1 (indexing from 1) | ||
* @return An array of size K with all elements initialised to zero | ||
* and a 1 in the k-th position. | ||
* @throw std::domain_error if K is not positive, or if k is less than 1 or | ||
* greater than K. | ||
*/ | ||
inline std::vector<double> one_hot_array(int K, int k) { | ||
static const char* function = "one_hot_array"; | ||
check_positive(function, "size", K); | ||
check_bounded(function, "k", k, 1, K); | ||
|
||
std::vector<double> v(K, 0); | ||
v[k - 1] = 1; | ||
return v; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_INT_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_ONE_HOT_INT_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an integer array with 1 in the k-th position and zero elsewhere. | ||
* | ||
* @param K size of the array | ||
* @param k position of the 1 (indexing from 1) | ||
* @return An integer array of size K with all elements initialised to zero | ||
* and a 1 in the k-th position. | ||
* @throw std::domain_error if K is not positive, or if k is less than 1 or | ||
* greater than K. | ||
*/ | ||
inline std::vector<int> one_hot_int_array(int K, int k) { | ||
static const char* function = "one_hot_int_array"; | ||
check_positive(function, "size", K); | ||
check_bounded(function, "k", k, 1, K); | ||
|
||
std::vector<int> v(K, 0); | ||
v[k - 1] = 1; | ||
return v; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_ROW_VECTOR_HPP | ||
#define STAN_MATH_PRIM_FUN_ONE_HOT_ROW_VECTOR_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a row vector with 1 in the k-th position and zero elsewhere | ||
mcol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* @param K size of the row vector | ||
* @param k position of the 1 (indexing from 1) | ||
* @return A row vector of size K with all elements initialised to zero | ||
* and a 1 in the k-th position. | ||
* @throw std::domain_error if K is not positive, or if k is less than 1 or | ||
* greater than K. | ||
*/ | ||
inline Eigen::RowVectorXd one_hot_row_vector(int K, int k) { | ||
static const char* function = "one_hot_row_vector"; | ||
check_positive(function, "size", K); | ||
check_bounded(function, "k", k, 1, K); | ||
|
||
Eigen::RowVectorXd ret = Eigen::RowVectorXd::Zero(K); | ||
ret(k - 1) = 1; | ||
return ret; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONES_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_ONES_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an array of ones. | ||
* | ||
* @param K size of the array | ||
* @return An array of size K with all elements initialised to 1. | ||
* @throw std::domain_error if K is negative. | ||
*/ | ||
inline std::vector<double> ones_array(int K) { | ||
check_nonnegative("ones_array", "size", K); | ||
return std::vector<double>(K, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can just be And same for next definition. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't work:
If I use the static cast around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for trying. I've never seen that error and it's suprrising in that I thought you could always replace return-type constructor calls with braces. It clearly requires a stricter match. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does look like declaring |
||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONES_INT_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_ONES_INT_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an integer array of ones. | ||
* | ||
* @param K size of the array | ||
* @return An integer array of size K with all elements initialised to 1. | ||
* @throw std::domain_error if K is negative. | ||
*/ | ||
inline std::vector<int> ones_int_array(int K) { | ||
check_nonnegative("ones_int_array", "size", K); | ||
return std::vector<int>(K, 1); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ONES_ROW_VECTOR_HPP | ||
#define STAN_MATH_PRIM_FUN_ONES_ROW_VECTOR_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a row vector of ones | ||
* | ||
* @param K size of the row vector | ||
* @return A row vector of size K with all elements initialised to 1. | ||
* @throw std::domain_error if K is negative. | ||
*/ | ||
inline Eigen::RowVectorXd ones_row_vector(int K) { | ||
check_nonnegative("ones_row_vector", "size", K); | ||
return Eigen::RowVectorXd::Constant(K, 1); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_SPACED_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_SPACED_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an array of linearly spaced elements. | ||
* | ||
* This produces an array from low to high (included) with elements spaced | ||
* as (high - low) / (K - 1). For K=1, the array will contain the high value; | ||
* for K=0 it returns an empty array. | ||
* | ||
* @param K size of the array | ||
* @param low smallest value | ||
* @param high largest value | ||
* @return An array of size K with elements linearly spaced between | ||
* low and high. | ||
* @throw std::domain_error if K is negative, if low is nan or infinite, | ||
* if high is nan or infinite, or if high is less than low. | ||
*/ | ||
inline std::vector<double> spaced_array(int K, double low, double high) { | ||
static const char* function = "spaced_array"; | ||
check_nonnegative(function, "size", K); | ||
check_finite(function, "low", low); | ||
check_finite(function, "high", high); | ||
check_greater_or_equal(function, "high", high, low); | ||
|
||
if (K == 0) { | ||
mcol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return {}; | ||
} | ||
|
||
Eigen::VectorXd v = Eigen::VectorXd::LinSpaced(K, low, high); | ||
return {&v[0], &v[0] + K}; | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_SPACED_ROW_VECTOR_HPP | ||
#define STAN_MATH_PRIM_FUN_SPACED_ROW_VECTOR_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/Eigen.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a row vector of linearly spaced elements. | ||
* | ||
* This produces a row vector from low to high (included) with elements spaced | ||
* as (high - low) / (K - 1). For K=1, the vector will contain the high value; | ||
* for K=0 it returns an empty vector. | ||
* | ||
* @param K size of the row vector | ||
* @param low smallest value | ||
* @param high largest value | ||
* @return A row vector of size K with elements linearly spaced between | ||
* low and high. | ||
* @throw std::domain_error if K is negative, if low is nan or infinite, | ||
* if high is nan or infinite, or if high is less than low. | ||
*/ | ||
inline Eigen::RowVectorXd spaced_row_vector(int K, double low, double high) { | ||
static const char* function = "spaced_row_vector"; | ||
check_nonnegative(function, "size", K); | ||
check_finite(function, "low", low); | ||
check_finite(function, "high", high); | ||
check_greater_or_equal(function, "high", high, low); | ||
|
||
return Eigen::RowVectorXd::LinSpaced(K, low, high); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_ZEROS_ARRAY_HPP | ||
#define STAN_MATH_PRIM_FUN_ZEROS_ARRAY_HPP | ||
|
||
#include <stan/math/prim/err.hpp> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return an array of zeros. | ||
* | ||
* @param K size of the array | ||
* @return an array of size K with all elements initialised to 0. | ||
* @throw std::domain_error if K is negative. | ||
*/ | ||
inline std::vector<double> zeros_array(int K) { | ||
check_nonnegative("zeros_array", "size", K); | ||
return std::vector<double>(K, 0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should use braces constructor return. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as in the ones case. |
||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional]
Probably not worth refactoring, but a single implementation would work here. I realize there'd need to be two fronting functions to work easily with the Stan language, but the core implementation is shared. They'd probably need an enable-if to ensure the type is arithmetic.
That's true of most of these functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll pass on this for now.