diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 1241ced409cd..83d003052149 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -57,6 +57,7 @@ #include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_sync.cc" #include "../src/tree/updater_histmaker.cc" +#include "../src/tree/updater_approx.cc" #include "../src/tree/constraints.cc" // linear diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index 7bca5db03e01..5cf72ba82145 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -8,23 +8,24 @@ import os import xgboost from typing import Callable, List, Tuple +from sklearn.datasets import make_regression import tempfile import numpy as np def make_batches( - n_samples_per_batch: int, n_features: int, n_batches: int -) -> Tuple[List[np.ndarray], List[np.ndarray]]: - """Generate random batches.""" - X = [] - y = [] + n_samples_per_batch: int, n_features: int, n_batches: int, tmpdir: str, +) -> List[Tuple[str, str]]: + files: List[Tuple[str, str]] = [] rng = np.random.RandomState(1994) for i in range(n_batches): - _X = rng.randn(n_samples_per_batch, n_features) - _y = rng.randn(n_samples_per_batch) - X.append(_X) - y.append(_y) - return X, y + X, y = make_regression(n_samples_per_batch, n_features, random_state=rng) + X_path = os.path.join(tmpdir, "X-" + str(i) + ".npy") + y_path = os.path.join(tmpdir, "y-" + str(i) + ".npy") + np.save(X_path, X) + np.save(y_path, y) + files.append((X_path, y_path)) + return files class Iterator(xgboost.DataIter): @@ -38,8 +39,8 @@ def __init__(self, file_paths: List[Tuple[str, str]]): def load_file(self) -> Tuple[np.ndarray, np.ndarray]: X_path, y_path = self._file_paths[self._it] - X = np.loadtxt(X_path) - y = np.loadtxt(y_path) + X = np.load(X_path) + y = np.load(y_path) assert X.shape[0] == y.shape[0] return X, y @@ -66,15 +67,7 @@ def reset(self) -> None: def main(tmpdir: str) -> xgboost.Booster: # generate some random data for demo - batches = make_batches(1024, 17, 31) - files = [] - for i, (X, y) in enumerate(zip(*batches)): - X_path = os.path.join(tmpdir, "X-" + str(i) + ".txt") - np.savetxt(X_path, X) - y_path = os.path.join(tmpdir, "y-" + str(i) + ".txt") - np.savetxt(y_path, y) - files.append((X_path, y_path)) - + files = make_batches(1024, 17, 31, tmpdir) it = Iterator(files) # For non-data arguments, specify it here once instead of passing them by the `next` # method. @@ -83,7 +76,7 @@ def main(tmpdir: str) -> xgboost.Booster: # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some # caveats. This is still an experimental feature. - booster = xgboost.train({"tree_method": "approx"}, Xy) + booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")]) return booster diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 7a8bf6fa4d90..7017615c2b0d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -219,8 +219,12 @@ abstract class XGBoostRegressorSuiteBase extends FunSuite with PerTest { } } -class XGBoostCpuRegressorSuite extends XGBoostRegressorSuiteBase { +class XGBoostCpuRegressorSuiteApprox extends XGBoostRegressorSuiteBase { + override protected val treeMethod: String = "approx" +} +class XGBoostCpuRegressorSuiteHist extends XGBoostRegressorSuiteBase { + override protected val treeMethod: String = "hist" } @GpuTestSuite diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 40392a2bcfaa..6431487bb4b6 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -133,74 +133,84 @@ struct Prefetch { constexpr size_t Prefetch::kNoPrefetchSize; - -template -void BuildHistKernel(const std::vector& gpair, +template +void BuildHistKernel(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { + const GHistIndexMatrix &gmat, GHistRow hist) { const size_t size = row_indices.Size(); - const size_t* rid = row_indices.begin; - const float* pgh = reinterpret_cast(gpair.data()); - const BinIdxType* gradient_index = gmat.index.data(); - const size_t* row_ptr = gmat.row_ptr.data(); - const uint32_t* offsets = gmat.index.Offset(); - const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; - FPType* hist_data = reinterpret_cast(hist.data()); - const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains - // 2 FP values: gradient and hessian. - // So we need to multiply each row-index/bin-index by 2 - // to work with gradient pairs as a singe row FP array + const size_t *rid = row_indices.begin; + auto const *pgh = reinterpret_cast(gpair.data()); + const BinIdxType *gradient_index = gmat.index.data(); + + auto const &row_ptr = gmat.row_ptr; + auto base_rowid = gmat.base_rowid; + const uint32_t *offsets = gmat.index.Offset(); + auto get_row_ptr = [&](size_t ridx) { return row_ptr[ridx - base_rowid]; }; + auto get_rid = [&](size_t ridx) { return ridx - base_rowid; }; + + const size_t n_features = + get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); + auto hist_data = reinterpret_cast(hist.data()); + const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains + // 2 FP values: gradient and hessian. + // So we need to multiply each row-index/bin-index by 2 + // to work with gradient pairs as a singe row FP array for (size_t i = 0; i < size; ++i) { - const size_t icol_start = any_missing ? row_ptr[rid[i]] : rid[i] * n_features; - const size_t icol_end = any_missing ? row_ptr[rid[i]+1] : icol_start + n_features; + const size_t icol_start = + any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; + const size_t icol_end = + any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + const size_t row_size = icol_end - icol_start; const size_t idx_gh = two * rid[i]; if (do_prefetch) { - const size_t icol_start_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]] : - rid[i + Prefetch::kPrefetchOffset] * n_features; - const size_t icol_end_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]+1] : - icol_start_prefetch + n_features; + const size_t icol_start_prefetch = + any_missing + ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset]) + : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features; + const size_t icol_end_prefetch = + any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) + : icol_start_prefetch + n_features; PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); for (size_t j = icol_start_prefetch; j < icol_end_prefetch; - j+=Prefetch::GetPrefetchStep()) { + j += Prefetch::GetPrefetchStep()) { PREFETCH_READ_T0(gradient_index + j); } } - const BinIdxType* gr_index_local = gradient_index + icol_start; + const BinIdxType *gr_index_local = gradient_index + icol_start; for (size_t j = 0; j < row_size; ++j) { - const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + ( - any_missing ? 0 : offsets[j])); - - hist_data[idx_bin] += pgh[idx_gh]; - hist_data[idx_bin+1] += pgh[idx_gh+1]; + const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + + (any_missing ? 0 : offsets[j])); + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin + 1] += pgh[idx_gh + 1]; } } } -template -void BuildHistDispatch(const std::vector& gpair, +template +void BuildHistDispatch(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, GHistRow hist) { + const GHistIndexMatrix &gmat, GHistRow hist) { switch (gmat.index.GetBinTypeSize()) { - case kUint8BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); - break; - case kUint16BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); - break; - case kUint32BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); - break; - default: - CHECK(false); // no default behavior + case kUint8BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint16BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint32BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + default: + CHECK(false); // no default behavior } } @@ -208,9 +218,8 @@ template template void GHistBuilder::BuildHist( const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRowT hist) { + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRowT hist) const { const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); @@ -233,22 +242,22 @@ template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::SubtractionTrick(GHistRowT self, @@ -262,8 +271,9 @@ void GHistBuilder::SubtractionTrick(GHistRowT self, size_t n_blocks = size/block_size + !!(size%block_size); ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) { - const size_t ibegin = iblock*block_size; - const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); + const size_t ibegin = iblock * block_size; + const size_t iend = + (((iblock + 1) * block_size > size) ? size : ibegin + block_size); SubtractionHist(self, parent, sibling, ibegin, iend); }); } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 9dc0bd1c5fd1..1fa3a2240258 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -441,7 +441,7 @@ class ParallelGHistBuilder { } // Reduce following bins (begin, end] for nid-node in dst across threads - void ReduceHist(size_t nid, size_t begin, size_t end) { + void ReduceHist(size_t nid, size_t begin, size_t end) const { CHECK_GT(end, begin); CHECK_LT(nid, nodes_); @@ -467,7 +467,6 @@ class ParallelGHistBuilder { } } - protected: void MatchThreadsToNodes(const BlockedSpace2d& space) { const size_t space_size = space.Size(); const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_); @@ -514,6 +513,7 @@ class ParallelGHistBuilder { } } + private: void MatchNodeNidPairToHist() { size_t hist_allocated_additionally = 0; @@ -574,7 +574,7 @@ class GHistBuilder { void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRowT hist); + GHistRowT hist) const; // construct a histogram via subtraction trick void SubtractionTrick(GHistRowT self, GHistRowT sibling, diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 98612359ec73..5ffe34988968 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,228 +1,267 @@ - -/*! - * Copyright 2021 by Contributors - * \file row_set.h - * \brief Quick Utility to compute subset of rows - * \author Philip Cho, Tianqi Chen - */ -#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_ -#define XGBOOST_COMMON_PARTITION_BUILDER_H_ - -#include -#include -#include -#include -#include -#include "xgboost/tree_model.h" -#include "../common/column_matrix.h" - -namespace xgboost { -namespace common { - -// The builder is required for samples partition to left and rights children for set of nodes -// Responsible for: -// 1) Effective memory allocation for intermediate results for multi-thread work -// 2) Merging partial results produced by threads into original row set (row_set_collection_) -// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature -template -class PartitionBuilder { - public: - template - void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) { - left_right_nodes_sizes_.resize(n_nodes); - blocks_offsets_.resize(n_nodes+1); - - blocks_offsets_[0] = 0; - for (size_t i = 1; i < n_nodes+1; ++i) { - blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1); - } - - if (n_tasks > max_n_tasks_) { - mem_blocks_.resize(n_tasks); - max_n_tasks_ = n_tasks; - } - } - - // split row indexes (rid_span) to 2 parts (left_part, right_part) depending - // on comparison of indexes values (idx_span) and split point (split_cond) - // Handle dense columns - // Analog of std::stable_partition, but in no-inplace manner - template - inline std::pair PartitionKernel(const ColumnType& column, - common::Span rid_span, const int32_t split_cond, - common::Span left_part, common::Span right_part) { - size_t* p_left_part = left_part.data(); - size_t* p_right_part = right_part.data(); - size_t nleft_elems = 0; - size_t nright_elems = 0; - auto state = column.GetInitialState(rid_span.front()); - - for (auto rid : rid_span) { - const int32_t bin_id = column.GetBinIdx(rid, &state); - if (any_missing && bin_id == ColumnType::kMissingId) { - if (default_left) { - p_left_part[nleft_elems++] = rid; - } else { - p_right_part[nright_elems++] = rid; - } - } else { - if (bin_id <= split_cond) { - p_left_part[nleft_elems++] = rid; - } else { - p_right_part[nright_elems++] = rid; - } - } - } - - return {nleft_elems, nright_elems}; - } - - - template - void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, - const int32_t split_cond, - const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { - common::Span rid_span(rid + range.begin(), rid + range.end()); - common::Span left = GetLeftBuffer(node_in_set, - range.begin(), range.end()); - common::Span right = GetRightBuffer(node_in_set, - range.begin(), range.end()); - const bst_uint fid = tree[nid].SplitIndex(); - const bool default_left = tree[nid].DefaultLeft(); - const auto column_ptr = column_matrix.GetColumn(fid); - - std::pair child_nodes_sizes; - - if (column_ptr->GetType() == xgboost::common::kDenseColumn) { - const common::DenseColumn& column = - static_cast& >(*(column_ptr.get())); - if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } - } else { - CHECK_EQ(any_missing, true); - const common::SparseColumn& column - = static_cast& >(*(column_ptr.get())); - if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); - } - } - - const size_t n_left = child_nodes_sizes.first; - const size_t n_right = child_nodes_sizes.second; - - SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); - SetNRightElems(node_in_set, range.begin(), range.end(), n_right); - } - - - // allocate thread local memory, should be called for each specific task - void AllocateForTask(size_t id) { - if (mem_blocks_[id].get() == nullptr) { - BlockInfo* local_block_ptr = new BlockInfo; - CHECK_NE(local_block_ptr, (BlockInfo*)nullptr); - mem_blocks_[id].reset(local_block_ptr); - } - } - - common::Span GetLeftBuffer(int nid, size_t begin, size_t end) { - const size_t task_idx = GetTaskIdx(nid, begin); - return { mem_blocks_.at(task_idx)->Left(), end - begin }; - } - - common::Span GetRightBuffer(int nid, size_t begin, size_t end) { - const size_t task_idx = GetTaskIdx(nid, begin); - return { mem_blocks_.at(task_idx)->Right(), end - begin }; - } - - void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { - size_t task_idx = GetTaskIdx(nid, begin); - mem_blocks_.at(task_idx)->n_left = n_left; - } - - void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { - size_t task_idx = GetTaskIdx(nid, begin); - mem_blocks_.at(task_idx)->n_right = n_right; - } - - - size_t GetNLeftElems(int nid) const { - return left_right_nodes_sizes_[nid].first; - } - - size_t GetNRightElems(int nid) const { - return left_right_nodes_sizes_[nid].second; - } - - // Each thread has partial results for some set of tree-nodes - // The function decides order of merging partial results into final row set - void CalculateRowOffsets() { - for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { - size_t n_left = 0; - for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { - mem_blocks_[j]->n_offset_left = n_left; - n_left += mem_blocks_[j]->n_left; - } - size_t n_right = 0; - for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { - mem_blocks_[j]->n_offset_right = n_left + n_right; - n_right += mem_blocks_[j]->n_right; - } - left_right_nodes_sizes_[i] = {n_left, n_right}; - } - } - - void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { - size_t task_idx = GetTaskIdx(nid, begin); - - size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; - size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; - - const size_t* left = mem_blocks_[task_idx]->Left(); - const size_t* right = mem_blocks_[task_idx]->Right(); - - std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); - std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); - } - - size_t GetTaskIdx(int nid, size_t begin) { - return blocks_offsets_[nid] + begin / BlockSize; - } - - protected: - struct BlockInfo{ - size_t n_left; - size_t n_right; - - size_t n_offset_left; - size_t n_offset_right; - - size_t* Left() { - return &left_data_[0]; - } - - size_t* Right() { - return &right_data_[0]; - } - private: - size_t left_data_[BlockSize]; - size_t right_data_[BlockSize]; - }; - std::vector> left_right_nodes_sizes_; - std::vector blocks_offsets_; - std::vector> mem_blocks_; - size_t max_n_tasks_ = 0; -}; - -} // namespace common -} // namespace xgboost - -#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ + +/*! + * Copyright 2021 by Contributors + * \file row_set.h + * \brief Quick Utility to compute subset of rows + * \author Philip Cho, Tianqi Chen + */ +#ifndef XGBOOST_COMMON_PARTITION_BUILDER_H_ +#define XGBOOST_COMMON_PARTITION_BUILDER_H_ + +#include +#include +#include +#include +#include +#include "xgboost/tree_model.h" +#include "../common/column_matrix.h" + +namespace xgboost { +namespace common { + +// The builder is required for samples partition to left and rights children for set of nodes +// Responsible for: +// 1) Effective memory allocation for intermediate results for multi-thread work +// 2) Merging partial results produced by threads into original row set (row_set_collection_) +// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature +template +class PartitionBuilder { + public: + template + void Init(const size_t n_tasks, size_t n_nodes, Func funcNTask) { + left_right_nodes_sizes_.resize(n_nodes); + blocks_offsets_.resize(n_nodes+1); + + blocks_offsets_[0] = 0; + for (size_t i = 1; i < n_nodes+1; ++i) { + blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTask(i-1); + } + + if (n_tasks > max_n_tasks_) { + mem_blocks_.resize(n_tasks); + max_n_tasks_ = n_tasks; + } + } + + // split row indexes (rid_span) to 2 parts (left_part, right_part) depending + // on comparison of indexes values (idx_span) and split point (split_cond) + // Handle dense columns + // Analog of std::stable_partition, but in no-inplace manner + template + inline std::pair PartitionKernel(const ColumnType& column, + common::Span rid_span, const int32_t split_cond, + common::Span left_part, common::Span right_part) { + size_t* p_left_part = left_part.data(); + size_t* p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + auto state = column.GetInitialState(rid_span.front()); + + for (auto rid : rid_span) { + const int32_t bin_id = column.GetBinIdx(rid, &state); + if (any_missing && bin_id == ColumnType::kMissingId) { + if (default_left) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + } else { + if (bin_id <= split_cond) { + p_left_part[nleft_elems++] = rid; + } else { + p_right_part[nright_elems++] = rid; + } + } + } + + return {nleft_elems, nright_elems}; + } + + template + inline std::pair + PartitionRangeKernel(common::Span ridx, + common::Span left_part, + common::Span right_part, Pred pred) { + size_t *p_left_part = left_part.data(); + size_t *p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + for (auto row_id : ridx) { + if (pred(row_id)) { + p_left_part[nleft_elems++] = row_id; + } else { + p_right_part[nright_elems++] = row_id; + } + } + return {nleft_elems, nright_elems}; + } + + template + void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, + const int32_t split_cond, + const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { + common::Span rid_span(rid + range.begin(), rid + range.end()); + common::Span left = GetLeftBuffer(node_in_set, + range.begin(), range.end()); + common::Span right = GetRightBuffer(node_in_set, + range.begin(), range.end()); + const bst_uint fid = tree[nid].SplitIndex(); + const bool default_left = tree[nid].DefaultLeft(); + const auto column_ptr = column_matrix.GetColumn(fid); + + std::pair child_nodes_sizes; + + if (column_ptr->GetType() == xgboost::common::kDenseColumn) { + const common::DenseColumn& column = + static_cast& >(*(column_ptr.get())); + if (default_left) { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } else { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } + } else { + CHECK_EQ(any_missing, true); + const common::SparseColumn& column + = static_cast& >(*(column_ptr.get())); + if (default_left) { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } else { + child_nodes_sizes = PartitionKernel(column, rid_span, + split_cond, left, right); + } + } + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + SetNRightElems(node_in_set, range.begin(), range.end(), n_right); + } + + template + void PartitionRange(const size_t node_in_set, const size_t nid, + common::Range1d range, bst_feature_t fidx, + common::RowSetCollection *p_row_set_collection, + Pred pred) { + auto &row_set_collection = *p_row_set_collection; + const size_t *p_ridx = row_set_collection[nid].begin; + common::Span ridx(p_ridx + range.begin(), p_ridx + range.end()); + common::Span left = + this->GetLeftBuffer(node_in_set, range.begin(), range.end()); + common::Span right = + this->GetRightBuffer(node_in_set, range.begin(), range.end()); + std::pair child_nodes_sizes = + PartitionRangeKernel(ridx, left, right, pred); + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right); + } + + // allocate thread local memory, should be called for each specific task + void AllocateForTask(size_t id) { + if (mem_blocks_[id].get() == nullptr) { + BlockInfo* local_block_ptr = new BlockInfo; + CHECK_NE(local_block_ptr, (BlockInfo*)nullptr); + mem_blocks_[id].reset(local_block_ptr); + } + } + + common::Span GetLeftBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx)->Left(), end - begin }; + } + + common::Span GetRightBuffer(int nid, size_t begin, size_t end) { + const size_t task_idx = GetTaskIdx(nid, begin); + return { mem_blocks_.at(task_idx)->Right(), end - begin }; + } + + void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx)->n_left = n_left; + } + + void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) { + size_t task_idx = GetTaskIdx(nid, begin); + mem_blocks_.at(task_idx)->n_right = n_right; + } + + + size_t GetNLeftElems(int nid) const { + return left_right_nodes_sizes_[nid].first; + } + + size_t GetNRightElems(int nid) const { + return left_right_nodes_sizes_[nid].second; + } + + // Each thread has partial results for some set of tree-nodes + // The function decides order of merging partial results into final row set + void CalculateRowOffsets() { + for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) { + size_t n_left = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j]->n_offset_left = n_left; + n_left += mem_blocks_[j]->n_left; + } + size_t n_right = 0; + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + mem_blocks_[j]->n_offset_right = n_left + n_right; + n_right += mem_blocks_[j]->n_right; + } + left_right_nodes_sizes_[i] = {n_left, n_right}; + } + } + + void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { + size_t task_idx = GetTaskIdx(nid, begin); + + size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; + size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; + + const size_t* left = mem_blocks_[task_idx]->Left(); + const size_t* right = mem_blocks_[task_idx]->Right(); + + std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); + std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); + } + + size_t GetTaskIdx(int nid, size_t begin) { + return blocks_offsets_[nid] + begin / BlockSize; + } + + protected: + struct BlockInfo{ + size_t n_left; + size_t n_right; + + size_t n_offset_left; + size_t n_offset_right; + + size_t* Left() { + return &left_data_[0]; + } + + size_t* Right() { + return &right_data_[0]; + } + private: + size_t left_data_[BlockSize]; + size_t right_data_[BlockSize]; + }; + std::vector> left_right_nodes_sizes_; + std::vector blocks_offsets_; + std::vector> mem_blocks_; + size_t max_n_tasks_ = 0; +}; + +} // namespace common +} // namespace xgboost + +#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index ab3765f501fe..d6cd44b71466 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -251,6 +251,7 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { inline int32_t OmpGetNumThreads(int32_t n_threads) { if (n_threads <= 0) { n_threads = omp_get_num_procs(); + n_threads = std::min(n_threads, omp_get_max_threads()); } return n_threads; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 859e5ba9d3ad..d3822e06c1d9 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -168,7 +168,8 @@ void GBTree::ConfigureUpdaters() { // calling this function. break; case TreeMethod::kApprox: - tparam_.updater_seq = "grow_histmaker,prune"; + // grow_histmaker,prune + tparam_.updater_seq = "grow_global_approx_histmaker"; break; case TreeMethod::kExact: tparam_.updater_seq = "grow_colmaker,prune"; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index f78a7ed0974a..24b99ed4a8c7 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -56,15 +56,15 @@ template class HistEvaluator { // a non-missing value for the particular feature fid. template GradStats EnumerateSplit( - const GHistIndexMatrix &gmat, const common::GHistRow &hist, + common::HistogramCuts const &cut, const common::GHistRow &hist, const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator) const { static_assert(d_step == +1 || d_step == -1, "Invalid step."); // aliases - const std::vector &cut_ptr = gmat.cut.Ptrs(); - const std::vector &cut_val = gmat.cut.Values(); + const std::vector &cut_ptr = cut.Ptrs(); + const std::vector &cut_val = cut.Values(); // statistics on both sides of split GradStats c; @@ -116,7 +116,7 @@ template class HistEvaluator { snode.root_gain); if (i == imin) { // for leftmost bin, left bound is the smallest feature value - split_pt = gmat.cut.MinValues()[fidx]; + split_pt = cut.MinValues()[fidx]; } else { split_pt = cut_val[i - 1]; } @@ -132,7 +132,7 @@ template class HistEvaluator { public: void EvaluateSplits(const common::HistCollection &hist, - GHistIndexMatrix const &gidx, const RegTree &tree, + common::HistogramCuts const &cut, const RegTree &tree, std::vector* p_entries) { auto& entries = *p_entries; // All nodes are on the same level, so we can store the shared ptr. @@ -168,10 +168,10 @@ template class HistEvaluator { for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { auto fidx = features_set[fidx_in_set]; if (interaction_constraints_.Query(nidx, fidx)) { - auto grad_stats = EnumerateSplit<+1>(gidx, histogram, snode_[nidx], + auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx], best, fidx, nidx, evaluator); if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1>(gidx, histogram, snode_[nidx], best, fidx, nidx, + EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx, evaluator); } } diff --git a/src/tree/hist/expand_entry.h b/src/tree/hist/expand_entry.h new file mode 100644 index 000000000000..d0edfbd379a6 --- /dev/null +++ b/src/tree/hist/expand_entry.h @@ -0,0 +1,64 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ +#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ + +#include +#include "../param.h" + +namespace xgboost { +namespace tree { + +struct CPUExpandEntry { + int nid; + int depth; + SplitEntry split; + CPUExpandEntry() = default; + XGBOOST_DEVICE + CPUExpandEntry(int nid, int depth, SplitEntry split) + : nid(nid), depth(depth), split(std::move(split)) {} + CPUExpandEntry(int nid, int depth, float loss_chg) + : nid(nid), depth(depth) { + split.loss_chg = loss_chg; + } + + bool IsValid(const TrainParam& param, int num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) { + return false; + } + if (split.loss_chg < param.min_split_loss) { + return false; + } + if (param.max_depth > 0 && depth == param.max_depth) { + return false; + } + if (param.max_leaves > 0 && num_leaves == param.max_leaves) { + return false; + } + return true; + } + + float GetLossChange() const { return split.loss_chg; } + bst_node_t GetNodeId() const { return nid; } + + static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) { + if (param.max_depth > 0 && depth >= param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false; + return true; + } + + friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) { + os << "ExpandEntry: \n"; + os << "nidx: " << e.nid << "\n"; + os << "depth: " << e.depth << "\n"; + os << "loss: " << e.split.loss_chg << "\n"; + os << "left_sum: " << e.split.left_sum << "\n"; + os << "right_sum: " << e.split.right_sum << "\n"; + return os; + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_ diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 70c756e765e6..2352310d9b9e 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -11,6 +11,8 @@ #include "rabit/rabit.h" #include "xgboost/tree_model.h" #include "../../common/hist_util.h" +#include "../../data/gradient_index.h" +#include "../../common/observer.h" namespace xgboost { namespace tree { @@ -25,8 +27,9 @@ template class HistogramBuilder { common::GHistBuilder builder_; common::ParallelGHistBuilder buffer_; rabit::Reducer reducer_; - int32_t max_bin_ {-1}; + BatchParam param_; int32_t n_threads_ {-1}; + size_t n_batches_ {0}; // Whether XGBoost is running in distributed environment. bool is_distributed_ {false}; @@ -39,12 +42,12 @@ template class HistogramBuilder { * \param is_distributed Mostly used for testing to allow injecting parameters instead * of using global rabit variable. */ - void Reset(uint32_t total_bins, int32_t max_bin_per_feat, int32_t n_threads, - bool is_distributed = rabit::IsDistributed()) { + void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, + size_t n_batches, bool is_distributed) { CHECK_GE(n_threads, 1); n_threads_ = n_threads; - CHECK_GE(max_bin_per_feat, 2); - max_bin_ = max_bin_per_feat; + n_batches_ = n_batches; + param_ = p; hist_.Init(total_bins); hist_local_worker_.Init(total_bins); buffer_.Init(total_bins); @@ -53,45 +56,42 @@ template class HistogramBuilder { } template - void - BuildLocalHistograms(DMatrix *p_fmat, - std::vector nodes_for_explicit_hist_build, - common::RowSetCollection const &row_set_collection, - const std::vector &gpair_h) { + void BuildLocalHistograms( + size_t page_idx, + common::BlockedSpace2d space, + GHistIndexMatrix const &gidx, + std::vector const &nodes_for_explicit_hist_build, + common::RowSetCollection const &row_set_collection, + const std::vector &gpair_h) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); - - // create space of size (# rows in each node) - common::BlockedSpace2d space( - n_nodes, - [&](size_t node) { - const int32_t nid = nodes_for_explicit_hist_build[node].nid; - return row_set_collection[nid].Size(); - }, - 256); + CHECK_GT(n_nodes, 0); std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build[i].nid; target_hists[i] = hist_[nid]; } - buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + if (page_idx == 0) { + // FIXME: Handle different size of space. + buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + } // Parallel processing by nodes and data in each node - for (auto const &gmat : p_fmat->GetBatches( - BatchParam{GenericParameter::kCpuId, max_bin_})) { - common::ParallelFor2d( - space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { - const auto tid = static_cast(omp_get_thread_num()); - const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; - - auto start_of_row_set = row_set_collection[nid].begin; - auto rid_set = common::RowSetCollection::Elem( - start_of_row_set + r.begin(), start_of_row_set + r.end(), nid); - builder_.template BuildHist( - gpair_h, rid_set, gmat, - buffer_.GetInitializedHist(tid, nid_in_set)); - }); - } + common::ParallelFor2d( + space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { + const auto tid = static_cast(omp_get_thread_num()); + const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; + auto elem = row_set_collection[nid]; + auto start_of_row_set = std::min(r.begin(), elem.Size()); + auto end_of_row_set = std::min(r.end(), elem.Size()); + auto rid_set = common::RowSetCollection::Elem( + elem.begin + start_of_row_set, elem.begin + end_of_row_set, nid); + auto hist = buffer_.GetInitializedHist(tid, nid_in_set); + if (rid_set.Size() != 0) { + builder_.template BuildHist(gpair_h, rid_set, gidx, + hist); + } + }); } void @@ -110,24 +110,36 @@ template class HistogramBuilder { } } - /* Main entry point of this class, build histogram for tree nodes. */ - void BuildHist(DMatrix *p_fmat, RegTree *p_tree, + /** Main entry point of this class, build histogram for tree nodes. */ + void BuildHist(size_t page_id, + common::BlockedSpace2d space, + GHistIndexMatrix const& gidx, RegTree *p_tree, common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, std::vector const &gpair) { int starting_index = std::numeric_limits::max(); int sync_count = 0; - this->AddHistRows(&starting_index, &sync_count, - nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, p_tree); - if (p_fmat->IsDense()) { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + if (page_id == 0) { + this->AddHistRows(&starting_index, &sync_count, + nodes_for_explicit_hist_build, + nodes_for_subtraction_trick, p_tree); + } + if (gidx.IsDense()) { + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } else { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } + + CHECK_GE(n_batches_, 1); + if (page_id != n_batches_ - 1) { + return; + } + if (is_distributed_) { this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, @@ -138,6 +150,25 @@ template class HistogramBuilder { sync_count); } } + /** same as the other build hist but handles only single batch data (in-core) */ + void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree, + common::RowSetCollection const &row_set_collection, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + std::vector const &gpair) { + const size_t n_nodes = nodes_for_explicit_hist_build.size(); + // create space of size (# rows in each node) + common::BlockedSpace2d space( + n_nodes, + [&](size_t nidx_in_set) { + const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid; + return row_set_collection[nidx].Size(); + }, + 256); + this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, + nodes_for_explicit_hist_build, nodes_for_subtraction_trick, + gpair); + } void SyncHistogramDistributed( RegTree *p_tree, diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h new file mode 100644 index 000000000000..2fbee28c423b --- /dev/null +++ b/src/tree/hist/param.h @@ -0,0 +1,23 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#ifndef XGBOOST_TREE_HIST_PARAM_H_ +#define XGBOOST_TREE_HIST_PARAM_H_ +#include "xgboost/parameter.h" + +namespace xgboost { +namespace tree { +// training parameters specific to this algorithm +struct CPUHistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram; + // declare parameters + DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; +} // namespace tree +} // namespace xgboost + +#endif // XGBOOST_TREE_HIST_PARAM_H_ diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index a619713e043a..6cfc95330c02 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -34,6 +34,7 @@ DMLC_REGISTRY_LINK_TAG(updater_refresh); DMLC_REGISTRY_LINK_TAG(updater_prune); DMLC_REGISTRY_LINK_TAG(updater_quantile_hist); DMLC_REGISTRY_LINK_TAG(updater_histmaker); +DMLC_REGISTRY_LINK_TAG(updater_approx); DMLC_REGISTRY_LINK_TAG(updater_sync); #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(updater_gpu_hist); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc new file mode 100644 index 000000000000..60ce8b380d12 --- /dev/null +++ b/src/tree/updater_approx.cc @@ -0,0 +1,371 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#include +#include +#include + +#include "xgboost/tree_updater.h" +#include "xgboost/base.h" +#include "xgboost/json.h" + +#include "hist/evaluate_splits.h" +#include "hist/histogram.h" +#include "hist/param.h" + +#include "../common/random.h" +#include "../data/gradient_index.h" + +#include "constraints.h" +#include "driver.h" +#include "param.h" +#include "updater_approx.h" + +namespace xgboost { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_approx); +template class GloablApproxBuilder { + protected: + TrainParam param_; + std::shared_ptr col_sampler_; + HistEvaluator evaluator_; + HistogramBuilder histogram_builder_; + GenericParameter const* ctx_; + + std::vector partitioner_; + RegTree* p_last_tree_ {nullptr}; + common::Monitor* monitor_; + size_t n_batches_ {0}; + common::HistogramCuts feature_values_; + + public: + void InitData(DMatrix *p_fmat, common::Span hess) { + monitor_->Start(__func__); + n_batches_ = 0; + int32_t n_total_bins = 0; + partitioner_.clear(); + // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess, true})) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + feature_values_ = page.cut; + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + partitioner_.emplace_back(page.Size(), page.base_rowid); + n_batches_++; + } + + histogram_builder_.Reset( + n_total_bins, BatchParam{GenericParameter::kCpuId, param_.max_bin, hess}, + ctx_->Threads(), n_batches_, rabit::IsDistributed()); + monitor_->Stop(__func__); + } + + CPUExpandEntry InitRoot(DMatrix *p_fmat, + std::vector const &gpair, + common::Span hess, RegTree *p_tree) { + monitor_->Start(__func__); + CPUExpandEntry best; + best.nid = RegTree::kRoot; + best.depth = 0; + GradStats root_sum; + for (auto const &g : gpair) { + root_sum.Add(g); + } + rabit::Allreduce(reinterpret_cast(&root_sum), 2); + std::vector nodes{best}; + size_t i = 0; + auto space = this->ConstructHistSpace(nodes); + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, + partitioner_.at(i).Partitions(), nodes, {}, + gpair); + i++; + } + + auto weight = evaluator_.InitRoot(root_sum); + p_tree->Stat(RegTree::kRoot).sum_hess = root_sum.GetHess(); + p_tree->Stat(RegTree::kRoot).base_weight = weight; + (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); + + auto const &histograms = histogram_builder_.Histogram(); + evaluator_.EvaluateSplits(histograms, feature_values_, *p_tree, &nodes); + monitor_->Stop(__func__); + + return nodes.front(); + } + + void UpdatePredictionCache(const DMatrix *data, + VectorView out_preds) { + monitor_->Start(__func__); + // Caching prediction seems redundant for approx tree method, as sketching takes up + // majority of training time. + CHECK_EQ(out_preds.Size(), data->Info().num_row_); + CHECK(p_last_tree_); + + size_t n_nodes = p_last_tree_->GetNodes().size(); + + auto evaluator = evaluator_.Evaluator(); + auto const& tree = *p_last_tree_; + auto const& snode = evaluator_.Stats(); + for (auto &part : partitioner_) { + CHECK_EQ(part.Size(), n_nodes); + common::BlockedSpace2d space( + part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); + common::ParallelFor2d( + space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) { + if (tree[nidx].IsLeaf()) { + const auto rowset = part[nidx]; + auto const &stats = snode.at(nidx); + auto leaf_value = + evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * + param_.learning_rate; + for (const size_t *it = rowset.begin + r.begin(); + it < rowset.begin + r.end(); ++it) { + out_preds[*it] += leaf_value; + } + } + }); + } + monitor_->Stop(__func__); + } + + // Construct a work space for building histogram. Eventually we should move this + // function into histogram builder once hist tree method supports external memory. + common::BlockedSpace2d + ConstructHistSpace(std::vector const &nodes_to_build) const { + std::vector partition_size(nodes_to_build.size(), 0); + for (auto const &partition : partitioner_) { + size_t k = 0; + for (auto node : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + partition_size[k] = std::max(partition_size[k], n_rows_in_node); + k++; + } + } + common::BlockedSpace2d space{ + nodes_to_build.size(), + [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; + return space; + } + + void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, + std::vector const &valid_candidates, + std::vector const &gpair, + common::Span hess) { + std::vector nodes_to_build; + std::vector nodes_to_sub; + + for (auto const &c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = + c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build.push_back( + CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}); + nodes_to_sub.push_back( + CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}); + } + + size_t i = 0; + auto space = this->ConstructHistSpace(nodes_to_build); + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess})) { + histogram_builder_.BuildHist(i, space, page, p_tree, + partitioner_.at(i).Partitions(), + nodes_to_build, nodes_to_sub, gpair); + i++; + } + auto histograms = histogram_builder_.Histogram(); + } + + public: + explicit GloablApproxBuilder( + TrainParam param, MetaInfo const &info, GenericParameter const *ctx, + std::shared_ptr column_sampler, + common::Monitor *monitor) + : param_{std::move(param)}, col_sampler_{std::move(column_sampler)}, + evaluator_{param_, info, ctx->Threads(), col_sampler_}, ctx_{ctx}, + monitor_{monitor} {} + + void UpdateTree(RegTree *p_tree, MetaInfo const& info, + std::vector const &gpair, + common::Span hess, + DMatrix* p_fmat) { + p_last_tree_ = p_tree; + this->InitData(p_fmat, hess); + + Driver driver( + static_cast(param_.grow_policy)); + auto &tree = *p_tree; + driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); + auto num_leaves = 1; + auto expand_set = driver.Pop(); + + while (!expand_set.empty()) { + // candidates that can further splited. + std::vector valid_candidates; + // candidaates that can be applied. + std::vector applied; + for (auto const& candidate : expand_set) { + if (!candidate.IsValid(param_, num_leaves)) { + continue; + } + evaluator_.ApplyTreeSplit(candidate, p_tree); + applied.push_back(candidate); + num_leaves++; + int left_child_nidx = tree[candidate.nid].LeftChild(); + if (CPUExpandEntry::ChildIsValid( + param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { + valid_candidates.emplace_back(candidate); + } + } + size_t i = 0; + for (auto const &page : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin, hess})) { + partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); + i++; + } + + std::vector best_splits; + if (!valid_candidates.empty()) { + this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess); + for (auto const& candidate : valid_candidates) { + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + CPUExpandEntry l_best{ + left_child_nidx, tree.GetDepth(left_child_nidx), {}}; + CPUExpandEntry r_best{ + right_child_nidx, tree.GetDepth(right_child_nidx), {}}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + auto const &histograms = histogram_builder_.Histogram(); + evaluator_.EvaluateSplits(histograms, feature_values_, *p_tree, &best_splits); + } + driver.Push(best_splits.begin(), best_splits.end()); + expand_set = driver.Pop(); + } + } +}; + +class GlobalApproxUpdater : public TreeUpdater { + TrainParam param_; + common::Monitor monitor_; + CPUHistMakerTrainParam hist_param_; + + std::unique_ptr> f32_impl_; + std::unique_ptr> f64_impl_; + DMatrix *cached_{nullptr}; + std::shared_ptr column_sampler_ = + std::make_shared(); + + public: + GlobalApproxUpdater() { + monitor_.Init(__func__); + } + + void Configure(const Args& args) override { + param_.UpdateAllowUnknown(args); + hist_param_.UpdateAllowUnknown(args); + } + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("hist_param"), &this->hist_param_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["hist_param"] = ToJson(hist_param_); + } + + void InitData(TrainParam const ¶m, HostDeviceVector *gpair, + std::vector *sampled) { + auto const &h_gpair = gpair->HostVector(); + sampled->resize(h_gpair.size()); + std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); + auto &rnd = common::GlobalRandom(); + if (param.subsample != 1.0) { + CHECK(param.sampling_method != TrainParam::kGradientBased) + << "Gradient based sampling is not supported for approx tree method."; + std::bernoulli_distribution coin_flip(param.subsample); + std::transform(sampled->begin(), sampled->end(), sampled->begin(), + [&](GradientPair &g) { + if (coin_flip(rnd)) { + return g; + } else { + return GradientPair{}; + } + }); + } + } + + char const *Name() const override { return "grow_global_approx_histmaker"; } + + void Update(HostDeviceVector *gpair, DMatrix *m, + const std::vector &trees) override { + float lr = param_.learning_rate; + param_.learning_rate = lr / trees.size(); + + if (hist_param_.single_precision_histogram) { + f32_impl_ = std::make_unique>( + param_, m->Info(), tparam_, column_sampler_, &monitor_); + } else { + f64_impl_ = std::make_unique>( + param_, m->Info(), tparam_, column_sampler_, &monitor_); + } + + std::vector h_gpair; + InitData(param_, gpair, &h_gpair); + std::vector hess(h_gpair.size()); + std::transform(h_gpair.begin(), h_gpair.end(), hess.begin(), + [](auto g) { return g.GetHess(); }); + + cached_ = m; + auto const &info = m->Info(); + + for (auto p_tree : trees) { + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdateTree(p_tree, info, h_gpair, hess, m); + } else { + this->f64_impl_->UpdateTree(p_tree, info, h_gpair, hess, m); + } + } + param_.learning_rate = lr; + } + + bool + UpdatePredictionCache(const DMatrix *data, + VectorView out_preds) override { + if (data != cached_) { return false; } + + if (hist_param_.single_precision_histogram) { + this->f32_impl_->UpdatePredictionCache(data, out_preds); + } else { + this->f64_impl_->UpdatePredictionCache(data, out_preds); + } + return true; + } +}; + +DMLC_REGISTRY_FILE_TAG(grow_global_approx_histmaker); + +XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_approx_histmaker") + .describe("Tree constructor that uses approximate histogram construction " + "for each node.") + .set_body([]() { return new GlobalApproxUpdater(); }); +} // namespace tree +} // namespace xgboost diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h new file mode 100644 index 000000000000..6e8a9f651b47 --- /dev/null +++ b/src/tree/updater_approx.h @@ -0,0 +1,152 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#ifndef XGBOOST_TREE_UPDATER_APPROX_H_ +#define XGBOOST_TREE_UPDATER_APPROX_H_ + +#include +#include +#include + +#include "xgboost/tree_updater.h" +#include "xgboost/json.h" +#include "constraints.h" + +#include "../common/random.h" +#include "../common/partition_builder.h" + +#include "hist/param.h" +#include "hist/evaluate_splits.h" +#include "hist/evaluate_splits.h" +#include "hist/expand_entry.h" + +#include "driver.h" +#include "param.h" + +namespace xgboost { +namespace tree { +class ApproxRowPartitioner { + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + common::RowSetCollection row_set_collection_; + + public: + bst_row_t base_rowid = 0; + + static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, + GHistIndexMatrix const &index, + std::vector const &cut_ptrs, + std::vector const &cut_values) { + int32_t gidx = -1; + auto const& row_ptr = index.row_ptr; + auto get_row_ptr = [&](size_t ridx) { + return row_ptr.at(ridx - index.base_rowid); + }; + + if (index.IsDense()) { + gidx = index.index[get_row_ptr(ridx) + fidx]; + } else { + auto begin = get_row_ptr(ridx); + auto end = get_row_ptr(ridx + 1); + auto f_begin = cut_ptrs[fidx]; + auto f_end = cut_ptrs[fidx + 1]; + gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); + } + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + return cut_values[gidx]; + } + + public: + void UpdatePosition(GenericParameter const *ctx, + GHistIndexMatrix const &index, + std::vector const &candidates, + RegTree const *p_tree) { + size_t n_nodes = candidates.size(); + + auto const& cut_values = index.cut.Values(); + auto const& cut_ptrs = index.cut.Ptrs(); + + common::BlockedSpace2d space{n_nodes, + [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + int32_t nid = candidate.nid; + return row_set_collection_[nid].Size(); + }, + kPartitionBlockSize}; + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = + size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + common::ParallelFor2d( + space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + auto fidx = candidate.split.SplitIndex(); + const size_t task_id = + partition_builder_.GetTaskIdx(node_in_set, r.begin()); + partition_builder_.AllocateForTask(task_id); + partition_builder_.PartitionRange( + node_in_set, nid, r, fidx, &row_set_collection_, + [&](size_t row_id) { + auto cut_value = + SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values); + if (std::isnan(cut_value)) { + return candidate.split.DefaultLeft(); + } + return cut_value <= candidate.split.split_value; + }); + }); + + partition_builder_.CalculateRowOffsets(); + common::ParallelFor2d( + space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + partition_builder_.MergeToArray( + node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + for (size_t i = 0; i < candidates.size(); ++i) { + auto const& candidate = candidates[i]; + auto nidx = candidate.nid; + auto n_left = partition_builder_.GetNLeftElems(i); + auto n_right = partition_builder_.GetNRightElems(i); + CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size()); + bst_node_t left_nidx = (*p_tree)[nidx].LeftChild(); + bst_node_t right_nidx = (*p_tree)[nidx].RightChild(); + row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, + n_right); + } + } + + auto const& Partitions() const { return row_set_collection_; } + + auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } + auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } + + size_t Size() const { + return std::distance(row_set_collection_.begin(), + row_set_collection_.end()); + } + + ApproxRowPartitioner() = default; + explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid) + : base_rowid{_base_rowid} { + row_set_collection_.Clear(); + auto p_positions = row_set_collection_.Data(); + p_positions->resize(num_row); + std::iota(p_positions->begin(), p_positions->end(), base_rowid); + row_set_collection_.Init(); + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_APPROX_H_ diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3824691948cf..400fb826ee3c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -695,7 +695,7 @@ struct GPUHistMakerDevice { int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { + num_leaves)) { monitor.Start("UpdatePosition"); this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index bc894b4646b6..2eec9f0653bd 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -124,18 +124,23 @@ template void QuantileHistMaker::Builder::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, int *num_leaves, std::vector *expand) { - CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f); + CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); nodes_for_explicit_hist_build_.clear(); nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.push_back(node); - this->histogram_builder_->BuildHist(p_fmat, p_tree, row_set_collection_, - nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t page_id = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + page_id, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); + ++page_id; + } { - auto nid = CPUExpandEntry::kRootNid; + auto nid = RegTree::kRoot; GHistRowT hist = this->histogram_builder_->Histogram()[nid]; GradientPairT grad_stat; if (data_layout_ == DataLayout::kDenseDataZeroBased || @@ -170,7 +175,8 @@ void QuantileHistMaker::Builder::InitRoot( builder_monitor_.Start("EvaluateSplits"); for (auto const &gmat : p_fmat->GetBatches( BatchParam{GenericParameter::kCpuId, param_.max_bin})) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat, *p_tree, &entries); + evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries); + break; } builder_monitor_.Stop("EvaluateSplits"); node = entries.front(); @@ -259,9 +265,15 @@ void QuantileHistMaker::Builder::ExpandTree( SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); if (depth < param_.max_depth) { - this->histogram_builder_->BuildHist( - p_fmat, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t i = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + i, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, + gpair_h); + ++i; + } } else { int starting_index = std::numeric_limits::max(); int sync_count = 0; @@ -271,7 +283,7 @@ void QuantileHistMaker::Builder::ExpandTree( } builder_monitor_.Start("EvaluateSplits"); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat, + evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut, *p_tree, &nodes_to_evaluate); builder_monitor_.Stop("EvaluateSplits"); @@ -431,7 +443,9 @@ void QuantileHistMaker::Builder::InitData( }); } exc.Rethrow(); - this->histogram_builder_->Reset(nbins, param_.max_bin, this->nthread_); + this->histogram_builder_->Reset( + nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, + this->nthread_, 1, rabit::IsDistributed()); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 69e42b90db44..9654ab00a7c0 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -23,6 +23,9 @@ #include "hist/evaluate_splits.h" #include "hist/histogram.h" +#include "hist/expand_entry.h" +#include "hist/param.h" + #include "constraints.h" #include "./param.h" #include "./driver.h" @@ -89,51 +92,6 @@ using xgboost::common::GHistBuilder; using xgboost::common::ColumnMatrix; using xgboost::common::Column; -// training parameters specific to this algorithm -struct CPUHistMakerTrainParam - : public XGBoostParameter { - bool single_precision_histogram = false; - // declare parameters - DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( - "Use single precision to build histograms."); - } -}; - -/* tree growing policies */ -struct CPUExpandEntry { - static const int kRootNid = 0; - static const int kEmptyNid = -1; - int nid; - int depth; - SplitEntry split; - - CPUExpandEntry() = default; - CPUExpandEntry(int nid, int depth, bst_float loss_chg) - : nid(nid), depth(depth) { - split.loss_chg = loss_chg; - } - - bool IsValid(TrainParam const ¶m, int32_t num_leaves) const { - bool invalid = split.loss_chg <= kRtEps || - (param.max_depth > 0 && this->depth == param.max_depth) || - (param.max_leaves > 0 && num_leaves == param.max_leaves); - return !invalid; - } - - bst_float GetLossChange() const { - return split.loss_chg; - } - - int GetNodeId() const { - return nid; - } - - int GetDepth() const { - return depth; - } -}; - /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 4bdf34ab2f66..2c19b9e58c9b 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -13,7 +13,8 @@ TEST(GradientIndex, ExternalMemory) { std::unique_ptr dmat = CreateSparsePageDMatrix(10000); std::vector base_rowids; std::vector hessian(dmat->Info().num_row_, 1); - for (auto const& page : dmat->GetBatches({0, 64, hessian})) { + for (auto const &page : dmat->GetBatches( + {GenericParameter::kCpuId, 64, hessian})) { base_rowids.push_back(page.base_rowid); } size_t i = 0; diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 9255bf2c32dc..00fa56278a0c 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -35,7 +35,7 @@ TEST(GBTree, SelectTreeMethod) { gbtree.Configure(args); auto const& tparam = gbtree.GetTrainParam(); gbtree.Configure({{"tree_method", "approx"}}); - ASSERT_EQ(tparam.updater_seq, "grow_histmaker,prune"); + ASSERT_EQ(tparam.updater_seq, "grow_global_approx_histmaker"); gbtree.Configure({{"tree_method", "exact"}}); ASSERT_EQ(tparam.updater_seq, "grow_colmaker,prune"); gbtree.Configure({{"tree_method", "hist"}}); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index c9228edf992d..cb0171269305 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -58,7 +58,7 @@ template void TestEvaluateSplits() { entries.front().depth = 0; evaluator.InitRoot(GradStats{total_gpair}); - evaluator.EvaluateSplits(hist, gmat, tree, &entries); + evaluator.EvaluateSplits(hist, gmat.cut, tree, &entries); auto best_loss_chg = evaluator.Evaluator().CalcSplitGain( diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index a75ce70d4843..1ccd1529a207 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -35,8 +35,9 @@ void TestAddHistRows(bool is_distributed) { nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); HistogramBuilder histogram_builder; - histogram_builder.Reset(gmat.cut.TotalBins(), kMaxBins, omp_get_max_threads(), - is_distributed); + histogram_builder.Reset(gmat.cut.TotalBins(), + {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); histogram_builder.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -81,7 +82,8 @@ void TestSyncHist(bool is_distributed) { HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RowSetCollection row_set_collection_; { @@ -247,7 +249,8 @@ void TestBuildHistogram(bool is_distributed) { bst_node_t nid = 0; HistogramBuilder histogram; - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RegTree tree; @@ -258,11 +261,14 @@ void TestBuildHistogram(bool is_distributed) { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection_.Init(); - CPUExpandEntry node(CPUExpandEntry::kRootNid, tree.GetDepth(0), 0.0f); + CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); std::vector nodes_for_explicit_hist_build_; nodes_for_explicit_hist_build_.push_back(node); - histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_, - nodes_for_explicit_hist_build_, {}, gpair); + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, kMaxBins})) { + histogram.BuildHist(0, gidx, &tree, row_set_collection_, + nodes_for_explicit_hist_build_, {}, gpair); + } // Check if number of histogram bins is correct ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back()); diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc new file mode 100644 index 000000000000..839c0ada4f77 --- /dev/null +++ b/tests/cpp/tree/test_approx.cc @@ -0,0 +1,129 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include "../helpers.h" +#include "../../../src/tree/updater_approx.h" + +namespace xgboost { +namespace tree { +TEST(Approx, Partitioner) { + size_t n_samples = 1024, n_features = 1, base_rowid = 0; + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + ASSERT_EQ(partitioner.base_rowid, base_rowid); + ASSERT_EQ(partitioner.Size(), 1); + ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); + + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + GenericParameter ctx; + ctx.InitAllowUnknown(Args{}); + std::vector candidates{{0, 0, 0.4}}; + + for (auto const &page : + Xy->GetBatches({GenericParameter::kCpuId, 64})) { + bst_feature_t split_ind = 0; + { + auto min_value = page.cut.MinValues()[split_ind]; + RegTree tree; + tree.ExpandNode( + /*nid=*/0, /*split_index=*/0, /*split_value=*/min_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + candidates.front().split.split_value = min_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + ASSERT_EQ(partitioner.Size(), 3); + ASSERT_EQ(partitioner[1].Size(), 0); + ASSERT_EQ(partitioner[2].Size(), n_samples); + } + { + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + auto ptr = page.cut.Ptrs()[split_ind + 1]; + float split_value = page.cut.Values().at(ptr / 2); + RegTree tree; + tree.ExpandNode( + /*nid=*/RegTree::kRoot, /*split_index=*/split_ind, + /*split_value=*/split_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + auto left_nidx = tree[RegTree::kRoot].LeftChild(); + candidates.front().split.split_value = split_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + + auto elem = partitioner[left_nidx]; + ASSERT_LT(elem.Size(), n_samples); + ASSERT_GT(elem.Size(), 1); + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_LE(value, split_value); + } + auto right_nidx = tree[RegTree::kRoot].RightChild(); + elem = partitioner[right_nidx]; + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_GT(value, split_value) << *it; + } + } + } +} + +TEST(Approx, PredictionCache) { + size_t n_samples = 2048, n_features = 13; + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + + { + GenericParameter ctx; + ctx.InitAllowUnknown(Args{}); + std::unique_ptr approx{ + TreeUpdater::Create("grow_global_approx_histmaker", &ctx)}; + RegTree tree; + std::vector trees{&tree}; + auto gpair = GenerateRandomGradients(n_samples); + approx->Configure(Args{{"max_bin", "64"}}); + approx->Update(&gpair, Xy.get(), trees); + HostDeviceVector out_prediction_cached; + out_prediction_cached.Resize(n_samples); + MatrixView m(&out_prediction_cached, {n_samples, 1}, + GenericParameter::kCpuId); + VectorView v(m, 0); + ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), v)); + } + + std::unique_ptr learner{Learner::Create({Xy})}; + learner->SetParam("tree_method", "approx"); + learner->SetParam("nthread", "0"); + learner->Configure(); + + for (size_t i = 0; i < 8; ++i) { + learner->UpdateOneIter(i, Xy); + } + + HostDeviceVector out_prediction_cached; + learner->Predict(Xy, false, &out_prediction_cached, 0, 0); + + Json model{Object()}; + learner->SaveModel(&model); + + HostDeviceVector out_prediction; + { + std::unique_ptr learner{Learner::Create({Xy})}; + learner->LoadModel(model); + learner->Predict(Xy, false, &out_prediction, 0, 0); + } + + auto const h_predt_cached = out_prediction_cached.ConstHostSpan(); + auto const h_predt = out_prediction.ConstHostSpan(); + + ASSERT_EQ(h_predt.size(), h_predt_cached.size()); + for (size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps); + } +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_tree_policy.cc b/tests/cpp/tree/test_tree_policy.cc index 68a720a8fba6..65dc975f2319 100644 --- a/tests/cpp/tree/test_tree_policy.cc +++ b/tests/cpp/tree/test_tree_policy.cc @@ -61,7 +61,7 @@ class TestGrowPolicy : public ::testing::Test { } }; -TEST_F(TestGrowPolicy, DISABLED_Approx) { +TEST_F(TestGrowPolicy, Approx) { this->TestTreeGrowPolicy("approx", "depthwise"); this->TestTreeGrowPolicy("approx", "lossguide"); } diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 343eff97b984..035422854c78 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1207,13 +1207,17 @@ def test_feature_weights(self, client: "Client") -> None: for i in range(kCols): fw[i] *= float(i) fw = da.from_array(fw) - poly_increasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_increasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) fw = da.from_array(fw) - poly_decreasing = run_feature_weights(X, y, fw, model=xgb.dask.DaskXGBRegressor) + poly_decreasing = run_feature_weights( + X, y, fw, "approx", model=xgb.dask.DaskXGBRegressor + ) # Approxmated test, this is dependent on the implementation of random # number generator in std library. diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 75020eec16c7..fd94e27ead7e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1091,10 +1091,10 @@ def test_pandas_input(): np.array([0, 1])) -def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): +def run_feature_weights(X, y, fw, tree_method, model=xgb.XGBRegressor): with TemporaryDirectory() as tmpdir: colsample_bynode = 0.5 - reg = model(tree_method='hist', colsample_bynode=colsample_bynode) + reg = model(tree_method=tree_method, colsample_bynode=colsample_bynode) reg.fit(X, y, feature_weights=fw) model_path = os.path.join(tmpdir, 'model.json') @@ -1129,7 +1129,8 @@ def run_feature_weights(X, y, fw, model=xgb.XGBRegressor): return w -def test_feature_weights(): +@pytest.mark.parametrize("tree_method", ["approx", "hist"]) +def test_feature_weights(tree_method): kRows = 512 kCols = 64 X = rng.randn(kRows, kCols) @@ -1138,12 +1139,12 @@ def test_feature_weights(): fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(i) - poly_increasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_increasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) fw = np.ones(shape=(kCols,)) for i in range(kCols): fw[i] *= float(kCols - i) - poly_decreasing = run_feature_weights(X, y, fw, xgb.XGBRegressor) + poly_decreasing = run_feature_weights(X, y, fw, tree_method, xgb.XGBRegressor) # Approxmated test, this is dependent on the implementation of random # number generator in std library.