-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial support for column-split cpu predictor #8676
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
9908028
add predict batch test for column split
rongou 4599dca
fix context
rongou acfb694
set up in-memory communicator
rongou 893d140
support cpu prediction with column split
rongou fdb6418
fix test interaction
rongou 69f8bd9
encapsulate allreduce calls
rongou 7908b29
fix test
rongou 693a6ed
extract a help class for column split
rongou 553ade4
add comments
rongou 20bfbcb
try to fix windows build
rongou a39375f
Merge remote-tracking branch 'upstream/master' into split-col-pred
rongou 2de9bb3
Merge remote-tracking branch 'upstream/master' into split-col-pred
rongou 205d19e
add review comments
rongou File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,12 @@ | |
#include <limits> | ||
#include <mutex> | ||
|
||
#include "../collective/communicator-inl.h" | ||
#include "../common/categorical.h" | ||
#include "../common/math.h" | ||
#include "../common/threading_utils.h" | ||
#include "../data/adapter.h" | ||
#include "../data/gradient_index.h" | ||
#include "../data/proxy_dmatrix.h" | ||
#include "../gbm/gbtree_model.h" | ||
#include "cpu_treeshap.h" // CalculateContributions | ||
#include "predict_fn.h" | ||
|
@@ -23,7 +23,6 @@ | |
#include "xgboost/logging.h" | ||
#include "xgboost/predictor.h" | ||
#include "xgboost/tree_model.h" | ||
#include "xgboost/tree_updater.h" | ||
|
||
namespace xgboost { | ||
namespace predictor { | ||
|
@@ -284,16 +283,277 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) { | |
FillNodeMeanValues(tree, 0, mean_values); | ||
} | ||
|
||
class CPUPredictor : public Predictor { | ||
protected: | ||
// init thread buffers | ||
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) { | ||
int prev_thread_temp_size = out->size(); | ||
if (prev_thread_temp_size < nthread) { | ||
out->resize(nthread, RegTree::FVec()); | ||
namespace { | ||
// init thread buffers | ||
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) { | ||
int prev_thread_temp_size = out->size(); | ||
if (prev_thread_temp_size < nthread) { | ||
out->resize(nthread, RegTree::FVec()); | ||
} | ||
} | ||
} // anonymous namespace | ||
|
||
/** | ||
* @brief A helper class for prediction when the DMatrix is split by column. | ||
* | ||
* When data is split by column, a local DMatrix only contains a subset of features. All the workers | ||
* in a distributed/federated environment need to cooperate to produce a prediction. This is done in | ||
* two passes with the help of bit vectors. | ||
* | ||
* First pass: | ||
* for each tree: | ||
* for each row: | ||
* for each node: | ||
* if the feature is available and passes the filter, mark the corresponding decision bit | ||
* if the feature is missing, mark the missing bit | ||
* | ||
* Once the two bit vectors are populated, run allreduce on both, using bitwise OR for the decision | ||
* bits, and bitwise AND for the missing bits. | ||
* | ||
* Second pass: | ||
* for each tree: | ||
* for each row: | ||
* find the leaf node using the decision and missing bits, return the leaf value | ||
* | ||
* The size of the decision/missing bit vector is: | ||
* number of rows in a batch * sum(number of nodes in each tree) | ||
*/ | ||
class ColumnSplitHelper { | ||
public: | ||
ColumnSplitHelper(std::int32_t n_threads, gbm::GBTreeModel const &model, uint32_t tree_begin, | ||
uint32_t tree_end) | ||
: n_threads_{n_threads}, model_{model}, tree_begin_{tree_begin}, tree_end_{tree_end} { | ||
auto const n_trees = tree_end_ - tree_begin_; | ||
tree_sizes_.resize(n_trees); | ||
tree_offsets_.resize(n_trees); | ||
for (auto i = 0; i < n_trees; i++) { | ||
auto const &tree = *model_.trees[tree_begin_ + i]; | ||
tree_sizes_[i] = tree.GetNodes().size(); | ||
} | ||
// std::exclusive_scan (only available in c++17) equivalent to get tree offsets. | ||
tree_offsets_[0] = 0; | ||
for (auto i = 1; i < n_trees; i++) { | ||
tree_offsets_[i] = tree_offsets_[i - 1] + tree_sizes_[i - 1]; | ||
} | ||
bits_per_row_ = tree_offsets_.back() + tree_sizes_.back(); | ||
|
||
InitThreadTemp(n_threads_ * kBlockOfRowsSize, &feat_vecs_); | ||
} | ||
|
||
// Disable copy (and move) semantics. | ||
ColumnSplitHelper(ColumnSplitHelper const &) = delete; | ||
ColumnSplitHelper &operator=(ColumnSplitHelper const &) = delete; | ||
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete; | ||
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete; | ||
|
||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) { | ||
CHECK(xgboost::collective::IsDistributed()) | ||
<< "column-split prediction is only supported for distributed training"; | ||
|
||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) { | ||
CHECK_EQ(out_preds->size(), | ||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group); | ||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds); | ||
} | ||
} | ||
|
||
private: | ||
using BitVector = RBitField8; | ||
|
||
void InitBitVectors(std::size_t n_rows) { | ||
n_rows_ = n_rows; | ||
auto const size = BitVector::ComputeStorageSize(bits_per_row_ * n_rows_); | ||
decision_storage_.resize(size); | ||
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_)); | ||
missing_storage_.resize(size); | ||
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_)); | ||
} | ||
|
||
void ClearBitVectors() { | ||
std::fill(decision_storage_.begin(), decision_storage_.end(), 0); | ||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0); | ||
} | ||
|
||
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const { | ||
size_t tree_index = tree_id - tree_begin_; | ||
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id; | ||
} | ||
|
||
void AllreduceBitVectors() { | ||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(), | ||
decision_storage_.size()); | ||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(), | ||
missing_storage_.size()); | ||
} | ||
|
||
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) { | ||
auto const &tree = *model_.trees[tree_id]; | ||
auto const &cats = tree.GetCategoriesMatrix(); | ||
auto const has_categorical = tree.HasCategoricalSplit(); | ||
|
||
for (auto nid = 0; nid < tree.GetNodes().size(); nid++) { | ||
auto const &node = tree[nid]; | ||
if (node.IsDeleted() || node.IsLeaf()) { | ||
continue; | ||
} | ||
|
||
auto const bit_index = BitIndex(tree_id, row_id, nid); | ||
unsigned split_index = node.SplitIndex(); | ||
if (feat.IsMissing(split_index)) { | ||
missing_bits_.Set(bit_index); | ||
continue; | ||
} | ||
|
||
auto const fvalue = feat.GetFvalue(split_index); | ||
if (has_categorical && common::IsCat(cats.split_type, nid)) { | ||
auto const node_categories = | ||
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); | ||
if (!common::Decision(node_categories, fvalue)) { | ||
decision_bits_.Set(bit_index); | ||
} | ||
continue; | ||
} | ||
|
||
if (fvalue >= node.SplitCond()) { | ||
decision_bits_.Set(bit_index); | ||
} | ||
} | ||
} | ||
|
||
void MaskAllTrees(std::size_t batch_offset, std::size_t fvec_offset, std::size_t block_size) { | ||
for (auto tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) { | ||
for (size_t i = 0; i < block_size; ++i) { | ||
MaskOneTree(feat_vecs_[fvec_offset + i], tree_id, batch_offset + i); | ||
} | ||
} | ||
} | ||
|
||
bst_node_t GetNextNode(RegTree::Node const &node, std::size_t bit_index) { | ||
if (missing_bits_.Check(bit_index)) { | ||
return node.DefaultChild(); | ||
} else { | ||
return node.LeftChild() + decision_bits_.Check(bit_index); | ||
} | ||
} | ||
|
||
bst_node_t GetLeafIndex(RegTree const &tree, std::size_t tree_id, std::size_t row_id) { | ||
bst_node_t nid = 0; | ||
while (!tree[nid].IsLeaf()) { | ||
auto const bit_index = BitIndex(tree_id, row_id, nid); | ||
nid = GetNextNode(tree[nid], bit_index); | ||
} | ||
return nid; | ||
} | ||
|
||
bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) { | ||
auto const &tree = *model_.trees[tree_id]; | ||
auto const leaf = GetLeafIndex(tree, tree_id, row_id); | ||
return tree[leaf].LeafValue(); | ||
} | ||
|
||
void PredictAllTrees(std::vector<bst_float> *out_preds, std::size_t batch_offset, | ||
std::size_t predict_offset, std::size_t num_group, std::size_t block_size) { | ||
auto &preds = *out_preds; | ||
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) { | ||
auto const gid = model_.tree_info[tree_id]; | ||
for (size_t i = 0; i < block_size; ++i) { | ||
preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i); | ||
} | ||
} | ||
} | ||
|
||
template <typename DataView, size_t block_of_rows_size> | ||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) { | ||
auto const num_group = model_.learner_model_param->num_output_group; | ||
|
||
CHECK_EQ(model_.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; | ||
// parallel over local batch | ||
auto const nsize = batch.Size(); | ||
auto const num_feature = model_.learner_model_param->num_feature; | ||
auto const n_blocks = common::DivRoundUp(nsize, block_of_rows_size); | ||
InitBitVectors(nsize); | ||
|
||
// auto block_id has the same type as `n_blocks`. | ||
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) { | ||
auto const batch_offset = block_id * block_of_rows_size; | ||
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size); | ||
auto const fvec_offset = omp_get_thread_num() * block_of_rows_size; | ||
|
||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, &feat_vecs_); | ||
MaskAllTrees(batch_offset, fvec_offset, block_size); | ||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, &feat_vecs_); | ||
}); | ||
|
||
AllreduceBitVectors(); | ||
|
||
// auto block_id has the same type as `n_blocks`. | ||
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) { | ||
auto const batch_offset = block_id * block_of_rows_size; | ||
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size); | ||
PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group, | ||
block_size); | ||
}); | ||
|
||
ClearBitVectors(); | ||
} | ||
|
||
static std::size_t constexpr kBlockOfRowsSize = 64; | ||
|
||
std::int32_t const n_threads_; | ||
gbm::GBTreeModel const &model_; | ||
uint32_t const tree_begin_; | ||
uint32_t const tree_end_; | ||
|
||
std::vector<std::size_t> tree_sizes_{}; | ||
std::vector<std::size_t> tree_offsets_{}; | ||
std::size_t bits_per_row_{}; | ||
std::vector<RegTree::FVec> feat_vecs_{}; | ||
|
||
std::size_t n_rows_; | ||
/** | ||
* @brief Stores decision bit for each split node. | ||
* | ||
* Conceptually it's a 3-dimensional bit matrix: | ||
* - 1st dimension is the tree index, from `tree_begin_` to `tree_end_`. | ||
* - 2nd dimension is the row index, for each row in the batch. | ||
* - 3rd dimension is the node id, for each node in the tree. | ||
* | ||
* Since we have to ship the whole thing over the wire to do an allreduce, the matrix is flattened | ||
* into a 1-dimensional array. | ||
* | ||
* First, it's divided by the tree index: | ||
* | ||
* [ tree 0 ] [ tree 1 ] ... | ||
* | ||
* Then each tree is divided by row: | ||
* | ||
* [ tree 0 ] [ tree 1 ] ... | ||
* [ row 0 ] [ row 1 ] ... [ row n-1 ] [ row 0 ] ... | ||
* | ||
* Finally, each row is divided by the node id: | ||
* | ||
* [ tree 0 ] | ||
* [ row 0 ] [ row 1 ] ... | ||
* [ node 0 ] [ node 1 ] ... [ node n-1 ] [ node 0 ] ... | ||
* | ||
* The first two dimensions are fixed length, while the last dimension is variable length since | ||
* each tree may have a different number of nodes. We precompute the tree offsets, which are the | ||
* cumulative sums of tree sizes. The index of tree t, row r, node n is: | ||
* index(t, r, n) = tree_offsets[t] * n_rows + r * tree_sizes[t] + n | ||
*/ | ||
std::vector<BitVector::value_type> decision_storage_{}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add some comments about the layout of the storage? How to index it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
BitVector decision_bits_{}; | ||
/** | ||
* @brief Stores whether the feature is missing for each split node. | ||
* | ||
* See above for the storage layout. | ||
*/ | ||
std::vector<BitVector::value_type> missing_storage_{}; | ||
BitVector missing_bits_{}; | ||
}; | ||
|
||
class CPUPredictor : public Predictor { | ||
protected: | ||
void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin, | ||
int32_t tree_end, std::vector<bst_float> *out_preds) const { | ||
auto const n_threads = this->ctx_->Threads(); | ||
|
@@ -323,6 +583,12 @@ class CPUPredictor : public Predictor { | |
|
||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds, | ||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const { | ||
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) { | ||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); | ||
helper.PredictDMatrix(p_fmat, out_preds); | ||
return; | ||
} | ||
|
||
if (!p_fmat->PageExists<SparsePage>()) { | ||
this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds); | ||
return; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use an anonymous namespace instead of static if it's not intended to be used outside of the TU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.