Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for column-split cpu predictor #8676

Merged
merged 13 commits into from
Jan 17, 2023
2 changes: 2 additions & 0 deletions src/collective/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class Communicator {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
Expand Down
284 changes: 275 additions & 9 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
#include <limits>
#include <mutex>

#include "../collective/communicator-inl.h"
#include "../common/categorical.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "../data/adapter.h"
#include "../data/gradient_index.h"
#include "../data/proxy_dmatrix.h"
#include "../gbm/gbtree_model.h"
#include "cpu_treeshap.h" // CalculateContributions
#include "predict_fn.h"
Expand All @@ -23,7 +23,6 @@
#include "xgboost/logging.h"
#include "xgboost/predictor.h"
#include "xgboost/tree_model.h"
#include "xgboost/tree_updater.h"

namespace xgboost {
namespace predictor {
Expand Down Expand Up @@ -284,16 +283,277 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
FillNodeMeanValues(tree, 0, mean_values);
}

class CPUPredictor : public Predictor {
protected:
// init thread buffers
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
int prev_thread_temp_size = out->size();
if (prev_thread_temp_size < nthread) {
out->resize(nthread, RegTree::FVec());
namespace {
// init thread buffers
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use an anonymous namespace instead of static if it's not intended to be used outside of the TU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

int prev_thread_temp_size = out->size();
if (prev_thread_temp_size < nthread) {
out->resize(nthread, RegTree::FVec());
}
}
} // anonymous namespace

/**
* @brief A helper class for prediction when the DMatrix is split by column.
*
* When data is split by column, a local DMatrix only contains a subset of features. All the workers
* in a distributed/federated environment need to cooperate to produce a prediction. This is done in
* two passes with the help of bit vectors.
*
* First pass:
* for each tree:
* for each row:
* for each node:
* if the feature is available and passes the filter, mark the corresponding decision bit
* if the feature is missing, mark the missing bit
*
* Once the two bit vectors are populated, run allreduce on both, using bitwise OR for the decision
* bits, and bitwise AND for the missing bits.
*
* Second pass:
* for each tree:
* for each row:
* find the leaf node using the decision and missing bits, return the leaf value
*
* The size of the decision/missing bit vector is:
* number of rows in a batch * sum(number of nodes in each tree)
*/
class ColumnSplitHelper {
public:
ColumnSplitHelper(std::int32_t n_threads, gbm::GBTreeModel const &model, uint32_t tree_begin,
uint32_t tree_end)
: n_threads_{n_threads}, model_{model}, tree_begin_{tree_begin}, tree_end_{tree_end} {
auto const n_trees = tree_end_ - tree_begin_;
tree_sizes_.resize(n_trees);
tree_offsets_.resize(n_trees);
for (auto i = 0; i < n_trees; i++) {
auto const &tree = *model_.trees[tree_begin_ + i];
tree_sizes_[i] = tree.GetNodes().size();
}
// std::exclusive_scan (only available in c++17) equivalent to get tree offsets.
tree_offsets_[0] = 0;
for (auto i = 1; i < n_trees; i++) {
tree_offsets_[i] = tree_offsets_[i - 1] + tree_sizes_[i - 1];
}
bits_per_row_ = tree_offsets_.back() + tree_sizes_.back();

InitThreadTemp(n_threads_ * kBlockOfRowsSize, &feat_vecs_);
}

// Disable copy (and move) semantics.
ColumnSplitHelper(ColumnSplitHelper const &) = delete;
ColumnSplitHelper &operator=(ColumnSplitHelper const &) = delete;
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;

void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
CHECK(xgboost::collective::IsDistributed())
<< "column-split prediction is only supported for distributed training";

for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
}
}

private:
using BitVector = RBitField8;

void InitBitVectors(std::size_t n_rows) {
n_rows_ = n_rows;
auto const size = BitVector::ComputeStorageSize(bits_per_row_ * n_rows_);
decision_storage_.resize(size);
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
missing_storage_.resize(size);
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
}

void ClearBitVectors() {
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
}

std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
size_t tree_index = tree_id - tree_begin_;
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
}

void AllreduceBitVectors() {
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
decision_storage_.size());
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
missing_storage_.size());
}

void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
auto const &tree = *model_.trees[tree_id];
auto const &cats = tree.GetCategoriesMatrix();
auto const has_categorical = tree.HasCategoricalSplit();

for (auto nid = 0; nid < tree.GetNodes().size(); nid++) {
auto const &node = tree[nid];
if (node.IsDeleted() || node.IsLeaf()) {
continue;
}

auto const bit_index = BitIndex(tree_id, row_id, nid);
unsigned split_index = node.SplitIndex();
if (feat.IsMissing(split_index)) {
missing_bits_.Set(bit_index);
continue;
}

auto const fvalue = feat.GetFvalue(split_index);
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto const node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
if (!common::Decision(node_categories, fvalue)) {
decision_bits_.Set(bit_index);
}
continue;
}

if (fvalue >= node.SplitCond()) {
decision_bits_.Set(bit_index);
}
}
}

void MaskAllTrees(std::size_t batch_offset, std::size_t fvec_offset, std::size_t block_size) {
for (auto tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
for (size_t i = 0; i < block_size; ++i) {
MaskOneTree(feat_vecs_[fvec_offset + i], tree_id, batch_offset + i);
}
}
}

bst_node_t GetNextNode(RegTree::Node const &node, std::size_t bit_index) {
if (missing_bits_.Check(bit_index)) {
return node.DefaultChild();
} else {
return node.LeftChild() + decision_bits_.Check(bit_index);
}
}

