-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL] Implement row set collection. (#10057)
Co-authored-by: Dmitry Razdoburdin <>
- Loading branch information
1 parent
0ce4372
commit 761845f
Showing
2 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/*! | ||
* Copyright 2017-2023 XGBoost contributors | ||
*/ | ||
#ifndef PLUGIN_SYCL_COMMON_ROW_SET_H_ | ||
#define PLUGIN_SYCL_COMMON_ROW_SET_H_ | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include <xgboost/data.h> | ||
#pragma GCC diagnostic pop | ||
#include <algorithm> | ||
#include <vector> | ||
#include <utility> | ||
|
||
#include "../data.h" | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace common { | ||
|
||
|
||
/*! \brief Collection of rowsets stored on device in USM memory */ | ||
class RowSetCollection { | ||
public: | ||
/*! \brief data structure to store an instance set, a subset of | ||
* rows (instances) associated with a particular node in a decision | ||
* tree. */ | ||
struct Elem { | ||
const size_t* begin{nullptr}; | ||
const size_t* end{nullptr}; | ||
bst_node_t node_id{-1}; // id of node associated with this instance set; -1 means uninitialized | ||
Elem() | ||
= default; | ||
Elem(const size_t* begin, | ||
const size_t* end, | ||
bst_node_t node_id = -1) | ||
: begin(begin), end(end), node_id(node_id) {} | ||
|
||
|
||
inline size_t Size() const { | ||
return end - begin; | ||
} | ||
}; | ||
|
||
inline size_t Size() const { | ||
return elem_of_each_node_.size(); | ||
} | ||
|
||
/*! \brief return corresponding element set given the node_id */ | ||
inline const Elem& operator[](unsigned node_id) const { | ||
const Elem& e = elem_of_each_node_[node_id]; | ||
CHECK(e.begin != nullptr) | ||
<< "access element that is not in the set"; | ||
return e; | ||
} | ||
|
||
/*! \brief return corresponding element set given the node_id */ | ||
inline Elem& operator[](unsigned node_id) { | ||
Elem& e = elem_of_each_node_[node_id]; | ||
return e; | ||
} | ||
|
||
// clear up things | ||
inline void Clear() { | ||
elem_of_each_node_.clear(); | ||
} | ||
// initialize node id 0->everything | ||
inline void Init() { | ||
CHECK_EQ(elem_of_each_node_.size(), 0U); | ||
|
||
const size_t* begin = row_indices_.Begin(); | ||
const size_t* end = row_indices_.End(); | ||
elem_of_each_node_.emplace_back(Elem(begin, end, 0)); | ||
} | ||
|
||
auto& Data() { return row_indices_; } | ||
|
||
// split rowset into two | ||
inline void AddSplit(unsigned node_id, | ||
unsigned left_node_id, | ||
unsigned right_node_id, | ||
size_t n_left, | ||
size_t n_right) { | ||
const Elem e = elem_of_each_node_[node_id]; | ||
CHECK(e.begin != nullptr); | ||
size_t* all_begin = row_indices_.Begin(); | ||
size_t* begin = all_begin + (e.begin - all_begin); | ||
|
||
|
||
CHECK_EQ(n_left + n_right, e.Size()); | ||
CHECK_LE(begin + n_left, e.end); | ||
CHECK_EQ(begin + n_left + n_right, e.end); | ||
|
||
|
||
if (left_node_id >= elem_of_each_node_.size()) { | ||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); | ||
} | ||
if (right_node_id >= elem_of_each_node_.size()) { | ||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); | ||
} | ||
|
||
|
||
elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id); | ||
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id); | ||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); | ||
} | ||
|
||
private: | ||
// stores the row indexes in the set | ||
USMVector<size_t, MemoryType::on_device> row_indices_; | ||
// vector: node_id -> elements | ||
std::vector<Elem> elem_of_each_node_; | ||
}; | ||
|
||
} // namespace common | ||
} // namespace sycl | ||
} // namespace xgboost | ||
|
||
|
||
#endif // PLUGIN_SYCL_COMMON_ROW_SET_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/** | ||
* Copyright 2020-2023 by XGBoost contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "../../../plugin/sycl/common/row_set.h" | ||
#include "../../../plugin/sycl/device_manager.h" | ||
#include "../helpers.h" | ||
|
||
namespace xgboost::sycl::common { | ||
TEST(SyclRowSetCollection, AddSplits) { | ||
const size_t num_rows = 16; | ||
|
||
DeviceManager device_manager; | ||
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault()); | ||
|
||
RowSetCollection row_set_collection; | ||
|
||
auto& row_indices = row_set_collection.Data(); | ||
row_indices.Resize(&qu, num_rows); | ||
size_t* p_row_indices = row_indices.Data(); | ||
|
||
qu.submit([&](::sycl::handler& cgh) { | ||
cgh.parallel_for<>(::sycl::range<1>(num_rows), | ||
[p_row_indices](::sycl::item<1> pid) { | ||
const size_t idx = pid.get_id(0); | ||
p_row_indices[idx] = idx; | ||
}); | ||
}).wait_and_throw(); | ||
row_set_collection.Init(); | ||
|
||
CHECK_EQ(row_set_collection.Size(), 1); | ||
{ | ||
size_t nid_test = 0; | ||
auto& elem = row_set_collection[nid_test]; | ||
CHECK_EQ(elem.begin, row_indices.Begin()); | ||
CHECK_EQ(elem.end, row_indices.End()); | ||
CHECK_EQ(elem.node_id , 0); | ||
} | ||
|
||
size_t nid = 0; | ||
size_t nid_left = 1; | ||
size_t nid_right = 2; | ||
size_t n_left = 4; | ||
size_t n_right = num_rows - n_left; | ||
row_set_collection.AddSplit(nid, nid_left, nid_right, n_left, n_right); | ||
CHECK_EQ(row_set_collection.Size(), 3); | ||
|
||
{ | ||
size_t nid_test = 0; | ||
auto& elem = row_set_collection[nid_test]; | ||
CHECK_EQ(elem.begin, nullptr); | ||
CHECK_EQ(elem.end, nullptr); | ||
CHECK_EQ(elem.node_id , -1); | ||
} | ||
|
||
{ | ||
size_t nid_test = 1; | ||
auto& elem = row_set_collection[nid_test]; | ||
CHECK_EQ(elem.begin, row_indices.Begin()); | ||
CHECK_EQ(elem.end, row_indices.Begin() + n_left); | ||
CHECK_EQ(elem.node_id , nid_test); | ||
} | ||
|
||
{ | ||
size_t nid_test = 2; | ||
auto& elem = row_set_collection[nid_test]; | ||
CHECK_EQ(elem.begin, row_indices.Begin() + n_left); | ||
CHECK_EQ(elem.end, row_indices.End()); | ||
CHECK_EQ(elem.node_id , nid_test); | ||
} | ||
|
||
} | ||
} // namespace xgboost::sycl::common |