Skip to content

Commit

Permalink
Merge pull request #1769 from bstatcomp/cl_kernel_generator_transpose
Browse files Browse the repository at this point in the history
Add transposition to kernel generator
  • Loading branch information
t4c1 authored Mar 23, 2020
2 parents 7690b41 + 9174755 commit d903537
Show file tree
Hide file tree
Showing 24 changed files with 475 additions and 210 deletions.
1 change: 0 additions & 1 deletion stan/math/opencl/cholesky_decompose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/kernels/cholesky_decompose.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/opencl/prim/transpose.hpp>
#include <stan/math/prim/meta.hpp>
#include <CL/cl2.hpp>
#include <algorithm>
Expand Down
1 change: 1 addition & 0 deletions stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <stan/math/opencl/kernel_generator/select.hpp>
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/colwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/transpose.hpp>

#include <stan/math/opencl/kernel_generator/multi_result_kernel.hpp>
#include <stan/math/opencl/kernel_generator/get_kernel_source_for_evaluating_into.hpp>
Expand Down
24 changes: 12 additions & 12 deletions stan/math/opencl/kernel_generator/binary_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
* @return view
*/
inline matrix_cl_view view() const {
return either(std::get<0>(arguments_).view(),
std::get<1>(arguments_).view());
return either(this->template get_arg<0>().view(),
this->template get_arg<1>().view());
}
};

