Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add row vector, array and int_array construction utilities #1636

Merged
14 changes: 12 additions & 2 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
#include <stan/math/prim/fun/columns_dot_product.hpp>
#include <stan/math/prim/fun/columns_dot_self.hpp>
#include <stan/math/prim/fun/common_type.hpp>
#include <stan/math/prim/fun/constant_vector.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/corr_constrain.hpp>
#include <stan/math/prim/fun/corr_free.hpp>
Expand Down Expand Up @@ -198,7 +197,13 @@
#include <stan/math/prim/fun/num_elements.hpp>
#include <stan/math/prim/fun/offset_multiplier_constrain.hpp>
#include <stan/math/prim/fun/offset_multiplier_free.hpp>
#include <stan/math/prim/fun/one_hot_array.hpp>
#include <stan/math/prim/fun/one_hot_int_array.hpp>
#include <stan/math/prim/fun/one_hot_row_vector.hpp>
#include <stan/math/prim/fun/one_hot_vector.hpp>
#include <stan/math/prim/fun/ones_array.hpp>
#include <stan/math/prim/fun/ones_int_array.hpp>
#include <stan/math/prim/fun/ones_row_vector.hpp>
#include <stan/math/prim/fun/ones_vector.hpp>
#include <stan/math/prim/fun/ordered_constrain.hpp>
#include <stan/math/prim/fun/ordered_free.hpp>
Expand Down Expand Up @@ -243,7 +248,6 @@
#include <stan/math/prim/fun/scaled_add.hpp>
#include <stan/math/prim/fun/sd.hpp>
#include <stan/math/prim/fun/segment.hpp>
#include <stan/math/prim/fun/set_spaced_vector.hpp>
#include <stan/math/prim/fun/sign.hpp>
#include <stan/math/prim/fun/simplex_constrain.hpp>
#include <stan/math/prim/fun/simplex_free.hpp>
Expand All @@ -257,6 +261,9 @@
#include <stan/math/prim/fun/sort_indices.hpp>
#include <stan/math/prim/fun/sort_indices_asc.hpp>
#include <stan/math/prim/fun/sort_indices_desc.hpp>
#include <stan/math/prim/fun/spaced_array.hpp>
#include <stan/math/prim/fun/spaced_row_vector.hpp>
#include <stan/math/prim/fun/spaced_vector.hpp>
#include <stan/math/prim/fun/sqrt.hpp>
#include <stan/math/prim/fun/square.hpp>
#include <stan/math/prim/fun/squared_distance.hpp>
Expand Down Expand Up @@ -295,6 +302,9 @@
#include <stan/math/prim/fun/variance.hpp>
#include <stan/math/prim/fun/welford_covar_estimator.hpp>
#include <stan/math/prim/fun/welford_var_estimator.hpp>
#include <stan/math/prim/fun/zeros_array.hpp>
#include <stan/math/prim/fun/zeros_int_array.hpp>
#include <stan/math/prim/fun/zeros_row_vector.hpp>
#include <stan/math/prim/fun/zeros_vector.hpp>

#endif
27 changes: 0 additions & 27 deletions stan/math/prim/fun/constant_vector.hpp

This file was deleted.

33 changes: 33 additions & 0 deletions stan/math/prim/fun/one_hot_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_ONE_HOT_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an array with 1 in the k-th position and zero elsewhere.
*
* @param K size of the array
* @param k position of the 1 (indexing from 1)
* @return An array of size K with all elements initialised to zero
* and a 1 in the k-th position.
* @throw std::domain_error if K is not positive, or if k is less than 1 or
* greater than K.
*/
inline std::vector<double> one_hot_array(int K, int k) {
static const char* function = "one_hot_array";
check_positive(function, "size", K);
check_bounded(function, "k", k, 1, K);

std::vector<double> v(K, 0);
v[k - 1] = 1;
return v;
}

} // namespace math
} // namespace stan

#endif
33 changes: 33 additions & 0 deletions stan/math/prim/fun/one_hot_int_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_INT_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_ONE_HOT_INT_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an integer array with 1 in the k-th position and zero elsewhere.
*
* @param K size of the array
* @param k position of the 1 (indexing from 1)
* @return An integer array of size K with all elements initialised to zero
* and a 1 in the k-th position.
* @throw std::domain_error if K is not positive, or if k is less than 1 or
* greater than K.
*/
inline std::vector<int> one_hot_int_array(int K, int k) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]
Probably not worth refactoring, but a single implementation would work here. I realize there'd need to be two fronting functions to work easily with the Stan language, but the core implementation is shared. They'd probably need an enable-if to ensure the type is arithmetic.

