Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transposition to kernel generator #1769

Merged
merged 22 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b3c5a6b
added transpose to kernel_generator
t4c1 Mar 10, 2020
969c7dc
removed existing transpose kernel
t4c1 Mar 10, 2020
b3196f5
fix cpplint
t4c1 Mar 10, 2020
f931e1d
replaced transpose with another kernel in kernel_cl_test.cpp
t4c1 Mar 10, 2020
7296a2e
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 10, 2020
f101f37
Merge branch 'develop' into cl_kernel_generator_transpose
t4c1 Mar 11, 2020
94a4b11
addressed review comments
t4c1 Mar 13, 2020
a65bb77
Merge branch 'develop' into cl_kernel_generator_transpose
t4c1 Mar 13, 2020
987e3d5
Fixed the bug and enabled operation_cl to store arguments by reference
t4c1 Mar 17, 2020
a994e42
Merge commit 'a0e4ba0290262f0b5ce962648cd3aa50ae61d08b' into HEAD
yashikno Mar 17, 2020
a72d379
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 17, 2020
5b425f8
fix headers check
t4c1 Mar 17, 2020
f90ab06
fix view calculation of transpose
t4c1 Mar 18, 2020
92f0b01
Merge commit '3a66a331aecce071ebc402dcaa3213cff075feec' into HEAD
yashikno Mar 18, 2020
4a0c836
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 18, 2020
60896bc
fix cpplint
t4c1 Mar 18, 2020
57712ee
fixed aliasing in tri_inverse
t4c1 Mar 19, 2020
2e8796d
Merge commit 'f3cbe214b3da41316694673f715a434e27a9e6d0' into HEAD
yashikno Mar 19, 2020
805a293
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 19, 2020
bf60476
fixed transpose test
t4c1 Mar 19, 2020
f1cbd9e
Fixed usings in test
t4c1 Mar 20, 2020
9174755
added docs about aliasing
t4c1 Mar 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion stan/math/opencl/cholesky_decompose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/kernels/cholesky_decompose.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/opencl/prim/transpose.hpp>
#include <stan/math/prim/meta.hpp>
#include <CL/cl2.hpp>
#include <algorithm>
Expand Down
1 change: 1 addition & 0 deletions stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <stan/math/opencl/kernel_generator/select.hpp>
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/colwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/transpose.hpp>

#include <stan/math/opencl/kernel_generator/multi_result_kernel.hpp>
#include <stan/math/opencl/kernel_generator/get_kernel_source_for_evaluating_into.hpp>
Expand Down
24 changes: 12 additions & 12 deletions stan/math/opencl/kernel_generator/binary_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
* @return view
*/
inline matrix_cl_view view() const {
return either(std::get<0>(arguments_).view(),
std::get<1>(arguments_).view());
return either(this->template get_arg<0>().view(),
this->template get_arg<1>().view());
}
};

