diff --git a/stan/math/opencl/cholesky_decompose.hpp b/stan/math/opencl/cholesky_decompose.hpp index 38ae51505fa..a30564f82de 100644 --- a/stan/math/opencl/cholesky_decompose.hpp +++ b/stan/math/opencl/cholesky_decompose.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include diff --git a/stan/math/opencl/kernel_generator.hpp b/stan/math/opencl/kernel_generator.hpp index 3d84b673d6b..6a76beb320f 100644 --- a/stan/math/opencl/kernel_generator.hpp +++ b/stan/math/opencl/kernel_generator.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include diff --git a/stan/math/opencl/kernel_generator/binary_operation.hpp b/stan/math/opencl/kernel_generator/binary_operation.hpp index 3063e8ce105..99b55a41334 100644 --- a/stan/math/opencl/kernel_generator/binary_operation.hpp +++ b/stan/math/opencl/kernel_generator/binary_operation.hpp @@ -81,8 +81,8 @@ class binary_operation : public operation_cl { * @return view */ inline matrix_cl_view view() const { - return either(std::get<0>(arguments_).view(), - std::get<1>(arguments_).view()); + return either(this->template get_arg<0>().view(), + this->template get_arg<1>().view()); } }; @@ -116,9 +116,9 @@ class binary_operation : public operation_cl { public: \ class_name(T_a&& a, T_b&& b) /* NOLINT */ \ : base(std::forward(a), std::forward(b), operation) {} \ - inline auto deep_copy() { \ - auto&& a_copy = std::get<0>(arguments_).deep_copy(); \ - auto&& b_copy = std::get<1>(arguments_).deep_copy(); \ + inline auto deep_copy() const { \ + auto&& a_copy = this->template get_arg<0>().deep_copy(); \ + auto&& b_copy = this->template get_arg<1>().deep_copy(); \ return class_name, \ std::remove_reference_t>( \ std::move(a_copy), std::move(b_copy)); \ @@ -163,9 +163,9 @@ class binary_operation : public operation_cl { public: \ class_name(T_a&& a, T_b&& b) /* NOLINT */ \ : base(std::forward(a), std::forward(b), operation) {} \ - inline auto deep_copy() { \ - auto&& a_copy = std::get<0>(arguments_).deep_copy(); \ - auto&& b_copy = std::get<1>(arguments_).deep_copy(); \ + inline auto deep_copy() const { \ + auto&& a_copy = this->template get_arg<0>().deep_copy(); \ + auto&& b_copy = this->template get_arg<1>().deep_copy(); \ return class_name, \ std::remove_reference_t>( \ std::move(a_copy), std::move(b_copy)); \ @@ -189,14 +189,14 @@ ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW( common_scalar_t, "*", using base = binary_operation, common_scalar_t, T_a, T_b>; - return both(std::get<0>(base::arguments_).view(), - std::get<1>(base::arguments_).view());); + return both(this->template get_arg<0>().view(), + this->template get_arg<1>().view());); ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW( elewise_division_, elewise_division, common_scalar_t, "/", using base = binary_operation, common_scalar_t, T_a, T_b>; - return either(std::get<0>(base::arguments_).view(), - invert(std::get<1>(base::arguments_).view()));); + return either(this->template get_arg<0>().view(), + invert(this->template get_arg<1>().view()));); ADD_BINARY_OPERATION(less_than_, operator<, bool, "<"); ADD_BINARY_OPERATION_WITH_CUSTOM_VIEW(less_than_or_equal_, operator<=, bool, "<=", return matrix_cl_view::Entire); diff --git a/stan/math/opencl/kernel_generator/block.hpp b/stan/math/opencl/kernel_generator/block.hpp index 6c02bfdf3a9..e28a82ecd86 100644 --- a/stan/math/opencl/kernel_generator/block.hpp +++ b/stan/math/opencl/kernel_generator/block.hpp @@ -34,7 +34,6 @@ class block_ protected: int start_row_, start_col_, rows_, cols_; - using base::arguments_; public: /** @@ -61,8 +60,8 @@ class block_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return block_>{ std::move(arg_copy), start_row_, start_col_, rows_, cols_}; } @@ -122,7 +121,7 @@ class block_ cl::Kernel& kernel, int& arg_num) const { if (generated.count(this) == 0) { generated.insert(this); - std::get<0>(arguments_).set_args(generated, kernel, arg_num); + this->template get_arg<0>().set_args(generated, kernel, arg_num); kernel.setArg(arg_num++, start_row_); kernel.setArg(arg_num++, start_col_); } @@ -175,9 +174,9 @@ class block_ inline void set_view(int bottom_diagonal, int top_diagonal, int bottom_zero_diagonal, int top_zero_diagonal) const { int change = start_col_ - start_row_; - std::get<0>(arguments_) - .set_view(bottom_diagonal + change, top_diagonal + change, - bottom_zero_diagonal + change, top_zero_diagonal + change); + this->template get_arg<0>().set_view( + bottom_diagonal + change, top_diagonal + change, + bottom_zero_diagonal + change, top_zero_diagonal + change); } /** @@ -186,7 +185,7 @@ class block_ */ inline int bottom_diagonal() const { return std::max( - std::get<0>(arguments_).bottom_diagonal() - start_col_ + start_row_, + this->template get_arg<0>().bottom_diagonal() - start_col_ + start_row_, 1 - rows_); } @@ -196,7 +195,7 @@ class block_ */ inline int top_diagonal() const { return std::min( - std::get<0>(arguments_).top_diagonal() - start_col_ + start_row_, + this->template get_arg<0>().top_diagonal() - start_col_ + start_row_, cols_ - 1); } @@ -234,6 +233,14 @@ class block_ /** * Block of a kernel generator expression. + * + * Block operation modifies how its argument is indexed. If a matrix is both an + * argument and result of such an operation (such as in block(a, row1, + * col1, rows, cols) = block(a, row2, col2, rows, cols); + * ), the result can be wrong due to aliasing. In such case the + * expression should be evaluating in a temporary by doing block(a, row1, + * col1, rows, cols) = block(a, row2, col2, rows, cols).eval();. This is + * not necessary if the bolcks do not overlap or if they are the same block. * @tparam T type of argument * @param a input argument * @param start_row first row of block diff --git a/stan/math/opencl/kernel_generator/calc_if.hpp b/stan/math/opencl/kernel_generator/calc_if.hpp index 2bea3e60e94..f5b8444e89e 100644 --- a/stan/math/opencl/kernel_generator/calc_if.hpp +++ b/stan/math/opencl/kernel_generator/calc_if.hpp @@ -31,10 +31,6 @@ class calc_if_ using base = operation_cl, Scalar, T>; using base::var_name; - protected: - using base::arguments_; - - public: /** * Constructor * @param a expression to calc_if @@ -66,8 +62,8 @@ class calc_if_ const std::string& i, const std::string& j, const T_result& result) const { if (Do_Calculate) { - return std::get<0>(arguments_) - .get_whole_kernel_parts(generated, ng, i, j, result); + return this->template get_arg<0>().get_whole_kernel_parts(generated, ng, + i, j, result); } else { return {}; } @@ -84,7 +80,7 @@ class calc_if_ inline void set_args(std::set& generated, cl::Kernel& kernel, int& arg_num) const { if (Do_Calculate) { - std::get<0>(arguments_).set_args(generated, kernel, arg_num); + this->template get_arg<0>().set_args(generated, kernel, arg_num); } } @@ -92,7 +88,9 @@ class calc_if_ * View of a matrix that would be the result of evaluating this expression. * @return view */ - inline matrix_cl_view view() const { return std::get<0>(arguments_).view(); } + inline matrix_cl_view view() const { + return this->template get_arg<0>().view(); + } }; template (arguments_).rows() + local_rows - 1) / local_rows; + = (this->template get_arg<0>().rows() + local_rows - 1) / local_rows; return wgs_rows; } @@ -129,7 +128,7 @@ class colwise_reduction * Number of rows threads need to be launched for. * @return number of rows */ - inline int thread_rows() const { return std::get<0>(arguments_).rows(); } + inline int thread_rows() const { return this->template get_arg<0>().rows(); } /** * View of a matrix that would be the result of evaluating this expression. @@ -161,8 +160,8 @@ class colwise_sum_ : public colwise_reduction, T, sum_op> { * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return colwise_sum_>( std::move(arg_copy)); } @@ -209,8 +208,8 @@ class colwise_max_ : public colwise_reduction< * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return colwise_max_>( std::move(arg_copy)); } @@ -257,8 +256,8 @@ class colwise_min_ : public colwise_reduction< * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return colwise_min_>( std::move(arg_copy)); } diff --git a/stan/math/opencl/kernel_generator/load.hpp b/stan/math/opencl/kernel_generator/load.hpp index 8ed6f4220e2..63836cc7a07 100644 --- a/stan/math/opencl/kernel_generator/load.hpp +++ b/stan/math/opencl/kernel_generator/load.hpp @@ -48,7 +48,8 @@ class load_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline load_ deep_copy() { return load_(a_); } + inline load_ deep_copy() const& { return load_(a_); } + inline load_ deep_copy() && { return load_(std::forward(a_)); } /** * generates kernel code for this expression. diff --git a/stan/math/opencl/kernel_generator/multi_result_kernel.hpp b/stan/math/opencl/kernel_generator/multi_result_kernel.hpp index 22ba13ec022..30ae07d941f 100644 --- a/stan/math/opencl/kernel_generator/multi_result_kernel.hpp +++ b/stan/math/opencl/kernel_generator/multi_result_kernel.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_OPENCL_KERNEL_GENERATOR_MULTI_RESULT_KERNEL_HPP #ifdef STAN_OPENCL +#include #include #include #include @@ -18,21 +19,6 @@ namespace math { namespace internal { -/** - * A wrapper for references. This is used to wrap references when putting them - * in tuples. - */ -template -struct wrapper { - T x; - explicit wrapper(T&& x) : x(std::forward(x)) {} -}; - -template -wrapper make_wrapper(T&& x) { - return wrapper(std::forward(x)); -} - // Template parameter pack can only be at the end of the template list in // structs. We need 2 packs for expressions and results, so we nest structs. template diff --git a/stan/math/opencl/kernel_generator/operation_cl.hpp b/stan/math/opencl/kernel_generator/operation_cl.hpp index 172b03ff0d2..b09b0579a33 100644 --- a/stan/math/opencl/kernel_generator/operation_cl.hpp +++ b/stan/math/opencl/kernel_generator/operation_cl.hpp @@ -3,6 +3,7 @@ #ifdef STAN_OPENCL #include +#include #include #include #include @@ -48,7 +49,7 @@ class operation_cl : public operation_cl_base { "operation_cl: all arguments to operation must be operations!"); protected: - std::tuple arguments_; + std::tuple...> arguments_; mutable std::string var_name; // name of the variable that holds result of // this operation in the kernel @@ -74,13 +75,22 @@ class operation_cl : public operation_cl_base { // value representing a not yet determined size static const int dynamic = -1; + /** + Returns an argument to this operation + @tparam N index of the argument + */ + template + const auto& get_arg() const { + return std::get(arguments_).x; + } + /** * Constructor * @param arguments Arguments of this expression that are also valid * expressions */ explicit operation_cl(Args&&... arguments) - : arguments_(std::forward(arguments)...) {} + : arguments_(internal::wrapper(std::forward(arguments))...) {} /** * Evaluates the expression. @@ -165,9 +175,8 @@ class operation_cl : public operation_cl_base { std::string j_arg = j; derived().modify_argument_indices(i_arg, j_arg); std::array args_parts = index_apply([&](auto... Is) { - return std::array{ - std::get(arguments_) - .get_kernel_parts(generated, name_gen, i_arg, j_arg)...}; + return std::array{this->get_arg().get_kernel_parts( + generated, name_gen, i_arg, j_arg)...}; }); res.initialization = std::accumulate(args_parts.begin(), args_parts.end(), std::string(), @@ -190,8 +199,7 @@ class operation_cl : public operation_cl_base { return a + b.args; }); kernel_parts my_part = index_apply([&](auto... Is) { - return this->derived().generate(i, j, - std::get(arguments_).var_name...); + return this->derived().generate(i, j, this->get_arg().var_name...); }); res.initialization += my_part.initialization; res.body = my_part.body_prefix + res.body + my_part.body; @@ -230,8 +238,7 @@ class operation_cl : public operation_cl_base { // expression. index_apply([&](auto... Is) { static_cast(std::initializer_list{ - (std::get(arguments_).set_args(generated, kernel, arg_num), - 0)...}); + (this->get_arg().set_args(generated, kernel, arg_num), 0)...}); }); } } @@ -243,7 +250,7 @@ class operation_cl : public operation_cl_base { inline void add_read_event(cl::Event& e) const { index_apply([&](auto... Is) { (void)std::initializer_list{ - (std::get(arguments_).add_read_event(e), 0)...}; + (this->get_arg().add_read_event(e), 0)...}; }); } @@ -255,7 +262,7 @@ class operation_cl : public operation_cl_base { inline int rows() const { return index_apply([&](auto... Is) { // assuming all non-dynamic sizes match - return std::max({std::get(arguments_).rows()...}); + return std::max({this->get_arg().rows()...}); }); } @@ -267,7 +274,7 @@ class operation_cl : public operation_cl_base { inline int cols() const { return index_apply([&](auto... Is) { // assuming all non-dynamic sizes match - return std::max({std::get(arguments_).cols()...}); + return std::max({this->get_arg().cols()...}); }); } @@ -293,7 +300,7 @@ class operation_cl : public operation_cl_base { inline int bottom_diagonal() const { return index_apply([&](auto... Is) { return std::min(std::initializer_list( - {std::get(arguments_).bottom_diagonal()...})); + {this->get_arg().bottom_diagonal()...})); }); } @@ -304,8 +311,8 @@ class operation_cl : public operation_cl_base { */ inline int top_diagonal() const { return index_apply([&](auto... Is) { - return std::max(std::initializer_list( - {std::get(arguments_).top_diagonal()...})); + return std::max( + std::initializer_list({this->get_arg().top_diagonal()...})); }); } }; diff --git a/stan/math/opencl/kernel_generator/operation_cl_lhs.hpp b/stan/math/opencl/kernel_generator/operation_cl_lhs.hpp index 9390bbdfbe2..1a1411a2e5f 100644 --- a/stan/math/opencl/kernel_generator/operation_cl_lhs.hpp +++ b/stan/math/opencl/kernel_generator/operation_cl_lhs.hpp @@ -24,7 +24,6 @@ class operation_cl_lhs : public operation_cl { protected: using base = operation_cl; static constexpr int N = sizeof...(Args); - using base::arguments_; using base::derived; public: @@ -47,11 +46,11 @@ class operation_cl_lhs : public operation_cl { } std::string i_arg = i; std::string j_arg = j; - this->derived().modify_argument_indices(i_arg, j_arg); + derived().modify_argument_indices(i_arg, j_arg); std::array args_parts = index_apply([&](auto... Is) { return std::array{ - std::get(this->arguments_) - .get_kernel_parts_lhs(generated, name_gen, i_arg, j_arg)...}; + this->template get_arg().get_kernel_parts_lhs(generated, name_gen, + i_arg, j_arg)...}; }); kernel_parts res{}; res.body = std::accumulate( @@ -66,7 +65,7 @@ class operation_cl_lhs : public operation_cl { }); kernel_parts my_part = index_apply([&](auto... Is) { return this->derived().generate_lhs( - i, j, std::get(this->arguments_).var_name...); + i, j, this->template get_arg().var_name...); }); res.body += my_part.body; res.args += my_part.args; @@ -91,9 +90,9 @@ class operation_cl_lhs : public operation_cl { int bottom_zero_diagonal, int top_zero_diagonal) const { index_apply([&](auto... Is) { (void)std::initializer_list{ - (std::get(this->arguments_) - .set_view(bottom_diagonal, top_diagonal, bottom_zero_diagonal, - top_zero_diagonal), + (this->template get_arg().set_view(bottom_diagonal, top_diagonal, + bottom_zero_diagonal, + top_zero_diagonal), 0)...}; }); } @@ -109,7 +108,7 @@ class operation_cl_lhs : public operation_cl { inline void check_assign_dimensions(int rows, int cols) const { index_apply([&](auto... Is) { (void)std::initializer_list{ - (std::get(this->arguments_).check_assign_dimensions(rows, cols), + (this->template get_arg().check_assign_dimensions(rows, cols), 0)...}; }); } @@ -121,7 +120,7 @@ class operation_cl_lhs : public operation_cl { inline void add_write_event(cl::Event& e) const { index_apply([&](auto... Is) { (void)std::initializer_list{ - (std::get(this->arguments_).add_write_event(e), 0)...}; + (this->template get_arg().add_write_event(e), 0)...}; }); } }; diff --git a/stan/math/opencl/kernel_generator/rowwise_reduction.hpp b/stan/math/opencl/kernel_generator/rowwise_reduction.hpp index e148c1717d7..eb87b6d7933 100644 --- a/stan/math/opencl/kernel_generator/rowwise_reduction.hpp +++ b/stan/math/opencl/kernel_generator/rowwise_reduction.hpp @@ -36,7 +36,6 @@ class rowwise_reduction protected: std::string init_; - using base::arguments_; public: /** @@ -98,9 +97,9 @@ class rowwise_reduction cl::Kernel& kernel, int& arg_num) const { if (generated.count(this) == 0) { generated.insert(this); - std::get<0>(arguments_).set_args(generated, kernel, arg_num); - kernel.setArg(arg_num++, std::get<0>(arguments_).view()); - kernel.setArg(arg_num++, std::get<0>(arguments_).cols()); + this->template get_arg<0>().set_args(generated, kernel, arg_num); + kernel.setArg(arg_num++, this->template get_arg<0>().view()); + kernel.setArg(arg_num++, this->template get_arg<0>().cols()); } } @@ -157,8 +156,8 @@ class rowwise_sum_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return rowwise_sum_>( std::move(arg_copy)); } @@ -225,8 +224,8 @@ class rowwise_max_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return rowwise_max_>( std::move(arg_copy)); } @@ -292,8 +291,8 @@ class rowwise_min_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& arg_copy = this->template get_arg<0>().deep_copy(); return rowwise_min_>( std::move(arg_copy)); } diff --git a/stan/math/opencl/kernel_generator/scalar.hpp b/stan/math/opencl/kernel_generator/scalar.hpp index c68ef7b38b5..b9a4937f30d 100644 --- a/stan/math/opencl/kernel_generator/scalar.hpp +++ b/stan/math/opencl/kernel_generator/scalar.hpp @@ -41,7 +41,7 @@ class scalar_ : public operation_cl, T> { * Creates a deep copy of this expression. * @return copy of \c *this */ - inline scalar_ deep_copy() { return scalar_(a_); } + inline scalar_ deep_copy() const { return scalar_(a_); } /** * generates kernel code for this expression. diff --git a/stan/math/opencl/kernel_generator/select.hpp b/stan/math/opencl/kernel_generator/select.hpp index 287fb8c626f..53d5629b589 100644 --- a/stan/math/opencl/kernel_generator/select.hpp +++ b/stan/math/opencl/kernel_generator/select.hpp @@ -37,10 +37,6 @@ class select_ : public operation_cl, T_condition, T_then, T_else>; using base::var_name; - protected: - using base::arguments_; - - public: /** * Constructor * @param condition condition expression @@ -73,10 +69,10 @@ class select_ : public operation_cl, * Creates a deep copy of this expression. * @return copy of \c *this */ - inline auto deep_copy() { - auto&& condition_copy = std::get<0>(arguments_).deep_copy(); - auto&& then_copy = std::get<0>(arguments_).deep_copy(); - auto&& else_copy = std::get<0>(arguments_).deep_copy(); + inline auto deep_copy() const { + auto&& condition_copy = this->template get_arg<0>().deep_copy(); + auto&& then_copy = this->template get_arg<1>().deep_copy(); + auto&& else_copy = this->template get_arg<2>().deep_copy(); return select_, std::remove_reference_t, std::remove_reference_t>( @@ -107,9 +103,9 @@ class select_ : public operation_cl, * @return view */ inline matrix_cl_view view() const { - matrix_cl_view condition_view = std::get<0>(arguments_).view(); - matrix_cl_view then_view = std::get<1>(arguments_).view(); - matrix_cl_view else_view = std::get<2>(arguments_).view(); + matrix_cl_view condition_view = this->template get_arg<0>().view(); + matrix_cl_view then_view = this->template get_arg<1>().view(); + matrix_cl_view else_view = this->template get_arg<2>().view(); return both(either(then_view, else_view), both(condition_view, then_view)); } }; diff --git a/stan/math/opencl/kernel_generator/transpose.hpp b/stan/math/opencl/kernel_generator/transpose.hpp new file mode 100644 index 00000000000..afa1359cbe1 --- /dev/null +++ b/stan/math/opencl/kernel_generator/transpose.hpp @@ -0,0 +1,136 @@ +#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_TRANSPOSE_HPP +#define STAN_MATH_OPENCL_KERNEL_GENERATOR_TRANSPOSE_HPP +#ifdef STAN_OPENCL + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Represents a transpose in kernel generator expressions. + * + * @tparam Derived derived type + * @tparam Arg type of the argument + */ +template +class transpose_ + : public operation_cl, + typename std::remove_reference_t::Scalar, Arg> { + public: + using Scalar = typename std::remove_reference_t::Scalar; + using base = operation_cl, Scalar, Arg>; + using base::var_name; + + /** + * Constructor + * @param a expression to transpose + */ + explicit transpose_(Arg&& a) : base(std::forward(a)) {} + + /** + * Creates a deep copy of this expression. + * @return copy of \c *this + */ + inline transpose_> deep_copy() const { + return transpose_>{ + this->template get_arg<0>().deep_copy()}; + } + + /** + * generates kernel code for this and nested expressions. + * @param[in,out] generated set of already generated operations + * @param ng name generator for this kernel + * @param i row index variable name + * @param j column index variable name + * @return part of kernel with code for this and nested expressions + */ + inline kernel_parts generate(const std::string& i, const std::string& j, + const std::string& var_name_arg) const { + var_name = var_name_arg; + return {}; + } + + /** + * Swaps indices \c i and \c j for the argument expression. + * @param[in, out] i row index + * @param[in, out] j column index + */ + inline void modify_argument_indices(std::string& i, std::string& j) const { + std::swap(i, j); + } + + /** + * Number of rows of a matrix that would be the result of evaluating this + * expression. + * @return number of rows + */ + inline int rows() const { return this->template get_arg<0>().cols(); } + + /** + * Number of columns of a matrix that would be the result of evaluating this + * expression. + * @return number of columns + */ + inline int cols() const { return this->template get_arg<0>().rows(); } + + /** + * View of a matrix that would be the result of evaluating this expression. + * @return view + */ + inline matrix_cl_view view() const { + return transpose(this->template get_arg<0>().view()); + } + + /** + * Determine index of bottom diagonal written. + * @return index of bottom diagonal + */ + inline int bottom_diagonal() const { + return -this->template get_arg<0>().top_diagonal(); + } + + /** + * Determine index of top diagonal written. + * @return index of top diagonal + */ + inline int top_diagonal() const { + return -this->template get_arg<0>().bottom_diagonal(); + } +}; + +/** + * Transposes a kernel generator expression. + * + * Transposition modifies how its argument is indexed. If a matrix is both an + * argument and result of such an operation (such as in a = transpose(a); + * ), the result will be wrong due to aliasing. In such case the + * expression should be evaluating in a temporary by doing a = + * transpose(a).eval();. + * @tparam Arg type of the argument expression. + * @param a argument to transposition + */ +template > +inline auto transpose(Arg&& a) { + auto&& a_operation = as_operation_cl(std::forward(a)).deep_copy(); + return transpose_>{ + std::move(a_operation)}; +} + +} // namespace math +} // namespace stan + +#endif +#endif diff --git a/stan/math/opencl/kernel_generator/unary_function_cl.hpp b/stan/math/opencl/kernel_generator/unary_function_cl.hpp index 70467bb2c5c..af1021ed6cb 100644 --- a/stan/math/opencl/kernel_generator/unary_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/unary_function_cl.hpp @@ -63,7 +63,7 @@ class unary_function_cl * @return view */ inline matrix_cl_view view() const { - return std::get<0>(base::arguments_).view(); + return this->template get_arg<0>().view(); } protected: @@ -83,8 +83,8 @@ class unary_function_cl \ public: \ explicit fun##_(T&& a) : base(std::forward(a), #fun) {} \ - inline auto deep_copy() { \ - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \ + inline auto deep_copy() const { \ + auto&& arg_copy = this->template get_arg<0>().deep_copy(); \ return fun##_>{ \ std::move(arg_copy)}; \ } \ @@ -111,8 +111,8 @@ class unary_function_cl \ public: \ explicit fun##_(T&& a) : base(std::forward(a), #fun) {} \ - inline auto deep_copy() { \ - auto&& arg_copy = std::get<0>(arguments_).deep_copy(); \ + inline auto deep_copy() const { \ + auto&& arg_copy = this->template get_arg<0>().deep_copy(); \ return fun##_>{ \ std::move(arg_copy)}; \ } \ diff --git a/stan/math/opencl/kernel_generator/wrapper.hpp b/stan/math/opencl/kernel_generator/wrapper.hpp new file mode 100644 index 00000000000..19446f4bab4 --- /dev/null +++ b/stan/math/opencl/kernel_generator/wrapper.hpp @@ -0,0 +1,28 @@ +#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_WRAPPER +#define STAN_MATH_OPENCL_KERNEL_GENERATOR_WRAPPER +#include + +namespace stan { +namespace math { +namespace internal { + +/** + * A wrapper for references. This is used to wrap references when putting them + * in tuples. + */ +template +struct wrapper { + T x; + explicit wrapper(T&& x) : x(std::forward(x)) {} +}; + +template +wrapper make_wrapper(T&& x) { + return wrapper(std::forward(x)); +} + +} // namespace internal +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/opencl/kernels/transpose.hpp b/stan/math/opencl/kernels/transpose.hpp deleted file mode 100644 index eb9ac598751..00000000000 --- a/stan/math/opencl/kernels/transpose.hpp +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef STAN_MATH_OPENCL_KERNELS_TRANSPOSE_HPP -#define STAN_MATH_OPENCL_KERNELS_TRANSPOSE_HPP -#ifdef STAN_OPENCL - -#include -#include -#include - -namespace stan { -namespace math { -namespace opencl_kernels { -// \cond -static const std::string transpose_kernel_code = STRINGIFY( - // \endcond - /** \ingroup opencl_kernels - * Takes the transpose of the matrix on the OpenCL device. - * - * @param[out] B The output matrix to hold transpose of A. - * @param[in] A The input matrix to transpose into B. - * @param rows The number of rows for A. - * @param cols The number of columns for A. - * @note Code is a const char* held in - * transpose_kernel_code. - * This kernel uses the helper macros available in helpers.cl. - */ - __kernel void transpose(__global double *B, __global double *A, - unsigned int rows, unsigned int cols) { - int i = get_global_id(0); - int j = get_global_id(1); - if (i < rows && j < cols) { - BT(j, i) = A(i, j); - } - } - // \cond -); -// \endcond - -/** \ingroup opencl_kernels - * See the docs for \link kernels/transpose.hpp transpose() \endlink - */ -const kernel_cl transpose( - "transpose", {indexing_helpers, transpose_kernel_code}); - -} // namespace opencl_kernels -} // namespace math -} // namespace stan -#endif -#endif diff --git a/stan/math/opencl/opencl.hpp b/stan/math/opencl/opencl.hpp index e0212b5946d..9eccd20525b 100644 --- a/stan/math/opencl/opencl.hpp +++ b/stan/math/opencl/opencl.hpp @@ -43,7 +43,6 @@ #include #include #include -#include #include diff --git a/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp b/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp index 05b2f5351d4..d1c36a19233 100644 --- a/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp +++ b/stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp @@ -12,7 +12,6 @@ #include #include #include -#include namespace stan { namespace math { diff --git a/stan/math/opencl/prim/transpose.hpp b/stan/math/opencl/prim/transpose.hpp deleted file mode 100644 index 140fad8ab53..00000000000 --- a/stan/math/opencl/prim/transpose.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef STAN_MATH_OPENCL_PRIM_TRANSPOSE_HPP -#define STAN_MATH_OPENCL_PRIM_TRANSPOSE_HPP -#ifdef STAN_OPENCL - -#include -#include -#include -#include - -#include - -namespace stan { -namespace math { -/** \ingroup opencl - * Takes the transpose of the matrix on the OpenCL device. - * - * @param src the input matrix - * - * @return transposed input matrix - * - */ -template > -inline matrix_cl transpose(const matrix_cl& src) { - matrix_cl dst(src.cols(), src.rows(), transpose(src.view())); - if (dst.size() == 0) { - return dst; - } - try { - opencl_kernels::transpose(cl::NDRange(src.rows(), src.cols()), dst, src, - src.rows(), src.cols()); - } catch (const cl::Error& e) { - check_opencl_error("transpose", e); - } - return dst; -} -} // namespace math -} // namespace stan - -#endif -#endif diff --git a/stan/math/opencl/tri_inverse.hpp b/stan/math/opencl/tri_inverse.hpp index ba2b09dcfbb..48abb2126ae 100644 --- a/stan/math/opencl/tri_inverse.hpp +++ b/stan/math/opencl/tri_inverse.hpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include #include @@ -81,7 +81,7 @@ inline matrix_cl tri_inverse(const matrix_cl& A) { zero_mat.template zeros(); inv_padded.template zeros(); if (tri_view == matrix_cl_view::Upper) { - inv_mat = transpose(inv_mat); + inv_mat = transpose(inv_mat).eval(); } int work_per_thread = opencl_kernels::inv_lower_tri_multiply.make_functor.get_opts().at( @@ -109,7 +109,7 @@ inline matrix_cl tri_inverse(const matrix_cl& A) { if (parts == 1) { inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows()); if (tri_view == matrix_cl_view::Upper) { - inv_mat = transpose(inv_mat); + inv_mat = transpose(inv_mat).eval(); } return inv_mat; } @@ -151,7 +151,7 @@ inline matrix_cl tri_inverse(const matrix_cl& A) { // un-pad and return inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows()); if (tri_view == matrix_cl_view::Upper) { - inv_mat = transpose(inv_mat); + inv_mat = transpose(inv_mat).eval(); } inv_mat.view(tri_view); return inv_mat; diff --git a/test/unit/math/opencl/kernel_cl_test.cpp b/test/unit/math/opencl/kernel_cl_test.cpp index b8eb61dca70..dfafc117e9b 100644 --- a/test/unit/math/opencl/kernel_cl_test.cpp +++ b/test/unit/math/opencl/kernel_cl_test.cpp @@ -10,15 +10,13 @@ EXPECT_NEAR(A(i), B(i), DELTA); TEST(MathGpu, make_kernel) { - stan::math::matrix_d m0(3, 3); - stan::math::matrix_d m0_dst(3, 3); - m0 << 1, 2, 3, 4, 5, 6, 7, 8, 9; - stan::math::matrix_cl m00(m0); - - stan::math::matrix_cl m00_dst(m0.cols(), m0.rows()); - stan::math::opencl_kernels::transpose(cl::NDRange(m00.rows(), m00.cols()), - m00_dst, m00, m00.rows(), m00.cols()); - m0_dst = stan::math::from_matrix_cl(m00_dst); + stan::math::matrix_d m(3, 3); + + stan::math::matrix_cl m_cl(m.cols(), m.rows()); + stan::math::opencl_kernels::fill(cl::NDRange(m_cl.rows(), m_cl.cols()), m_cl, + 0, m_cl.rows(), m_cl.cols(), + stan::math::matrix_cl_view::Entire); + m = stan::math::from_matrix_cl(m_cl); } TEST(MathGpu, write_after_write) { diff --git a/test/unit/math/opencl/kernel_generator/reference_kernels/a+aT.cl b/test/unit/math/opencl/kernel_generator/reference_kernels/a+aT.cl new file mode 100644 index 00000000000..ef28073fc25 --- /dev/null +++ b/test/unit/math/opencl/kernel_generator/reference_kernels/a+aT.cl @@ -0,0 +1,10 @@ +kernel void calculate(__global double* var3_global, int var3_rows, int var3_view, int var4, __global double* var7_global, int var7_rows, int var7_view, int var8, __global double* var9_global, int var9_rows, int var9_view){ +int i = get_global_id(0); +int j = get_global_id(1); +double var3 = 0; if (!((!contains_nonzero(var3_view, LOWER) && j < i) || (!contains_nonzero(var3_view, UPPER) && j > i))) {var3 = var3_global[i + var3_rows * j];} +double var2 = var3 + var4; +double var7 = 0; if (!((!contains_nonzero(var7_view, LOWER) && i < j) || (!contains_nonzero(var7_view, UPPER) && i > j))) {var7 = var7_global[j + var7_rows * i];} +double var6 = var7 + var8; +double var1 = var2 + var6; +var9_global[i + var9_rows * j] = var1; +} diff --git a/test/unit/math/opencl/kernel_generator/transpose_test.cpp b/test/unit/math/opencl/kernel_generator/transpose_test.cpp new file mode 100644 index 00000000000..9baaec1577d --- /dev/null +++ b/test/unit/math/opencl/kernel_generator/transpose_test.cpp @@ -0,0 +1,191 @@ +#ifdef STAN_OPENCL +#include +#include +#include +#include +#include +#include +#include +#include + +#define EXPECT_MATRIX_NEAR(A, B, DELTA) \ + for (int i = 0; i < A.size(); i++) \ + EXPECT_NEAR(A(i), B(i), DELTA); + +TEST(MathMatrixCL, transpose_rvalue_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + auto tmp = stan::math::transpose(stan::math::to_matrix_cl(m)); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = stan::math::transpose(m); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, transpose_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + matrix_cl m_cl(m); + auto tmp = stan::math::transpose(m_cl); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = stan::math::transpose(m); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, transpose_triangular_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + matrix_cl m_cl(m, stan::math::matrix_cl_view::Upper); + auto tmp = stan::math::transpose(m_cl); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = stan::math::transpose(m).triangularView(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Lower, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, double_transpose_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + matrix_cl m_cl(m); + auto tmp = stan::math::transpose(stan::math::transpose(m_cl)); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m; + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, double_transpose_triangular_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + matrix_cl m_cl(m, stan::math::matrix_cl_view::Upper); + auto tmp = stan::math::transpose(stan::math::transpose(m_cl)); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m.triangularView(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Upper, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, double_transpose_accepts_lvalue_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + + matrix_cl m_cl(m); + auto tmp2 = stan::math::transpose(m_cl); + auto tmp = stan::math::transpose(tmp2); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m; + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, transpose_block_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(5, 5); + m << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25; + + matrix_cl m_cl(m); + auto tmp2 = stan::math::block(m_cl, 2, 1, 2, 3); + auto tmp = stan::math::transpose(tmp2); + matrix_cl res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m.block(2, 1, 2, 3).transpose(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, block_of_transpose_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + MatrixXd m(5, 5); + m << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25; + + matrix_cl m_cl(m); + auto tmp = stan::math::transpose(m_cl); + auto tmp2 = stan::math::block(tmp, 2, 1, 2, 3); + matrix_cl res_cl = tmp2; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m.transpose().block(2, 1, 2, 3); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +TEST(MathMatrixCL, a_plus_a_transpose_test) { + using Eigen::MatrixXd; + using stan::math::matrix_cl; + std::string kernel_filename = "a+aT.cl"; + MatrixXd m(3, 3); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9; + + matrix_cl m_cl(m); + auto tmp2 = m_cl + 1; + auto tmp = tmp2 + stan::math::transpose(tmp2); + matrix_cl res_cl; + + std::string kernel_src = tmp.get_kernel_source_for_evaluating_into(res_cl); + stan::test::store_reference_kernel_if_needed(kernel_filename, kernel_src); + std::string expected_kernel_src + = stan::test::load_reference_kernel(kernel_filename); + EXPECT_EQ(expected_kernel_src, kernel_src); + + res_cl = tmp; + + MatrixXd res = stan::math::from_matrix_cl(res_cl); + MatrixXd correct = m.array() + 2 + m.array().transpose(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_EQ(stan::math::matrix_cl_view::Entire, res_cl.view()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + +#endif