diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 7a245ebfe5b0..1241ced409cd 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -38,6 +38,8 @@ #include "../src/data/sparse_page_raw_format.cc" #include "../src/data/ellpack_page.cc" #include "../src/data/gradient_index.cc" +#include "../src/data/gradient_index_page_source.cc" +#include "../src/data/gradient_index_format.cc" #include "../src/data/sparse_page_dmatrix.cc" #include "../src/data/proxy_dmatrix.cc" diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 3a403d541b56..cd000e371332 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2015-2021 by Contributors * \file data.h * \brief The input data structure of xgboost. * \author Tianqi Chen @@ -214,12 +214,27 @@ struct BatchParam { int gpu_id; /*! \brief Maximum number of bins per feature for histograms. */ int max_bin{0}; + /*! \brief Hessian, used for sketching with future approx implementation. */ + common::Span hess; + /*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */ + bool regen {false}; + BatchParam() = default; BatchParam(int32_t device, int32_t max_bin) : gpu_id{device}, max_bin{max_bin} {} + /** + * \brief Get batch with sketch weighted by hessian. The batch will be regenerated if + * the span is changed, so caller should keep the span for each iteration. + */ + BatchParam(int32_t device, int32_t max_bin, common::Span hessian, + bool regenerate = false) + : gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {} bool operator!=(const BatchParam& other) const { - return gpu_id != other.gpu_id || max_bin != other.max_bin; + if (hess.empty() && other.hess.empty()) { + return gpu_id != other.gpu_id || max_bin != other.max_bin; + } + return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data(); } }; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 04dd3e3d92da..fc49148b0765 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -111,7 +111,7 @@ class HistogramCuts { }; inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, - std::vector const &hessian = {}) { + Span const hessian = {}) { HistogramCuts out; auto const& info = m->Info(); const auto threads = omp_get_max_threads(); @@ -136,7 +136,7 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, return out; } -enum BinTypeSize { +enum BinTypeSize : uint32_t { kUint8BinsTypeSize = 1, kUint16BinsTypeSize = 2, kUint32BinsTypeSize = 4 @@ -207,6 +207,13 @@ struct Index { return data_.end(); } + std::vector::iterator begin() { // NOLINT + return data_.begin(); + } + std::vector::iterator end() { // NOLINT + return data_.end(); + } + private: static uint32_t GetValueFromUint8(void *t, size_t i) { return reinterpret_cast(t)[i]; diff --git a/src/common/quantile.cc b/src/common/quantile.cc index fcbd76e52b51..a50602b152c0 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -94,26 +94,26 @@ std::vector HostSketchContainer::LoadBalance( namespace { // Function to merge hessian and sample weights std::vector MergeWeights(MetaInfo const &info, - std::vector const &hessian, + Span const hessian, bool use_group, int32_t n_threads) { CHECK_EQ(hessian.size(), info.num_row_); std::vector results(hessian.size()); auto const &group_ptr = info.group_ptr_; + auto const& weights = info.weights_.HostVector(); + auto get_weight = [&](size_t i) { return weights.empty() ? 1.0f : weights[i]; }; if (use_group) { - auto const &group_weights = info.weights_.HostVector(); CHECK_GE(group_ptr.size(), 2); CHECK_EQ(group_ptr.back(), hessian.size()); size_t cur_group = 0; for (size_t i = 0; i < hessian.size(); ++i) { - results[i] = hessian[i] * group_weights[cur_group]; + results[i] = hessian[i] * get_weight(cur_group); if (i == group_ptr[cur_group + 1]) { cur_group++; } } } else { - auto const &sample_weights = info.weights_.HostVector(); ParallelFor(hessian.size(), n_threads, Sched::Auto(), - [&](auto i) { results[i] = hessian[i] * sample_weights[i]; }); + [&](auto i) { results[i] = hessian[i] * get_weight(i); }); } return results; } @@ -141,7 +141,7 @@ std::vector UnrollGroupWeights(MetaInfo const &info) { } // anonymous namespace void HostSketchContainer::PushRowPage( - SparsePage const &page, MetaInfo const &info, std::vector const &hessian) { + SparsePage const &page, MetaInfo const &info, Span hessian) { monitor_.Start(__func__); bst_feature_t n_columns = info.num_col_; auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; diff --git a/src/common/quantile.h b/src/common/quantile.h index bdfe387c76dd..c72f5f39160f 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -760,7 +760,7 @@ class HostSketchContainer { /* \brief Push a CSR matrix. */ void PushRowPage(SparsePage const &page, MetaInfo const &info, - std::vector const &hessian = {}); + Span const hessian = {}); void MakeCuts(HistogramCuts* cuts); }; diff --git a/src/data/data.cc b/src/data/data.cc index 741a84ec29ab..a0504f4d5c8b 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -32,6 +32,7 @@ DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage> DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::CSCPage>); DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SortedCSCPage>); DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::EllpackPage>); +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::GHistIndexMatrix>); } // namespace dmlc namespace { @@ -1089,5 +1090,6 @@ namespace data { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format); +DMLC_REGISTRY_LINK_TAG(gradient_index_format); } // namespace data } // namespace xgboost diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 13a0a1766d32..2f54b91c9bbc 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -4,8 +4,9 @@ #include #include -#include "./ellpack_page.cuh" -#include "./sparse_page_writer.h" +#include "ellpack_page.cuh" +#include "sparse_page_writer.h" +#include "histogram_cut_format.h" namespace xgboost { namespace data { @@ -17,9 +18,9 @@ class EllpackPageRawFormat : public SparsePageFormat { public: bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { auto* impl = page->Impl(); - fi->Read(&impl->Cuts().cut_values_.HostVector()); - fi->Read(&impl->Cuts().cut_ptrs_.HostVector()); - fi->Read(&impl->Cuts().min_vals_.HostVector()); + if (!ReadHistogramCuts(&impl->Cuts(), fi)) { + return false; + } fi->Read(&impl->n_rows); fi->Read(&impl->is_dense); fi->Read(&impl->row_stride); @@ -33,12 +34,7 @@ class EllpackPageRawFormat : public SparsePageFormat { size_t Write(const EllpackPage& page, dmlc::Stream* fo) override { size_t bytes = 0; auto* impl = page.Impl(); - fo->Write(impl->Cuts().cut_values_.ConstHostVector()); - bytes += impl->Cuts().cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t); - fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector()); - bytes += impl->Cuts().cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t); - fo->Write(impl->Cuts().min_vals_.ConstHostVector()); - bytes += impl->Cuts().min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t); + bytes += WriteHistogramCuts(impl->Cuts(), fo); fo->Write(impl->n_rows); bytes += sizeof(impl->n_rows); fo->Write(impl->is_dense); diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index a43ac9881e07..eca45efc7463 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -32,7 +32,7 @@ class EllpackPageSource : public PageSourceIncMixIn { size_t row_stride, common::Span feature_types, std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), - is_dense_{is_dense}, row_stride_{row_stride}, param_{param}, + is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)}, feature_types_{feature_types}, cuts_{std::move(cuts)} { this->source_ = source; this->Fetch(); diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 7836d2a19469..f2e14882e80b 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -8,8 +8,125 @@ #include "../common/hist_util.h" namespace xgboost { -void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { - cut = common::SketchOnDMatrix(p_fmat, max_bins); + +void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t rbegin, + size_t prev_sum, uint32_t nbins, + int32_t n_threads) { + // The number of threads is pegged to the batch size. If the OMP + // block is parallelized on anything other than the batch/block size, + // it should be reassigned + const size_t batch_threads = + std::max(size_t(1), std::min(batch.Size(), + static_cast(n_threads))); + auto page = batch.GetView(); + common::MemStackAllocator partial_sums(batch_threads); + size_t *p_part = partial_sums.Get(); + + size_t block_size = batch.Size() / batch_threads; + + dmlc::OMPException exc; +#pragma omp parallel num_threads(batch_threads) + { +#pragma omp for + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { + exc.Run([&]() { + size_t ibegin = block_size * tid; + size_t iend = (tid == (batch_threads - 1) ? batch.Size() + : (block_size * (tid + 1))); + + size_t sum = 0; + for (size_t i = ibegin; i < iend; ++i) { + sum += page[i].size(); + row_ptr[rbegin + 1 + i] = sum; + } + }); + } + +#pragma omp single + { + exc.Run([&]() { + p_part[0] = prev_sum; + for (size_t i = 1; i < batch_threads; ++i) { + p_part[i] = p_part[i - 1] + row_ptr[rbegin + i * block_size]; + } + }); + } + +#pragma omp for + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { + exc.Run([&]() { + size_t ibegin = block_size * tid; + size_t iend = (tid == (batch_threads - 1) ? batch.Size() + : (block_size * (tid + 1))); + + for (size_t i = ibegin; i < iend; ++i) { + row_ptr[rbegin + 1 + i] += p_part[tid]; + } + }); + } + } + exc.Rethrow(); + + const size_t n_offsets = cut.Ptrs().size() - 1; + const size_t n_index = row_ptr[rbegin + batch.Size()]; + ResizeIndex(n_index, isDense_); + + CHECK_GT(cut.Values().size(), 0U); + + uint32_t *offsets = nullptr; + if (isDense_) { + index.ResizeOffset(n_offsets); + offsets = index.Offset(); + for (size_t i = 0; i < n_offsets; ++i) { + offsets[i] = cut.Ptrs()[i]; + } + } + + if (isDense_) { + common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); + if (curent_bin_size == common::kUint8BinsTypeSize) { + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + + } else if (curent_bin_size == common::kUint16BinsTypeSize) { + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + } else { + CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + } + + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is + not reduced */ + } else { + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [](auto idx, auto) { return idx; }); + } + + common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) { + for (int32_t tid = 0; tid < n_threads; ++tid) { + hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; + hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch + } + }); +} + +void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span hess) { + cut = common::SketchOnDMatrix(p_fmat, max_bins, hess); max_num_bins = max_bins; const int32_t nthread = omp_get_max_threads(); @@ -32,121 +149,35 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { this->isDense_ = isDense; for (const auto &batch : p_fmat->GetBatches()) { - // The number of threads is pegged to the batch size. If the OMP - // block is parallelized on anything other than the batch/block size, - // it should be reassigned - const size_t batch_threads = std::max( - size_t(1), - std::min(batch.Size(), static_cast(omp_get_max_threads()))); - auto page = batch.GetView(); - common::MemStackAllocator partial_sums(batch_threads); - size_t* p_part = partial_sums.Get(); - - size_t block_size = batch.Size() / batch_threads; - - dmlc::OMPException exc; - #pragma omp parallel num_threads(batch_threads) - { - #pragma omp for - for (omp_ulong tid = 0; tid < batch_threads; ++tid) { - exc.Run([&]() { - size_t ibegin = block_size * tid; - size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); - - size_t sum = 0; - for (size_t i = ibegin; i < iend; ++i) { - sum += page[i].size(); - row_ptr[rbegin + 1 + i] = sum; - } - }); - } - - #pragma omp single - { - exc.Run([&]() { - p_part[0] = prev_sum; - for (size_t i = 1; i < batch_threads; ++i) { - p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size]; - } - }); - } - - #pragma omp for - for (omp_ulong tid = 0; tid < batch_threads; ++tid) { - exc.Run([&]() { - size_t ibegin = block_size * tid; - size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); - - for (size_t i = ibegin; i < iend; ++i) { - row_ptr[rbegin + 1 + i] += p_part[tid]; - } - }); - } - } - exc.Rethrow(); - - const size_t n_offsets = cut.Ptrs().size() - 1; - const size_t n_index = row_ptr[rbegin + batch.Size()]; - ResizeIndex(n_index, isDense); - - CHECK_GT(cut.Values().size(), 0U); - - uint32_t* offsets = nullptr; - if (isDense) { - index.ResizeOffset(n_offsets); - offsets = index.Offset(); - for (size_t i = 0; i < n_offsets; ++i) { - offsets[i] = cut.Ptrs()[i]; - } - } - - if (isDense) { - common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); - if (curent_bin_size == common::kUint8BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - - } else if (curent_bin_size == common::kUint16BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - } else { - CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - } - - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is not reduced */ - } else { - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [](auto idx, auto) { return idx; }); - } - - common::ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) { - for (int32_t tid = 0; tid < nthread; ++tid) { - hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; - hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch - } - }); - + this->PushBatch(batch, rbegin, prev_sum, nbins, nthread); prev_sum = row_ptr[rbegin + batch.Size()]; rbegin += batch.Size(); } } +void GHistIndexMatrix::Init(SparsePage const &batch, + common::HistogramCuts const &cuts, + int32_t max_bins_per_feat, bool isDense, + int32_t n_threads) { + CHECK_GE(n_threads, 1); + base_rowid = batch.base_rowid; + isDense_ = isDense; + cut = cuts; + max_num_bins = max_bins_per_feat; + CHECK_EQ(row_ptr.size(), 0); + // The number of threads is pegged to the batch size. If the OMP + // block is parallelized on anything other than the batch/block size, + // it should be reassigned + row_ptr.resize(batch.Size() + 1, 0); + const uint32_t nbins = cut.Ptrs().back(); + hit_count.resize(nbins, 0); + hit_count_tloc_.resize(n_threads * nbins, 0); + + size_t rbegin = 0; + size_t prev_sum = 0; + + this->PushBatch(batch, rbegin, prev_sum, nbins, n_threads); +} void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index d42f596bc896..971e82d4f081 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -18,6 +18,9 @@ namespace xgboost { * index for CPU histogram. On GPU ellpack page is used. */ class GHistIndexMatrix { + void PushBatch(SparsePage const &batch, size_t rbegin, size_t prev_sum, + uint32_t nbins, int32_t n_threads); + public: /*! \brief row pointer to rows by element position */ std::vector row_ptr; @@ -29,12 +32,16 @@ class GHistIndexMatrix { common::HistogramCuts cut; DMatrix* p_fmat; size_t max_num_bins; + size_t base_rowid{0}; - GHistIndexMatrix(DMatrix* x, int32_t max_bin) { - this->Init(x, max_bin); + GHistIndexMatrix() = default; + GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span hess = {}) { + this->Init(x, max_bin, hess); } // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_num_bins); + void Init(DMatrix* p_fmat, int max_num_bins, common::Span hess); + void Init(SparsePage const &page, common::HistogramCuts const &cuts, + int32_t max_bins_per_feat, bool is_dense, int32_t n_threads); // specific method for sparse data as no possibility to reduce allocated memory template @@ -77,6 +84,11 @@ class GHistIndexMatrix { inline bool IsDense() const { return isDense_; } + void SetDense(bool is_dense) { isDense_ = is_dense; } + + bst_row_t Size() const { + return row_ptr.empty() ? 0 : row_ptr.size() - 1; + } private: std::vector hit_count_tloc_; diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc new file mode 100644 index 000000000000..19baeb406414 --- /dev/null +++ b/src/data/gradient_index_format.cc @@ -0,0 +1,107 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include "sparse_page_writer.h" +#include "gradient_index.h" +#include "histogram_cut_format.h" + +namespace xgboost { +namespace data { + +class GHistIndexRawFormat : public SparsePageFormat { + public: + bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override { + if (!ReadHistogramCuts(&page->cut, fi)) { + return false; + } + // indptr + fi->Read(&page->row_ptr); + // offset + using OffsetT = std::iterator_traitsindex.Offset())>::value_type; + std::vector offset; + if (!fi->Read(&offset)) { + return false; + } + page->index.ResizeOffset(offset.size()); + std::copy(offset.begin(), offset.end(), page->index.Offset()); + // data + std::vector data; + if (!fi->Read(&data)) { + return false; + } + page->index.Resize(data.size()); + std::copy(data.cbegin(), data.cend(), page->index.begin()); + // bin type + // Old gcc doesn't support reading from enum. + std::underlying_type_t uint_bin_type{0}; + if (!fi->Read(&uint_bin_type)) { + return false; + } + common::BinTypeSize size_type = + static_cast(uint_bin_type); + page->index.SetBinTypeSize(size_type); + // hit count + if (!fi->Read(&page->hit_count)) { + return false; + } + if (!fi->Read(&page->max_num_bins)) { + return false; + } + if (!fi->Read(&page->base_rowid)) { + return false; + } + bool is_dense = false; + if (!fi->Read(&is_dense)) { + return false; + } + page->SetDense(is_dense); + return true; + } + + size_t Write(GHistIndexMatrix const &page, dmlc::Stream *fo) override { + size_t bytes = 0; + bytes += WriteHistogramCuts(page.cut, fo); + // indptr + fo->Write(page.row_ptr); + bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) + + sizeof(uint64_t); + // offset + using OffsetT = std::iterator_traits::value_type; + std::vector offset(page.index.OffsetSize()); + std::copy(page.index.Offset(), + page.index.Offset() + page.index.OffsetSize(), offset.begin()); + fo->Write(offset); + bytes += page.index.OffsetSize() * sizeof(OffsetT) + sizeof(uint64_t); + // data + std::vector data(page.index.begin(), page.index.end()); + fo->Write(data); + bytes += data.size() * sizeof(decltype(data)::value_type) + sizeof(uint64_t); + // bin type + std::underlying_type_t uint_bin_type = + page.index.GetBinTypeSize(); + fo->Write(uint_bin_type); + bytes += sizeof(page.index.GetBinTypeSize()); + // hit count + fo->Write(page.hit_count); + bytes += + page.hit_count.size() * sizeof(decltype(page.hit_count)::value_type) + + sizeof(uint64_t); + // max_bins, base row, is_dense + fo->Write(page.max_num_bins); + bytes += sizeof(page.max_num_bins); + fo->Write(page.base_rowid); + bytes += sizeof(page.base_rowid); + fo->Write(page.IsDense()); + bytes += sizeof(page.IsDense()); + return bytes; + } +}; + +DMLC_REGISTRY_FILE_TAG(gradient_index_format); + +XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw) + .describe("Raw GHistIndex binary data format.") + .set_body([]() { return new GHistIndexRawFormat(); }); + +} // namespace data +} // namespace xgboost diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc new file mode 100644 index 000000000000..e35970bf3e4e --- /dev/null +++ b/src/data/gradient_index_page_source.cc @@ -0,0 +1,18 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#include "gradient_index_page_source.h" + +namespace xgboost { +namespace data { +void GradientIndexPageSource::Fetch() { + if (!this->ReadCache()) { + auto const& csr = source_->Page(); + this->page_.reset(new GHistIndexMatrix()); + CHECK_NE(cuts_.Values().size(), 0); + this->page_->Init(*csr, cuts_, max_bin_per_feat_, is_dense_, nthreads_); + this->WriteCache(); + } +} +} // namespace data +} // namespace xgboost diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h new file mode 100644 index 000000000000..db66a1cda02f --- /dev/null +++ b/src/data/gradient_index_page_source.h @@ -0,0 +1,37 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ +#define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ + +#include +#include + +#include "sparse_page_source.h" +#include "gradient_index.h" + +namespace xgboost { +namespace data { +class GradientIndexPageSource : public PageSourceIncMixIn { + common::HistogramCuts cuts_; + bool is_dense_; + int32_t max_bin_per_feat_; + + public: + GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, + size_t n_batches, std::shared_ptr cache, + BatchParam param, common::HistogramCuts cuts, + bool is_dense, int32_t max_bin_per_feat, + std::shared_ptr source) + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), + cuts_{std::move(cuts)}, is_dense_{is_dense}, max_bin_per_feat_{ + max_bin_per_feat} { + this->source_ = source; + this->Fetch(); + } + + void Fetch() final; +}; +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ diff --git a/src/data/histogram_cut_format.h b/src/data/histogram_cut_format.h new file mode 100644 index 000000000000..39961c4a2bbe --- /dev/null +++ b/src/data/histogram_cut_format.h @@ -0,0 +1,36 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ +#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ + +#include "../common/hist_util.h" + +namespace xgboost { +namespace data { +inline bool ReadHistogramCuts(common::HistogramCuts *cuts, dmlc::SeekStream *fi) { + if (!fi->Read(&cuts->cut_values_.HostVector())) { + return false; + } + if (!fi->Read(&cuts->cut_ptrs_.HostVector())) { + return false; + } + if (!fi->Read(&cuts->min_vals_.HostVector())) { + return false; + } + return true; +} + +inline size_t WriteHistogramCuts(common::HistogramCuts const &cuts, dmlc::Stream *fo) { + size_t bytes = 0; + fo->Write(cuts.cut_values_.ConstHostVector()); + bytes += cuts.cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t); + fo->Write(cuts.cut_ptrs_.ConstHostVector()); + bytes += cuts.cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t); + fo->Write(cuts.min_vals_.ConstHostVector()); + bytes += cuts.min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t); + return bytes; +} +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index a737c6d59071..44a8a3f8fe7c 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -94,10 +94,12 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par if (!(batch_param_ != BatchParam{})) { CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; } - if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) { + if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) { CHECK_GE(param.max_bin, 2); - gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin)); + CHECK_EQ(param.gpu_id, -1); + gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess)); batch_param_ = param; + CHECK_EQ(batch_param_.hess.data(), param.hess.data()); } auto begin_iter = BatchIterator( new SimpleBatchIteratorImpl(gradient_index_)); diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index e0502675ebb9..18c81a654ad3 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -43,7 +43,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p XGDMatrixCallbackNext *next, float missing, int32_t nthreads, std::string cache_prefix) : proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing}, - nthreads_{nthreads}, cache_prefix_{std::move(cache_prefix)} { + cache_prefix_{std::move(cache_prefix)} { + ctx_.nthread = nthreads; cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_; if (rabit::IsDistributed()) { cache_prefix_ += ("-r" + std::to_string(rabit::GetRank())); @@ -112,7 +113,7 @@ void SparsePageDMatrix::InitializeSparsePage() { DMatrixProxy *proxy = MakeProxy(proxy_); sparse_page_source_.reset(); // clear before creating new one to prevent conflicts. sparse_page_source_ = std::make_shared( - iter, proxy, this->missing_, this->nthreads_, this->info_.num_col_, + iter, proxy, this->missing_, this->ctx_.Threads(), this->info_.num_col_, this->n_batches_, cache_info_.at(id)); } @@ -132,7 +133,7 @@ BatchSet SparsePageDMatrix::GetColumnBatches() { this->InitializeSparsePage(); if (!column_source_) { column_source_ = std::make_shared( - this->missing_, this->nthreads_, this->Info().num_col_, + this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), sparse_page_source_); } else { column_source_->Reset(); @@ -147,7 +148,7 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { this->InitializeSparsePage(); if (!sorted_column_source_) { sorted_column_source_ = std::make_shared( - this->missing_, this->nthreads_, this->Info().num_col_, + this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), sparse_page_source_); } else { sorted_column_source_->Reset(); @@ -158,16 +159,41 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& param) { CHECK_GE(param.max_bin, 2); - // External memory is not support - if (!ghist_index_source_ || (param != batch_param_ && param != BatchParam{})) { - this->InitializeSparsePage(); - ghist_index_source_.reset(new GHistIndexMatrix{this, param.max_bin}); - batch_param_ = param; + if (param.hess.empty()) { + // hist method doesn't support full external memory implementation, so we concatenate + // all index here. + if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { + this->InitializeSparsePage(); + ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin}); + this->InitializeSparsePage(); + batch_param_ = param; + } + auto begin_iter = BatchIterator( + new SimpleBatchIteratorImpl(ghist_index_page_)); + return BatchSet(begin_iter); } + + auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(); - auto begin_iter = BatchIterator( - new SimpleBatchIteratorImpl(ghist_index_source_)); - return BatchSet(begin_iter); + if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) { + cache_info_.erase(id); + MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); + auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess); + this->InitializeSparsePage(); // reset after use. + + batch_param_ = param; + ghist_index_source_.reset(); + CHECK_NE(cuts.Values().size(), 0); + ghist_index_source_.reset(new GradientIndexPageSource( + this->missing_, this->ctx_.Threads(), this->Info().num_col_, + this->n_batches_, cache_info_.at(id), param, std::move(cuts), + this->IsDense(), param.max_bin, sparse_page_source_)); + } else { + CHECK(ghist_index_source_); + ghist_index_source_->Reset(); + } + auto begin_iter = BatchIterator(ghist_index_source_); + return BatchSet(BatchIterator(begin_iter)); } #if !defined(XGBOOST_USE_CUDA) diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 176cdc75b407..0ffc4c45a91b 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -31,7 +31,7 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& par auto ft = this->info_.feature_types.ConstDeviceSpan(); ellpack_page_source_.reset(); // release resources. ellpack_page_source_.reset(new EllpackPageSource( - this->missing_, this->nthreads_, this->Info().num_col_, + this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_)); } else { diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 3164a3ee3bb0..d1dfa54ccf15 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -8,6 +8,7 @@ #define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ #include +#include #include #include #include @@ -16,6 +17,7 @@ #include #include "ellpack_page_source.h" +#include "gradient_index_page_source.h" #include "sparse_page_source.h" namespace xgboost { @@ -67,7 +69,7 @@ class SparsePageDMatrix : public DMatrix { XGDMatrixCallbackNext *next_; float missing_; - int nthreads_; + GenericParameter ctx_; std::string cache_prefix_; uint32_t n_batches_ {0}; // sparse page is the source to other page types, we make a special member function. @@ -118,7 +120,8 @@ class SparsePageDMatrix : public DMatrix { std::shared_ptr ellpack_page_source_; std::shared_ptr column_source_; std::shared_ptr sorted_column_source_; - std::shared_ptr ghist_index_source_; + std::shared_ptr ghist_index_page_; // hist + std::shared_ptr ghist_index_source_; bool EllpackExists() const override { return static_cast(ellpack_page_source_); @@ -143,6 +146,7 @@ MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix, auto it = cache_info.find(id); if (it == cache_info.cend()) { cache_info[id].reset(new Cache{false, name, format}); + LOG(INFO) << "Make cache:" << name << std::endl; } return id; } diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index eafbaa652de4..91a6504fe86a 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -98,7 +98,12 @@ struct SparsePageFormatReg #define EllpackPageFmt SparsePageFormat #define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, EllpackPageFm, Name) + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, EllpackPageFmt, Name) + +#define GHistIndexPageFmt SparsePageFormat +#define XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(Name) \ + DMLC_REGISTRY_REGISTER(SparsePageFormatReg, \ + GHistIndexPageFmt, Name) } // namespace data } // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 843341c3631b..3b20e54a7c35 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -9,6 +9,7 @@ #include #include +#include #include "../../common/compressed_iterator.h" #include "../../common/random.h" @@ -185,10 +186,10 @@ GradientBasedSample UniformSampling::Sample(common::Span gpair, DM ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl const* page, size_t n_rows, - const BatchParam& batch_param, + BatchParam batch_param, float subsample) : original_page_(page), - batch_param_(batch_param), + batch_param_(std::move(batch_param)), subsample_(subsample), sample_row_index_(n_rows) {} @@ -259,10 +260,10 @@ GradientBasedSample GradientBasedSampling::Sample(common::Span gpa ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( EllpackPageImpl const* page, size_t n_rows, - const BatchParam& batch_param, + BatchParam batch_param, float subsample) : original_page_(page), - batch_param_(batch_param), + batch_param_(std::move(batch_param)), subsample_(subsample), threshold_(n_rows + 1, 0.0f), grad_sum_(n_rows, 0.0f), diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index bb578995e353..989d994f9e09 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -68,7 +68,7 @@ class ExternalMemoryUniformSampling : public SamplingStrategy { public: ExternalMemoryUniformSampling(EllpackPageImpl const* page, size_t n_rows, - const BatchParam& batch_param, + BatchParam batch_param, float subsample); GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; @@ -102,7 +102,7 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { public: ExternalMemoryGradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, - const BatchParam& batch_param, + BatchParam batch_param, float subsample); GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) override; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 7499293e7860..6cd66d21fcaa 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -209,7 +209,7 @@ struct GPUHistMakerDevice { tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), - batch_param(_batch_param) { + batch_param(std::move(_batch_param)) { sampler.reset(new GradientBasedSampler( page, _n_rows, batch_param, param.subsample, param.sampling_method)); if (!param.monotone_constraints.empty()) { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 9c946fc5f664..bc894b4646b6 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -69,13 +69,13 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr *gpair, DMatrix *dmat, const std::vector &trees) { - auto const &gmat = - *(dmat->GetBatches( - BatchParam{GenericParameter::kCpuId, param_.max_bin}) - .begin()); + auto it = dmat->GetBatches( + BatchParam{GenericParameter::kCpuId, param_.max_bin}) + .begin(); + auto p_gmat = it.Page(); if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); - column_matrix_.Init(gmat, param_.sparse_threshold); + column_matrix_.Init(*p_gmat, param_.sparse_threshold); updater_monitor_.Stop("GmatInitialization"); // A proper solution is puting cut matrix in DMatrix, see: // https://github.com/dmlc/xgboost/issues/5143 @@ -91,12 +91,12 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, if (!float_builder_) { this->SetBuilder(n_trees, &float_builder_, dmat); } - CallBuilderUpdate(float_builder_, gpair, dmat, gmat, trees); + CallBuilderUpdate(float_builder_, gpair, dmat, *p_gmat, trees); } else { if (!double_builder_) { SetBuilder(n_trees, &double_builder_, dmat); } - CallBuilderUpdate(double_builder_, gpair, dmat, gmat, trees); + CallBuilderUpdate(double_builder_, gpair, dmat, *p_gmat, trees); } param_.learning_rate = lr; diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc new file mode 100644 index 000000000000..4bdf34ab2f66 --- /dev/null +++ b/tests/cpp/data/test_gradient_index.cc @@ -0,0 +1,26 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include + +#include "../helpers.h" +#include "../../../src/data/gradient_index.h" + +namespace xgboost { +namespace data { +TEST(GradientIndex, ExternalMemory) { + std::unique_ptr dmat = CreateSparsePageDMatrix(10000); + std::vector base_rowids; + std::vector hessian(dmat->Info().num_row_, 1); + for (auto const& page : dmat->GetBatches({0, 64, hessian})) { + base_rowids.push_back(page.base_rowid); + } + size_t i = 0; + for (auto const& page : dmat->GetBatches()) { + ASSERT_EQ(base_rowids[i], page.base_rowid); + ++i; + } +} +} // namespace data +} // namespace xgboost diff --git a/tests/cpp/data/test_gradient_index_page_raw_format.cc b/tests/cpp/data/test_gradient_index_page_raw_format.cc new file mode 100644 index 000000000000..b24ee8770b8d --- /dev/null +++ b/tests/cpp/data/test_gradient_index_page_raw_format.cc @@ -0,0 +1,48 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include + +#include "../../../src/data/gradient_index.h" +#include "../../../src/data/sparse_page_source.h" +#include "../helpers.h" + +namespace xgboost { +namespace data { +TEST(GHistIndexPageRawFormat, IO) { + std::unique_ptr> format{ + CreatePageFormat("raw")}; + auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); + dmlc::TemporaryDirectory tmpdir; + std::string path = tmpdir.path + "/ghistindex.page"; + + { + std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + for (auto const &index : + m->GetBatches({GenericParameter::kCpuId, 256})) { + format->Write(index, fo.get()); + } + } + + GHistIndexMatrix page; + std::unique_ptr fi{ + dmlc::SeekStream::CreateForRead(path.c_str())}; + format->Read(&page, fi.get()); + + for (auto const &gidx : + m->GetBatches({GenericParameter::kCpuId, 256})) { + auto const &loaded = gidx; + ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs()); + ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues()); + ASSERT_EQ(loaded.cut.Values(), page.cut.Values()); + ASSERT_EQ(loaded.base_rowid, page.base_rowid); + ASSERT_EQ(loaded.IsDense(), page.IsDense()); + ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), + page.index.begin())); + ASSERT_TRUE(std::equal(loaded.index.Offset(), + loaded.index.Offset() + loaded.index.OffsetSize(), + page.index.Offset())); + } +} +} // namespace data +} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 5f791eeef9a2..94516be04b9a 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -370,7 +370,7 @@ std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, std::unique_ptr dmat{DMatrix::Create( static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), 1, prefix)}; + std::numeric_limits::quiet_NaN(), omp_get_max_threads(), prefix)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(dmat.get())) +