Expand Down Expand Up @@ -116,9 +116,9 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
public: \
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
inline auto deep_copy() { \
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
inline auto deep_copy() const { \
auto&& a_copy = this->template get_arg<0>().deep_copy(); \
auto&& b_copy = this->template get_arg<1>().deep_copy(); \
return class_name<std::remove_reference_t<decltype(a_copy)>, \
std::remove_reference_t<decltype(b_copy)>>( \
std::move(a_copy), std::move(b_copy)); \
Expand Down Expand Up @@ -163,9 +163,9 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
public: \
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
inline auto deep_copy() { \
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
inline auto deep_copy() const { \
auto&& a_copy = this->template get_arg<0>().deep_copy(); \
auto&& b_copy = this->template get_arg<1>().deep_copy(); \
return class_name<std::remove_reference_t<decltype(a_copy)>, \
std::remove_reference_t<decltype(b_copy)>>( \
std::move(a_copy), std::move(b_copy)); \
Expand All @@ -189,14 +189,14 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(
common_scalar_t<T_a COMMA T_b>, "*",
using base = binary_operation<elewise_multiplication_<T_a, T_b>,
common_scalar_t<T_a, T_b>, T_a, T_b>;
return both(std::get<0>(base::arguments_).view(),
std::get<1>(base::arguments_).view()););
return both(this->template get_arg<0>().view(),
this->template get_arg<1>().view()););
ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(
elewise_division_, elewise_division, common_scalar_t<T_a COMMA T_b>, "/",
using base = binary_operation<elewise_division_<T_a, T_b>,
common_scalar_t<T_a, T_b>, T_a, T_b>;
return either(std::get<0>(base::arguments_).view(),
invert(std::get<1>(base::arguments_).view())););
return either(this->template get_arg<0>().view(),
invert(this->template get_arg<1>().view())););
ADD_BINARY_OPERATION(less_than_, operator<, bool, "<");
ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(less_than_or_equal_, operator<=, bool,
"<=", return matrix_cl_view::Entire);
Expand Down
25 changes: 16 additions & 9 deletions stan/math/opencl/kernel_generator/block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class block_

protected:
int start_row_, start_col_, rows_, cols_;
using base::arguments_;

public:
/**
Expand All @@ -61,8 +60,8 @@ class block_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return block_<std::remove_reference_t<decltype(arg_copy)>>{
std::move(arg_copy), start_row_, start_col_, rows_, cols_};
}
Expand Down Expand Up @@ -122,7 +121,7 @@ class block_
cl::Kernel& kernel, int& arg_num) const {
if (generated.count(this) == 0) {
generated.insert(this);
std::get<0>(arguments_).set_args(generated, kernel, arg_num);
this->template get_arg<0>().set_args(generated, kernel, arg_num);
kernel.setArg(arg_num++, start_row_);
kernel.setArg(arg_num++, start_col_);
}
Expand Down Expand Up @@ -175,9 +174,9 @@ class block_
inline void set_view(int bottom_diagonal, int top_diagonal,
int bottom_zero_diagonal, int top_zero_diagonal) const {
int change = start_col_ - start_row_;
std::get<0>(arguments_)
.set_view(bottom_diagonal + change, top_diagonal + change,
bottom_zero_diagonal + change, top_zero_diagonal + change);
this->template get_arg<0>().set_view(
bottom_diagonal + change, top_diagonal + change,
bottom_zero_diagonal + change, top_zero_diagonal + change);
}

/**
Expand All @@ -186,7 +185,7 @@ class block_
*/
inline int bottom_diagonal() const {
return std::max(
std::get<0>(arguments_).bottom_diagonal() - start_col_ + start_row_,
this->template get_arg<0>().bottom_diagonal() - start_col_ + start_row_,
1 - rows_);
}

Expand All @@ -196,7 +195,7 @@ class block_
*/
inline int top_diagonal() const {
return std::min(
std::get<0>(arguments_).top_diagonal() - start_col_ + start_row_,
this->template get_arg<0>().top_diagonal() - start_col_ + start_row_,
cols_ - 1);
}

Expand Down Expand Up @@ -234,6 +233,14 @@ class block_

/**
* Block of a kernel generator expression.
*
* Block operation modifies how its argument is indexed. If a matrix is both an
* argument and result of such an operation (such as in <code> block(a, row1,
* col1, rows, cols) = block(a, row2, col2, rows, cols);
* </code>), the result can be wrong due to aliasing. In such case the
* expression should be evaluating in a temporary by doing <code> block(a, row1,
* col1, rows, cols) = block(a, row2, col2, rows, cols).eval();</code>. This is
* not necessary if the bolcks do not overlap or if they are the same block.
* @tparam T type of argument
* @param a input argument
* @param start_row first row of block
Expand Down
14 changes: 6 additions & 8 deletions stan/math/opencl/kernel_generator/calc_if.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class calc_if_
using base = operation_cl<calc_if_<Do_Calculate, T>, Scalar, T>;
using base::var_name;

protected:
using base::arguments_;

public:
/**
* Constructor
* @param a expression to calc_if
Expand Down Expand Up @@ -66,8 +62,8 @@ class calc_if_
const std::string& i, const std::string& j,
const T_result& result) const {
if (Do_Calculate) {
return std::get<0>(arguments_)
.get_whole_kernel_parts(generated, ng, i, j, result);
return this->template get_arg<0>().get_whole_kernel_parts(generated, ng,
i, j, result);
} else {
return {};
}
Expand All @@ -84,15 +80,17 @@ class calc_if_
inline void set_args(std::set<const operation_cl_base*>& generated,
cl::Kernel& kernel, int& arg_num) const {
if (Do_Calculate) {
std::get<0>(arguments_).set_args(generated, kernel, arg_num);
this->template get_arg<0>().set_args(generated, kernel, arg_num);
}
}

/**
* View of a matrix that would be the result of evaluating this expression.
* @return view
*/
inline matrix_cl_view view() const { return std::get<0>(arguments_).view(); }
inline matrix_cl_view view() const {
return this->template get_arg<0>().view();
}
};

template <bool Do_Calculate, typename T,
Expand Down
17 changes: 8 additions & 9 deletions stan/math/opencl/kernel_generator/colwise_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class colwise_reduction

protected:
std::string init_;
using base::arguments_;
using base::derived;

public:
Expand Down Expand Up @@ -121,15 +120,15 @@ class colwise_reduction
inline int rows() const {
int local_rows = opencl_context.base_opts().at("LOCAL_SIZE_");
int wgs_rows
= (std::get<0>(arguments_).rows() + local_rows - 1) / local_rows;
= (this->template get_arg<0>().rows() + local_rows - 1) / local_rows;
return wgs_rows;
}

/**
* Number of rows threads need to be launched for.
* @return number of rows
*/
inline int thread_rows() const { return std::get<0>(arguments_).rows(); }
inline int thread_rows() const { return this->template get_arg<0>().rows(); }

/**
* View of a matrix that would be the result of evaluating this expression.
Expand Down Expand Up @@ -161,8 +160,8 @@ class colwise_sum_ : public colwise_reduction<colwise_sum_<T>, T, sum_op> {
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down Expand Up @@ -209,8 +208,8 @@ class colwise_max_ : public colwise_reduction<
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down Expand Up @@ -257,8 +256,8 @@ class colwise_min_ : public colwise_reduction<
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() { return load_<T&>(a_); }
inline load_<T&> deep_copy() const& { return load_<T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }

/**
* generates kernel code for this expression.
Expand Down
16 changes: 1 addition & 15 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_generator/wrapper.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
Expand All @@ -18,21 +19,6 @@ namespace math {

namespace internal {

/**
* A wrapper for references. This is used to wrap references when putting them
* in tuples.
*/
template <typename T>
struct wrapper {
T x;
explicit wrapper(T&& x) : x(std::forward<T>(x)) {}
};

template <typename T>
wrapper<T> make_wrapper(T&& x) {
return wrapper<T>(std::forward<T>(x));
}

// Template parameter pack can only be at the end of the template list in
// structs. We need 2 packs for expressions and results, so we nest structs.
template <int n, typename... T_results>
Expand Down
Loading

0 comments on commit d903537

Please sign in to comment.