That's true of most of these functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll pass on this for now.

static const char* function = "one_hot_int_array";
check_positive(function, "size", K);
check_bounded(function, "k", k, 1, K);

std::vector<int> v(K, 0);
v[k - 1] = 1;
return v;
}

} // namespace math
} // namespace stan

#endif
33 changes: 33 additions & 0 deletions stan/math/prim/fun/one_hot_row_vector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef STAN_MATH_PRIM_FUN_ONE_HOT_ROW_VECTOR_HPP
#define STAN_MATH_PRIM_FUN_ONE_HOT_ROW_VECTOR_HPP

#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

namespace stan {
namespace math {

/**
* Return a row vector with 1 in the k-th position and zero elsewhere
mcol marked this conversation as resolved.
Show resolved Hide resolved
*
* @param K size of the row vector
* @param k position of the 1 (indexing from 1)
* @return A row vector of size K with all elements initialised to zero
* and a 1 in the k-th position.
* @throw std::domain_error if K is not positive, or if k is less than 1 or
* greater than K.
*/
inline Eigen::RowVectorXd one_hot_row_vector(int K, int k) {
static const char* function = "one_hot_row_vector";
check_positive(function, "size", K);
check_bounded(function, "k", k, 1, K);

Eigen::RowVectorXd ret = Eigen::RowVectorXd::Zero(K);
ret(k - 1) = 1;
return ret;
}

} // namespace math
} // namespace stan

#endif
25 changes: 25 additions & 0 deletions stan/math/prim/fun/ones_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef STAN_MATH_PRIM_FUN_ONES_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_ONES_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an array of ones.
*
* @param K size of the array
* @return An array of size K with all elements initialised to 1.
* @throw std::domain_error if K is negative.
*/
inline std::vector<double> ones_array(int K) {
check_nonnegative("ones_array", "size", K);
return std::vector<double>(K, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be return {K, 1}; You can always return curly braces around constructor args for the return type.

And same for next definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work:

./stan/math/prim/fun/ones_array.hpp:19:11: error: non-constant-expression cannot be
      narrowed from type 'int' to 'double' in initializer list [-Wc++11-narrowing]
  return {K, 1};
          ^
./stan/math/prim/fun/ones_array.hpp:19:11: note: insert an explicit cast to silence
      this issue
  return {K, 1};
          ^
          static_cast<double>( )

If I use the static cast around K, it builds a vector of two elements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying. I've never seen that error and it's suprrising in that I thought you could always replace return-type constructor calls with braces. It clearly requires a stricter match.

Copy link
Contributor

@bob-carpenter bob-carpenter Jan 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look like declaring const int K might also work, but I'd rather not sprinkle consts around.

}

} // namespace math
} // namespace stan

#endif
25 changes: 25 additions & 0 deletions stan/math/prim/fun/ones_int_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef STAN_MATH_PRIM_FUN_ONES_INT_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_ONES_INT_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an integer array of ones.
*
* @param K size of the array
* @return An integer array of size K with all elements initialised to 1.
* @throw std::domain_error if K is negative.
*/
inline std::vector<int> ones_int_array(int K) {
check_nonnegative("ones_int_array", "size", K);
return std::vector<int>(K, 1);
}

} // namespace math
} // namespace stan

#endif
25 changes: 25 additions & 0 deletions stan/math/prim/fun/ones_row_vector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef STAN_MATH_PRIM_FUN_ONES_ROW_VECTOR_HPP
#define STAN_MATH_PRIM_FUN_ONES_ROW_VECTOR_HPP

#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

namespace stan {
namespace math {

/**
* Return a row vector of ones
*
* @param K size of the row vector
* @return A row vector of size K with all elements initialised to 1.
* @throw std::domain_error if K is negative.
*/
inline Eigen::RowVectorXd ones_row_vector(int K) {
check_nonnegative("ones_row_vector", "size", K);
return Eigen::RowVectorXd::Constant(K, 1);
}

} // namespace math
} // namespace stan

