Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: make kernel generator wait for events on matrices #1796

Merged
merged 15 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <utility>
#include <set>
#include <vector>

namespace stan {
namespace math {
Expand Down Expand Up @@ -48,7 +49,7 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() const & { return load_<T&>(a_); }
inline load_<T&> deep_copy() const& { return load_<T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }

/**
Expand Down Expand Up @@ -120,6 +121,31 @@ class load_
*/
inline void add_write_event(cl::Event& e) const { a_.add_write_event(e); }

/**
* Adds all read and write events on the matrix used by this expression to a
* list and clears them from the matrix.
* @param[out] events List of all events.
*/
inline void get_clear_read_write_events(
std::vector<cl::Event>& events) const {
events.insert(events.end(), a_.read_events().begin(),
a_.read_events().end());
events.insert(events.end(), a_.write_events().begin(),
a_.write_events().end());
a_.clear_read_write_events();
}

/**
* Adds all write events on the matrix used by this expression to a list and
* clears them from the matrix.
* @param[out] events List of all events.
*/
inline void get_clear_write_events(std::vector<cl::Event>& events) const {
events.insert(events.end(), a_.write_events().begin(),
a_.write_events().end());
a_.clear_write_events();
}

/**
* Number of rows of a matrix that would be the result of evaluating this
* expression.
Expand Down
27 changes: 25 additions & 2 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <tuple>
#include <utility>
#include <set>
#include <vector>

namespace stan {
namespace math {
Expand All @@ -32,6 +33,21 @@ struct multi_result_kernel_internal {
std::tuple_element_t<n, std::tuple<T_results...>>>;
using T_current_expression = std::remove_reference_t<
std::tuple_element_t<n, std::tuple<T_expressions...>>>;
/**
* Generates list of all events kernel assigning expressions to results must
* wait on. Also clears those events from matrices.
* @param[out] events list of events
* @param results results
* @param expressions expressions
*/
static void get_clear_events(
std::vector<cl::Event>& events,
const std::tuple<wrapper<T_results>...>& results,
const std::tuple<wrapper<T_expressions>...>& expressions) {
next::get_clear_events(events, results, expressions);
std::get<n>(expressions).x.get_clear_write_events(events);
std::get<n>(results).x.get_clear_read_write_events(events);
}
/**
* Assigns the dimensions of expressions to matching results if possible.
* Otherwise checks that dimensions match. Also checks that all expressions
Expand Down Expand Up @@ -139,6 +155,11 @@ template <typename... T_results>
struct multi_result_kernel_internal<-1, T_results...> {
template <typename... T_expressions>
struct inner {
static void get_clear_events(
std::vector<cl::Event>& events,
const std::tuple<wrapper<T_results>...>& results,
const std::tuple<wrapper<T_expressions>...>& expressions) {}

static void check_assign_dimensions(
int n_rows, int n_cols,
const std::tuple<wrapper<T_results>...>& results,
Expand Down Expand Up @@ -406,6 +427,8 @@ class results_cl {
std::set<const operation_cl_base*> generated;
impl::set_args(generated, kernel, arg_num, results, expressions);

std::vector<cl::Event> events;
impl::get_clear_events(events, results, expressions);
cl::Event e;
if (require_specific_local_size) {
kernel.setArg(arg_num++, n_rows);
Expand All @@ -416,11 +439,11 @@ class results_cl {

opencl_context.queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(local * wgs_rows, wgs_cols),
cl::NDRange(local, 1), nullptr, &e);
cl::NDRange(local, 1), &events, &e);
} else {
opencl_context.queue().enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange(n_rows, n_cols),
cl::NullRange, nullptr, &e);
cl::NullRange, &events, &e);
}
impl::add_event(e, results, expressions);
} catch (cl::Error e) {
Expand Down
17 changes: 15 additions & 2 deletions stan/math/opencl/kernel_generator/operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <set>
#include <array>
#include <numeric>
#include <vector>

namespace stan {
namespace math {
Expand Down Expand Up @@ -249,8 +250,20 @@ class operation_cl : public operation_cl_base {
*/
inline void add_read_event(cl::Event& e) const {
index_apply<N>([&](auto... Is) {
(void)std::initializer_list<int>{
(this->get_arg<Is>().add_read_event(e), 0)...};
static_cast<void>(std::initializer_list<int>{
(this->get_arg<Is>().add_read_event(e), 0)...});
});
}

/**
* Adds all write events on any matrices used by nested expressions to a list
* and clears them from those matrices.
* @param[out] events List of all events.
*/
inline void get_clear_write_events(std::vector<cl::Event>& events) const {
index_apply<N>([&](auto... Is) {
static_cast<void>(std::initializer_list<int>{
(this->template get_arg<Is>().get_clear_write_events(events), 0)...});
});
}

Expand Down
27 changes: 21 additions & 6 deletions stan/math/opencl/kernel_generator/operation_cl_lhs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <set>
#include <array>
#include <numeric>
#include <vector>

namespace stan {
namespace math {
Expand Down Expand Up @@ -89,11 +90,11 @@ class operation_cl_lhs : public operation_cl<Derived, Scalar, Args...> {
inline void set_view(int bottom_diagonal, int top_diagonal,
int bottom_zero_diagonal, int top_zero_diagonal) const {
index_apply<N>([&](auto... Is) {
(void)std::initializer_list<int>{
static_cast<void>(std::initializer_list<int>{
(this->template get_arg<Is>().set_view(bottom_diagonal, top_diagonal,
bottom_zero_diagonal,
top_zero_diagonal),
0)...};
0)...});
});
}

Expand All @@ -107,9 +108,9 @@ class operation_cl_lhs : public operation_cl<Derived, Scalar, Args...> {
*/
inline void check_assign_dimensions(int rows, int cols) const {
index_apply<N>([&](auto... Is) {
(void)std::initializer_list<int>{
static_cast<void>(std::initializer_list<int>{
(this->template get_arg<Is>().check_assign_dimensions(rows, cols),
0)...};
0)...});
});
}

Expand All @@ -119,8 +120,22 @@ class operation_cl_lhs : public operation_cl<Derived, Scalar, Args...> {
*/
inline void add_write_event(cl::Event& e) const {
index_apply<N>([&](auto... Is) {
(void)std::initializer_list<int>{
(this->template get_arg<Is>().add_write_event(e), 0)...};
static_cast<void>(std::initializer_list<int>{
(this->template get_arg<Is>().add_write_event(e), 0)...});
});
}

/**
* Adds all read and write events on any matrices used by nested expressions
* to a list and clears them from those matrices.
* @param[out] events List of all events.
*/
inline void get_clear_read_write_events(
std::vector<cl::Event>& events) const {
index_apply<N>([&](auto... Is) {
static_cast<void>(std::initializer_list<int>{
(this->template get_arg<Is>().get_clear_read_write_events(events),
0)...});
});
}
};
Expand Down
38 changes: 38 additions & 0 deletions test/unit/math/opencl/kernel_generator/operation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,42 @@ TEST(MathMatrixCL, kernel_caching) {
EXPECT_EQ(unused_cache::kernel_(), nullptr);
}

TEST(MathMatrixCL, events_write_after_write) {
using stan::math::matrix_cl;
matrix_cl<double> zero_cl(3, 3);
zero_cl.zeros();
zero_cl.wait_for_read_write_events();

for (int j = 0; j < 3000; j++) {
matrix_cl<double> m_cl(3, 3);

for (int i = 0; i < 4; i++) {
m_cl = zero_cl + i;
}

Eigen::MatrixXd res = stan::math::from_matrix_cl(m_cl);
Eigen::MatrixXd correct = Eigen::MatrixXd::Constant(3, 3, 3);

EXPECT_MATRIX_NEAR(res, correct, 1e-13);
}
}

TEST(MathMatrixCL, events_read_after_write_and_write_after_read) {
using stan::math::matrix_cl;
int iters = 3000;

matrix_cl<double> m1_cl(3, 3);
matrix_cl<double> m2_cl(3, 3);
m1_cl.zeros();

for (int j = 0; j < iters; j++) {
m2_cl = m1_cl + 1;
m1_cl = m2_cl + 1;
}
Eigen::MatrixXd res = stan::math::from_matrix_cl(m1_cl);
Eigen::MatrixXd correct = Eigen::MatrixXd::Constant(3, 3, 2 * iters);

EXPECT_MATRIX_NEAR(res, correct, 1e-13);
}

#endif