bst_node_t GetLeafIndex(RegTree const &tree, std::size_t tree_id, std::size_t row_id) {
bst_node_t nid = 0;
while (!tree[nid].IsLeaf()) {
auto const bit_index = BitIndex(tree_id, row_id, nid);
nid = GetNextNode(tree[nid], bit_index);
}
return nid;
}

bst_float PredictOneTree(std::size_t tree_id, std::size_t row_id) {
auto const &tree = *model_.trees[tree_id];
auto const leaf = GetLeafIndex(tree, tree_id, row_id);
return tree[leaf].LeafValue();
}

void PredictAllTrees(std::vector<bst_float> *out_preds, std::size_t batch_offset,
std::size_t predict_offset, std::size_t num_group, std::size_t block_size) {
auto &preds = *out_preds;
for (size_t tree_id = tree_begin_; tree_id < tree_end_; ++tree_id) {
auto const gid = model_.tree_info[tree_id];
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] += PredictOneTree(tree_id, batch_offset + i);
}
}
}

template <typename DataView, size_t block_of_rows_size>
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
auto const num_group = model_.learner_model_param->num_output_group;

CHECK_EQ(model_.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far";
// parallel over local batch
auto const nsize = batch.Size();
auto const num_feature = model_.learner_model_param->num_feature;
auto const n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
InitBitVectors(nsize);

// auto block_id has the same type as `n_blocks`.
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
auto const batch_offset = block_id * block_of_rows_size;
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size);
auto const fvec_offset = omp_get_thread_num() * block_of_rows_size;

FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, &feat_vecs_);
MaskAllTrees(batch_offset, fvec_offset, block_size);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, &feat_vecs_);
});

AllreduceBitVectors();

// auto block_id has the same type as `n_blocks`.
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
auto const batch_offset = block_id * block_of_rows_size;
auto const block_size = std::min(nsize - batch_offset, block_of_rows_size);
PredictAllTrees(out_preds, batch_offset, batch_offset + batch.base_rowid, num_group,
block_size);
});

ClearBitVectors();
}

static std::size_t constexpr kBlockOfRowsSize = 64;

std::int32_t const n_threads_;
gbm::GBTreeModel const &model_;
uint32_t const tree_begin_;
uint32_t const tree_end_;

std::vector<std::size_t> tree_sizes_{};
std::vector<std::size_t> tree_offsets_{};
std::size_t bits_per_row_{};
std::vector<RegTree::FVec> feat_vecs_{};

std::size_t n_rows_;
/**
* @brief Stores decision bit for each split node.
*
* Conceptually it's a 3-dimensional bit matrix:
* - 1st dimension is the tree index, from `tree_begin_` to `tree_end_`.
* - 2nd dimension is the row index, for each row in the batch.
* - 3rd dimension is the node id, for each node in the tree.
*
* Since we have to ship the whole thing over the wire to do an allreduce, the matrix is flattened
* into a 1-dimensional array.
*
* First, it's divided by the tree index:
*
* [ tree 0 ] [ tree 1 ] ...
*
* Then each tree is divided by row:
*
* [ tree 0 ] [ tree 1 ] ...
* [ row 0 ] [ row 1 ] ... [ row n-1 ] [ row 0 ] ...
*
* Finally, each row is divided by the node id:
*
* [ tree 0 ]
* [ row 0 ] [ row 1 ] ...
* [ node 0 ] [ node 1 ] ... [ node n-1 ] [ node 0 ] ...
*
* The first two dimensions are fixed length, while the last dimension is variable length since
* each tree may have a different number of nodes. We precompute the tree offsets, which are the
* cumulative sums of tree sizes. The index of tree t, row r, node n is:
* index(t, r, n) = tree_offsets[t] * n_rows + r * tree_sizes[t] + n
*/
std::vector<BitVector::value_type> decision_storage_{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments about the layout of the storage? How to index it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

BitVector decision_bits_{};
/**
* @brief Stores whether the feature is missing for each split node.
*
* See above for the storage layout.
*/
std::vector<BitVector::value_type> missing_storage_{};
BitVector missing_bits_{};
};

class CPUPredictor : public Predictor {
protected:
void PredictGHistIndex(DMatrix *p_fmat, gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end, std::vector<bst_float> *out_preds) const {
auto const n_threads = this->ctx_->Threads();
Expand Down Expand Up @@ -323,6 +583,12 @@ class CPUPredictor : public Predictor {

void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end) const {
if (p_fmat->Info().data_split_mode == DataSplitMode::kCol) {
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
return;
}

if (!p_fmat->PageExists<SparsePage>()) {
this->PredictGHistIndex(p_fmat, model, tree_begin, tree_end, out_preds);
return;
Expand Down
13 changes: 11 additions & 2 deletions tests/cpp/collective/test_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ namespace collective {
TEST(CommunicatorFactory, TypeFromEnv) {
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());

dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);

dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());

dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());

dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "In-Memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromEnv());
}

TEST(CommunicatorFactory, TypeFromArgs) {
Expand All @@ -32,6 +35,9 @@ TEST(CommunicatorFactory, TypeFromArgs) {
config["xgboost_communicator"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));

config["xgboost_communicator"] = String("in-memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));

config["xgboost_communicator"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
Expand All @@ -46,6 +52,9 @@ TEST(CommunicatorFactory, TypeFromArgsUpperCase) {
config["XGBOOST_COMMUNICATOR"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));

config["XGBOOST_COMMUNICATOR"] = String("in-memory");
EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config));

config["XGBOOST_COMMUNICATOR"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
Expand Down
Loading