#endif
44 changes: 44 additions & 0 deletions stan/math/prim/fun/spaced_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef STAN_MATH_PRIM_FUN_SPACED_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_SPACED_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an array of linearly spaced elements.
*
* This produces an array from low to high (included) with elements spaced
* as (high - low) / (K - 1). For K=1, the array will contain the high value;
* for K=0 it returns an empty array.
*
* @param K size of the array
* @param low smallest value
* @param high largest value
* @return An array of size K with elements linearly spaced between
* low and high.
* @throw std::domain_error if K is negative, if low is nan or infinite,
* if high is nan or infinite, or if high is less than low.
*/
inline std::vector<double> spaced_array(int K, double low, double high) {
static const char* function = "spaced_array";
check_nonnegative(function, "size", K);
check_finite(function, "low", low);
check_finite(function, "high", high);
check_greater_or_equal(function, "high", high, low);

if (K == 0) {
mcol marked this conversation as resolved.
Show resolved Hide resolved
return {};
}

Eigen::VectorXd v = Eigen::VectorXd::LinSpaced(K, low, high);
return {&v[0], &v[0] + K};
}

} // namespace math
} // namespace stan

#endif
38 changes: 38 additions & 0 deletions stan/math/prim/fun/spaced_row_vector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef STAN_MATH_PRIM_FUN_SPACED_ROW_VECTOR_HPP
#define STAN_MATH_PRIM_FUN_SPACED_ROW_VECTOR_HPP

#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

namespace stan {
namespace math {

/**
* Return a row vector of linearly spaced elements.
*
* This produces a row vector from low to high (included) with elements spaced
* as (high - low) / (K - 1). For K=1, the vector will contain the high value;
* for K=0 it returns an empty vector.
*
* @param K size of the row vector
* @param low smallest value
* @param high largest value
* @return A row vector of size K with elements linearly spaced between
* low and high.
* @throw std::domain_error if K is negative, if low is nan or infinite,
* if high is nan or infinite, or if high is less than low.
*/
inline Eigen::RowVectorXd spaced_row_vector(int K, double low, double high) {
static const char* function = "spaced_row_vector";
check_nonnegative(function, "size", K);
check_finite(function, "low", low);
check_finite(function, "high", high);
check_greater_or_equal(function, "high", high, low);

return Eigen::RowVectorXd::LinSpaced(K, low, high);
}

} // namespace math
} // namespace stan

#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef STAN_MATH_PRIM_FUN_SET_SPACED_VECTOR_HPP
#define STAN_MATH_PRIM_FUN_SET_SPACED_VECTOR_HPP
#ifndef STAN_MATH_PRIM_FUN_SPACED_VECTOR_HPP
#define STAN_MATH_PRIM_FUN_SPACED_VECTOR_HPP

#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
Expand All @@ -8,7 +8,7 @@ namespace stan {
namespace math {

/**
* Return a vector of linearly spaced elements
* Return a vector of linearly spaced elements.
*
* This produces a vector from low to high (included) with elements spaced
* as (high - low) / (K - 1). For K=1, the vector will contain the high value;
Expand All @@ -22,8 +22,8 @@ namespace math {
* @throw std::domain_error if K is negative, if low is nan or infinite,
* if high is nan or infinite, or if high is less than low.
*/
inline Eigen::VectorXd set_spaced_vector(int K, double low, double high) {
static const char* function = "set_spaced_vector";
inline Eigen::VectorXd spaced_vector(int K, double low, double high) {
static const char* function = "spaced_vector";
check_nonnegative(function, "size", K);
check_finite(function, "low", low);
check_finite(function, "high", high);
Expand Down
25 changes: 25 additions & 0 deletions stan/math/prim/fun/zeros_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef STAN_MATH_PRIM_FUN_ZEROS_ARRAY_HPP
#define STAN_MATH_PRIM_FUN_ZEROS_ARRAY_HPP

#include <stan/math/prim/err.hpp>
#include <vector>

namespace stan {
namespace math {

/**
* Return an array of zeros.
*
* @param K size of the array
* @return an array of size K with all elements initialised to 0.
* @throw std::domain_error if K is negative.
*/
inline std::vector<double> zeros_array(int K) {
check_nonnegative("zeros_array", "size", K);
return std::vector<double>(K, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use braces constructor return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as in the ones case.

}

} // namespace math
} // namespace stan

#endif
Loading