From d7a7959c529b18a18794ad3dfaf61469d0a97fc5 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Mon, 19 Feb 2024 23:29:13 -0800 Subject: [PATCH] initial --- plugin/sycl/common/row_set.h | 123 ++++++++++++++++++ .../plugin/test_sycl_row_set_collection.cc | 78 +++++++++++ 2 files changed, 201 insertions(+) create mode 100644 plugin/sycl/common/row_set.h create mode 100644 tests/cpp/plugin/test_sycl_row_set_collection.cc diff --git a/plugin/sycl/common/row_set.h b/plugin/sycl/common/row_set.h new file mode 100644 index 000000000000..574adbf8d9b9 --- /dev/null +++ b/plugin/sycl/common/row_set.h @@ -0,0 +1,123 @@ +/*! + * Copyright 2017-2023 XGBoost contributors + */ +#ifndef PLUGIN_SYCL_COMMON_ROW_SET_H_ +#define PLUGIN_SYCL_COMMON_ROW_SET_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#pragma GCC diagnostic pop +#include +#include +#include + +#include "../data.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + + +/*! \brief Collection of rowsets stored on device in USM memory */ +class RowSetCollection { + public: + /*! \brief data structure to store an instance set, a subset of + * rows (instances) associated with a particular node in a decision + * tree. */ + struct Elem { + const size_t* begin{nullptr}; + const size_t* end{nullptr}; + bst_node_t node_id{-1}; // id of node associated with this instance set; -1 means uninitialized + Elem() + = default; + Elem(const size_t* begin, + const size_t* end, + bst_node_t node_id = -1) + : begin(begin), end(end), node_id(node_id) {} + + + inline size_t Size() const { + return end - begin; + } + }; + + inline size_t Size() const { + return elem_of_each_node_.size(); + } + + /*! \brief return corresponding element set given the node_id */ + inline const Elem& operator[](unsigned node_id) const { + const Elem& e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr) + << "access element that is not in the set"; + return e; + } + + /*! \brief return corresponding element set given the node_id */ + inline Elem& operator[](unsigned node_id) { + Elem& e = elem_of_each_node_[node_id]; + return e; + } + + // clear up things + inline void Clear() { + elem_of_each_node_.clear(); + } + // initialize node id 0->everything + inline void Init() { + CHECK_EQ(elem_of_each_node_.size(), 0U); + + const size_t* begin = row_indices_.Begin(); + const size_t* end = row_indices_.End(); + elem_of_each_node_.emplace_back(Elem(begin, end, 0)); + } + + auto& Data() { return row_indices_; } + + // split rowset into two + inline void AddSplit(unsigned node_id, + unsigned left_node_id, + unsigned right_node_id, + size_t n_left, + size_t n_right) { + const Elem e = elem_of_each_node_[node_id]; + CHECK(e.begin != nullptr); + size_t* all_begin = row_indices_.Begin(); + size_t* begin = all_begin + (e.begin - all_begin); + + + CHECK_EQ(n_left + n_right, e.Size()); + CHECK_LE(begin + n_left, e.end); + CHECK_EQ(begin + n_left + n_right, e.end); + + + if (left_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); + } + if (right_node_id >= elem_of_each_node_.size()) { + elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); + } + + + elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id); + elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id); + elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); + } + + private: + // stores the row indexes in the set + USMVector row_indices_; + // vector: node_id -> elements + std::vector elem_of_each_node_; +}; + +} // namespace common +} // namespace sycl +} // namespace xgboost + + +#endif // PLUGIN_SYCL_COMMON_ROW_SET_H_ diff --git a/tests/cpp/plugin/test_sycl_row_set_collection.cc b/tests/cpp/plugin/test_sycl_row_set_collection.cc new file mode 100644 index 000000000000..f527d9f16d1b --- /dev/null +++ b/tests/cpp/plugin/test_sycl_row_set_collection.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020-2023 by XGBoost contributors + */ +#include + +#include +#include +#include + +#include "../../../plugin/sycl/common/row_set.h" +#include "../../../plugin/sycl/device_manager.h" +#include "../helpers.h" + +namespace xgboost::sycl::common { +TEST(SyclRowSetCollection, AddSplits) { + const size_t num_rows = 16; + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault()); + + RowSetCollection row_set_collection; + + auto& row_indices = row_set_collection.Data(); + row_indices.Resize(&qu, num_rows); + size_t* p_row_indices = row_indices.Data(); + + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(num_rows), + [p_row_indices](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + }); + }).wait_and_throw(); + row_set_collection.Init(); + + CHECK_EQ(row_set_collection.Size(), 1); + { + size_t nid_test = 0; + auto& elem = row_set_collection[nid_test]; + CHECK_EQ(elem.begin, row_indices.Begin()); + CHECK_EQ(elem.end, row_indices.End()); + CHECK_EQ(elem.node_id , 0); + } + + size_t nid = 0; + size_t nid_left = 1; + size_t nid_right = 2; + size_t n_left = 4; + size_t n_right = num_rows - n_left; + row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right); + CHECK_EQ(row_set_collection.Size(), 3); + + { + size_t nid_test = 0; + auto& elem = row_set_collection[nid_test]; + CHECK_EQ(elem.begin, nullptr); + CHECK_EQ(elem.end, nullptr); + CHECK_EQ(elem.node_id , -1); + } + + { + size_t nid_test = 1; + auto& elem = row_set_collection[nid_test]; + CHECK_EQ(elem.begin, row_indices.Begin()); + CHECK_EQ(elem.end, row_indices.Begin() + n_left); + CHECK_EQ(elem.node_id , nid_test); + } + + { + size_t nid_test = 2; + auto& elem = row_set_collection[nid_test]; + CHECK_EQ(elem.begin, row_indices.Begin() + n_left); + CHECK_EQ(elem.end, row_indices.End()); + CHECK_EQ(elem.node_id , nid_test); + } + +} +} // namespace xgboost::sycl::common