Expand Down Expand Up @@ -116,9 +116,9 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
public: \
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
inline auto deep_copy() { \
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
inline auto deep_copy() const { \
auto&& a_copy = this->template get_arg<0>().deep_copy(); \
auto&& b_copy = this->template get_arg<1>().deep_copy(); \
return class_name<std::remove_reference_t<decltype(a_copy)>, \
std::remove_reference_t<decltype(b_copy)>>( \
std::move(a_copy), std::move(b_copy)); \
Expand Down Expand Up @@ -163,9 +163,9 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
public: \
class_name(T_a&& a, T_b&& b) /* NOLINT */ \
: base(std::forward<T_a>(a), std::forward<T_b>(b), operation) {} \
inline auto deep_copy() { \
auto&& a_copy = std::get<0>(arguments_).deep_copy(); \
auto&& b_copy = std::get<1>(arguments_).deep_copy(); \
inline auto deep_copy() const { \
auto&& a_copy = this->template get_arg<0>().deep_copy(); \
auto&& b_copy = this->template get_arg<1>().deep_copy(); \
return class_name<std::remove_reference_t<decltype(a_copy)>, \
std::remove_reference_t<decltype(b_copy)>>( \
std::move(a_copy), std::move(b_copy)); \
Expand All @@ -189,14 +189,14 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(
common_scalar_t<T_a COMMA T_b>, "*",
using base = binary_operation<elewise_multiplication_<T_a, T_b>,
common_scalar_t<T_a, T_b>, T_a, T_b>;
return both(std::get<0>(base::arguments_).view(),
std::get<1>(base::arguments_).view()););
return both(this->template get_arg<0>().view(),
this->template get_arg<1>().view()););
ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(
elewise_division_, elewise_division, common_scalar_t<T_a COMMA T_b>, "/",
using base = binary_operation<elewise_division_<T_a, T_b>,
common_scalar_t<T_a, T_b>, T_a, T_b>;
return either(std::get<0>(base::arguments_).view(),
invert(std::get<1>(base::arguments_).view())););
return either(this->template get_arg<0>().view(),
invert(this->template get_arg<1>().view())););
ADD_BINARY_OPERATION(less_than_, operator<, bool, "<");
ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(less_than_or_equal_, operator<=, bool,
"<=", return matrix_cl_view::Entire);
Expand Down
17 changes: 8 additions & 9 deletions stan/math/opencl/kernel_generator/block.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class block_

protected:
int start_row_, start_col_, rows_, cols_;
using base::arguments_;

public:
/**
Expand All @@ -61,8 +60,8 @@ class block_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return block_<std::remove_reference_t<decltype(arg_copy)>>{
std::move(arg_copy), start_row_, start_col_, rows_, cols_};
}
Expand Down Expand Up @@ -122,7 +121,7 @@ class block_
cl::Kernel& kernel, int& arg_num) const {
if (generated.count(this) == 0) {
generated.insert(this);
std::get<0>(arguments_).set_args(generated, kernel, arg_num);
this->template get_arg<0>().set_args(generated, kernel, arg_num);
kernel.setArg(arg_num++, start_row_);
kernel.setArg(arg_num++, start_col_);
}
Expand Down Expand Up @@ -175,9 +174,9 @@ class block_
inline void set_view(int bottom_diagonal, int top_diagonal,
int bottom_zero_diagonal, int top_zero_diagonal) const {
int change = start_col_ - start_row_;
std::get<0>(arguments_)
.set_view(bottom_diagonal + change, top_diagonal + change,
bottom_zero_diagonal + change, top_zero_diagonal + change);
this->template get_arg<0>().set_view(
bottom_diagonal + change, top_diagonal + change,
bottom_zero_diagonal + change, top_zero_diagonal + change);
}

/**
Expand All @@ -186,7 +185,7 @@ class block_
*/
inline int bottom_diagonal() const {
return std::max(
std::get<0>(arguments_).bottom_diagonal() - start_col_ + start_row_,
this->template get_arg<0>().bottom_diagonal() - start_col_ + start_row_,
1 - rows_);
}

Expand All @@ -196,7 +195,7 @@ class block_
*/
inline int top_diagonal() const {
return std::min(
std::get<0>(arguments_).top_diagonal() - start_col_ + start_row_,
this->template get_arg<0>().top_diagonal() - start_col_ + start_row_,
cols_ - 1);
}

Expand Down
14 changes: 6 additions & 8 deletions stan/math/opencl/kernel_generator/calc_if.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class calc_if_
using base = operation_cl<calc_if_<Do_Calculate, T>, Scalar, T>;
using base::var_name;

protected:
using base::arguments_;

public:
/**
* Constructor
* @param a expression to calc_if
Expand Down Expand Up @@ -66,8 +62,8 @@ class calc_if_
const std::string& i, const std::string& j,
const T_result& result) const {
if (Do_Calculate) {
return std::get<0>(arguments_)
.get_whole_kernel_parts(generated, ng, i, j, result);
return this->template get_arg<0>().get_whole_kernel_parts(generated, ng,
i, j, result);
} else {
return {};
}
Expand All @@ -84,15 +80,17 @@ class calc_if_
inline void set_args(std::set<const operation_cl_base*>& generated,
cl::Kernel& kernel, int& arg_num) const {
if (Do_Calculate) {
std::get<0>(arguments_).set_args(generated, kernel, arg_num);
this->template get_arg<0>().set_args(generated, kernel, arg_num);
}
}

/**
* View of a matrix that would be the result of evaluating this expression.
* @return view
*/
inline matrix_cl_view view() const { return std::get<0>(arguments_).view(); }
inline matrix_cl_view view() const {
return this->template get_arg<0>().view();
}
};

template <bool Do_Calculate, typename T,
Expand Down
17 changes: 8 additions & 9 deletions stan/math/opencl/kernel_generator/colwise_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class colwise_reduction

protected:
std::string init_;
using base::arguments_;
using base::derived;

public:
Expand Down Expand Up @@ -121,15 +120,15 @@ class colwise_reduction
inline int rows() const {
int local_rows = opencl_context.base_opts().at("LOCAL_SIZE_");
int wgs_rows
= (std::get<0>(arguments_).rows() + local_rows - 1) / local_rows;
= (this->template get_arg<0>().rows() + local_rows - 1) / local_rows;
return wgs_rows;
}

/**
* Number of rows threads need to be launched for.
* @return number of rows
*/
inline int thread_rows() const { return std::get<0>(arguments_).rows(); }
inline int thread_rows() const { return this->template get_arg<0>().rows(); }

/**
* View of a matrix that would be the result of evaluating this expression.
Expand Down Expand Up @@ -161,8 +160,8 @@ class colwise_sum_ : public colwise_reduction<colwise_sum_<T>, T, sum_op> {
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down Expand Up @@ -209,8 +208,8 @@ class colwise_max_ : public colwise_reduction<
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down Expand Up @@ -257,8 +256,8 @@ class colwise_min_ : public colwise_reduction<
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() {
auto&& arg_copy = std::get<0>(arguments_).deep_copy();
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return colwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
std::move(arg_copy));
}
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() { return load_<T&>(a_); }
inline load_<T&> deep_copy() const& { return load_<T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }

/**
* generates kernel code for this expression.
Expand Down
16 changes: 1 addition & 15 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_generator/wrapper.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
Expand All @@ -18,21 +19,6 @@ namespace math {

namespace internal {

/**
* A wrapper for references. This is used to wrap references when putting them
* in tuples.
*/
template <typename T>
struct wrapper {
T x;
explicit wrapper(T&& x) : x(std::forward<T>(x)) {}
};

template <typename T>
wrapper<T> make_wrapper(T&& x) {
return wrapper<T>(std::forward<T>(x));
}

// Template parameter pack can only be at the end of the template list in
// structs. We need 2 packs for expressions and results, so we nest structs.
template <int n, typename... T_results>
Expand Down
37 changes: 22 additions & 15 deletions stan/math/opencl/kernel_generator/operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#ifdef STAN_OPENCL

#include <stan/math/prim/meta.hpp>
#include <stan/math/opencl/kernel_generator/wrapper.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
Expand Down Expand Up @@ -48,7 +49,7 @@ class operation_cl : public operation_cl_base {
"operation_cl: all arguments to operation must be operations!");

protected:
std::tuple<Args...> arguments_;
std::tuple<internal::wrapper<Args>...> arguments_;
mutable std::string var_name; // name of the variable that holds result of
// this operation in the kernel

Expand All @@ -74,13 +75,22 @@ class operation_cl : public operation_cl_base {
// value representing a not yet determined size
static const int dynamic = -1;

/**
Returns an argument to this operation
@tparam N index of the argument
*/
template <size_t N>
const auto& get_arg() const {
return std::get<N>(arguments_).x;
}

/**
* Constructor
* @param arguments Arguments of this expression that are also valid
* expressions
*/
explicit operation_cl(Args&&... arguments)
: arguments_(std::forward<Args>(arguments)...) {}
: arguments_(internal::wrapper<Args>(std::forward<Args>(arguments))...) {}

/**
* Evaluates the expression.
Expand Down Expand Up @@ -165,9 +175,8 @@ class operation_cl : public operation_cl_base {
std::string j_arg = j;
derived().modify_argument_indices(i_arg, j_arg);
std::array<kernel_parts, N> args_parts = index_apply<N>([&](auto... Is) {
return std::array<kernel_parts, N>{
std::get<Is>(arguments_)
.get_kernel_parts(generated, name_gen, i_arg, j_arg)...};
return std::array<kernel_parts, N>{this->get_arg<Is>().get_kernel_parts(
generated, name_gen, i_arg, j_arg)...};
});
res.initialization
= std::accumulate(args_parts.begin(), args_parts.end(), std::string(),
Expand All @@ -190,8 +199,7 @@ class operation_cl : public operation_cl_base {
return a + b.args;
});
kernel_parts my_part = index_apply<N>([&](auto... Is) {
return this->derived().generate(i, j,
std::get<Is>(arguments_).var_name...);
return this->derived().generate(i, j, this->get_arg<Is>().var_name...);
});
res.initialization += my_part.initialization;
res.body = my_part.body_prefix + res.body + my_part.body;
Expand Down Expand Up @@ -230,8 +238,7 @@ class operation_cl : public operation_cl_base {
// expression.
index_apply<N>([&](auto... Is) {
static_cast<void>(std::initializer_list<int>{
(std::get<Is>(arguments_).set_args(generated, kernel, arg_num),
0)...});
(this->get_arg<Is>().set_args(generated, kernel, arg_num), 0)...});
});
}
}
Expand All @@ -243,7 +250,7 @@ class operation_cl : public operation_cl_base {
inline void add_read_event(cl::Event& e) const {
index_apply<N>([&](auto... Is) {
(void)std::initializer_list<int>{
(std::get<Is>(arguments_).add_read_event(e), 0)...};
(this->get_arg<Is>().add_read_event(e), 0)...};
});
}

Expand All @@ -255,7 +262,7 @@ class operation_cl : public operation_cl_base {
inline int rows() const {
return index_apply<N>([&](auto... Is) {
// assuming all non-dynamic sizes match
return std::max({std::get<Is>(arguments_).rows()...});
return std::max({this->get_arg<Is>().rows()...});
});
}

Expand All @@ -267,7 +274,7 @@ class operation_cl : public operation_cl_base {
inline int cols() const {
return index_apply<N>([&](auto... Is) {
// assuming all non-dynamic sizes match
return std::max({std::get<Is>(arguments_).cols()...});
return std::max({this->get_arg<Is>().cols()...});
});
}

Expand All @@ -293,7 +300,7 @@ class operation_cl : public operation_cl_base {
inline int bottom_diagonal() const {
return index_apply<N>([&](auto... Is) {
return std::min(std::initializer_list<int>(
{std::get<Is>(arguments_).bottom_diagonal()...}));
{this->get_arg<Is>().bottom_diagonal()...}));
});
}

Expand All @@ -304,8 +311,8 @@ class operation_cl : public operation_cl_base {
*/
inline int top_diagonal() const {
return index_apply<N>([&](auto... Is) {
return std::max(std::initializer_list<int>(
{std::get<Is>(arguments_).top_diagonal()...}));
return std::max(
std::initializer_list<int>({this->get_arg<Is>().top_diagonal()...}));
});
}
};
Expand Down
Loading