diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 99015b90820a..5fbeab6d72ee 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -26,6 +26,8 @@ namespace xgboost { // forward declare learner. class LearnerImpl; +// forward declare dmatrix. +class DMatrix; /*! \brief data type accepted by xgboost interface */ enum DataType { @@ -86,7 +88,7 @@ class MetaInfo { * \return The pre-defined root index of i-th instance. */ inline unsigned GetRoot(size_t i) const { - return root_index_.size() != 0 ? root_index_[i] : 0U; + return !root_index_.empty() ? root_index_[i] : 0U; } /*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */ inline const std::vector& LabelAbsSort() const { @@ -166,7 +168,7 @@ class SparsePage { /*! \brief the data of the segments */ HostDeviceVector data; - size_t base_rowid; + size_t base_rowid{}; /*! \brief an instance of sparse vector in the batch */ using Inst = common::Span; @@ -215,23 +217,23 @@ class SparsePage { const int nthread = omp_get_max_threads(); builder.InitBudget(num_columns, nthread); long batch_size = static_cast(this->Size()); // NOLINT(*) -#pragma omp parallel for schedule(static) +#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static) for (long i = 0; i < batch_size; ++i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = (*this)[i]; - for (bst_uint j = 0; j < inst.size(); ++j) { - builder.AddBudget(inst[j].index, tid); + for (const auto& entry : inst) { + builder.AddBudget(entry.index, tid); } } builder.InitStorage(); -#pragma omp parallel for schedule(static) +#pragma omp parallel for default(none) shared(batch_size, builder) schedule(static) for (long i = 0; i < batch_size; ++i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = (*this)[i]; - for (bst_uint j = 0; j < inst.size(); ++j) { + for (const auto& entry : inst) { builder.Push( - inst[j].index, - Entry(static_cast(this->base_rowid + i), inst[j].fvalue), + entry.index, + Entry(static_cast(this->base_rowid + i), entry.fvalue), tid); } } @@ -240,7 +242,7 @@ class SparsePage { void SortRows() { auto ncol = static_cast(this->Size()); -#pragma omp parallel for schedule(dynamic, 1) +#pragma omp parallel for default(none) shared(ncol) schedule(dynamic, 1) for (bst_omp_uint i = 0; i < ncol; ++i) { if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) { std::sort( @@ -287,10 +289,30 @@ class SortedCSCPage : public SparsePage { explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {} }; +class EllpackPageImpl; +class EllpackPage { + public: + explicit EllpackPage(DMatrix* dmat); + ~EllpackPage(); + EllpackPage() = delete; + // no copy + EllpackPage(const EllpackPage&) = delete; + EllpackPage& operator=(const EllpackPage&) = delete; + // movable + EllpackPage(EllpackPage&&) = default; + EllpackPage& operator=(EllpackPage&&) = default; + + const EllpackPageImpl* Impl() const { return impl_.get(); } + EllpackPageImpl* Impl() { return impl_.get(); } + + private: + std::unique_ptr impl_; +}; + template class BatchIteratorImpl { public: - virtual ~BatchIteratorImpl() {} + virtual ~BatchIteratorImpl() = default; virtual T& operator*() = 0; virtual const T& operator*() const = 0; virtual void operator++() = 0; @@ -412,7 +434,7 @@ class DMatrix { bool silent, bool load_row_split, const std::string& file_format = "auto", - const size_t page_size = kPageSize); + size_t page_size = kPageSize); /*! * \brief create a new DMatrix, by wrapping a row_iterator, and meta info. @@ -438,7 +460,7 @@ class DMatrix { */ static DMatrix* Create(dmlc::Parser* parser, const std::string& cache_prefix = "", - const size_t page_size = kPageSize); + size_t page_size = kPageSize); /*! \brief page size 32 MB */ static const size_t kPageSize = 32UL << 20UL; @@ -447,6 +469,7 @@ class DMatrix { virtual BatchSet GetRowBatches() = 0; virtual BatchSet GetColumnBatches() = 0; virtual BatchSet GetSortedColumnBatches() = 0; + virtual BatchSet GetEllpackBatches() = 0; }; template<> @@ -463,6 +486,11 @@ template<> inline BatchSet DMatrix::GetBatches() { return GetSortedColumnBatches(); } + +template<> +inline BatchSet DMatrix::GetBatches() { + return GetEllpackBatches(); +} } // namespace xgboost namespace dmlc { diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index d1ef37df1094..b207d8b31f42 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -99,15 +99,15 @@ struct SketchContainer { std::vector col_locks_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; - SketchContainer(const tree::TrainParam ¶m, DMatrix *dmat) : + SketchContainer(int max_bin, DMatrix *dmat) : col_locks_(dmat->Info().num_col_) { const MetaInfo &info = dmat->Info(); // Initialize Sketches for this dmatrix sketches_.resize(info.num_col_); -#pragma omp parallel for default(none) shared(info, param) schedule(static) \ +#pragma omp parallel for default(none) shared(info, max_bin) schedule(static) \ if (info.num_col_ > kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < info.num_col_; ++icol) { // NOLINT - sketches_[icol].Init(info.num_row_, 1.0 / (8 * param.max_bin)); + sketches_[icol].Init(info.num_row_, 1.0 / (8 * max_bin)); } } @@ -130,7 +130,7 @@ struct GPUSketcher { bool has_weights_{false}; size_t row_stride_{0}; - tree::TrainParam param_; + const int max_bin_; SketchContainer *sketch_container_; dh::device_vector row_ptrs_{}; dh::device_vector entries_{}; @@ -148,11 +148,11 @@ struct GPUSketcher { public: DeviceShard(int device, bst_uint n_rows, - tree::TrainParam param, + int max_bin, SketchContainer* sketch_container) : device_(device), n_rows_(n_rows), - param_(std::move(param)), + max_bin_(max_bin), sketch_container_(sketch_container) { } @@ -183,7 +183,7 @@ struct GPUSketcher { } constexpr int kFactor = 8; - double eps = 1.0 / (kFactor * param_.max_bin); + double eps = 1.0 / (kFactor * max_bin_); size_t dummy_nlevel; WXQSketch::LimitSizeLevel(gpu_batch_nrows_, eps, &dummy_nlevel, &n_cuts_); @@ -362,7 +362,7 @@ struct GPUSketcher { // add cuts into sketches thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); #pragma omp parallel for default(none) schedule(static) \ - if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT +if (num_cols_ > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { WXQSketch::SummaryContainer summary; summary.Reserve(n_cuts_); @@ -403,10 +403,8 @@ struct GPUSketcher { }; void SketchBatch(const SparsePage &batch, const MetaInfo &info) { - auto device = generic_param_.gpu_id; - // create device shard - shard_.reset(new DeviceShard(device, batch.Size(), param_, sketch_container_.get())); + shard_.reset(new DeviceShard(device_, batch.Size(), max_bin_, sketch_container_.get())); // compute sketches for the shard shard_->Init(batch, info, gpu_batch_nrows_); @@ -417,9 +415,8 @@ struct GPUSketcher { row_stride_ = shard_->GetRowStride(); } - GPUSketcher(const tree::TrainParam ¶m, const GenericParameter &generic_param, int gpu_nrows) - : param_(param), generic_param_(generic_param), gpu_batch_nrows_(gpu_nrows), row_stride_(0) { - } + GPUSketcher(int device, int max_bin, int gpu_nrows) + : device_(device), max_bin_(max_bin), gpu_batch_nrows_(gpu_nrows), row_stride_(0) {} /* Builds the sketches on the GPU for the dmatrix and returns the row stride * for the entire dataset */ @@ -427,29 +424,31 @@ struct GPUSketcher { const MetaInfo &info = dmat->Info(); row_stride_ = 0; - sketch_container_.reset(new SketchContainer(param_, dmat)); + sketch_container_.reset(new SketchContainer(max_bin_, dmat)); for (const auto &batch : dmat->GetBatches()) { this->SketchBatch(batch, info); } - hmat->Init(&sketch_container_->sketches_, param_.max_bin); + hmat->Init(&sketch_container_->sketches_, max_bin_); return row_stride_; } private: std::unique_ptr shard_; - const tree::TrainParam ¶m_; - const GenericParameter &generic_param_; + const int device_; + const int max_bin_; int gpu_batch_nrows_; size_t row_stride_; std::unique_ptr sketch_container_; }; -size_t DeviceSketch - (const tree::TrainParam ¶m, const GenericParameter &learner_param, int gpu_batch_nrows, - DMatrix *dmat, HistogramCuts *hmat) { - GPUSketcher sketcher(param, learner_param, gpu_batch_nrows); +size_t DeviceSketch(int device, + int max_bin, + int gpu_batch_nrows, + DMatrix* dmat, + HistogramCuts* hmat) { + GPUSketcher sketcher(device, max_bin, gpu_batch_nrows); // We only need to return the result in HistogramCuts container, so it is safe to // use a pointer of local HistogramCutsDense DenseCuts dense_cuts(hmat); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index df765301f5be..1ae9389d4868 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -290,10 +290,11 @@ class DenseCuts : public CutsBuilder { * * \return The row stride across the entire dataset. */ -size_t DeviceSketch - (const tree::TrainParam& param, const GenericParameter &learner_param, int gpu_batch_nrows, - DMatrix* dmat, HistogramCuts* hmat); - +size_t DeviceSketch(int device, + int max_bin, + int gpu_batch_nrows, + DMatrix* dmat, + HistogramCuts* hmat); /*! * \brief preprocessed global index matrix, in CSR format diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc new file mode 100644 index 000000000000..f84fc2bece77 --- /dev/null +++ b/src/data/ellpack_page.cc @@ -0,0 +1,25 @@ +/*! + * Copyright 2019 XGBoost contributors + * + * \file ellpack_page.cc + */ +#ifndef XGBOOST_USE_CUDA + +#include + +// dummy implementation of ELlpackPage in case CUDA is not used +namespace xgboost { + +class EllpackPageImpl {}; + +EllpackPage::EllpackPage(DMatrix* dmat) { + LOG(FATAL) << "Not implemented."; +} + +EllpackPage::~EllpackPage() { + LOG(FATAL) << "Not implemented."; +} + +} // namespace xgboost + +#endif // XGBOOST_USE_CUDA diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu new file mode 100644 index 000000000000..7ebd6c1f3343 --- /dev/null +++ b/src/data/ellpack_page.cu @@ -0,0 +1,210 @@ +/*! + * Copyright 2019 XGBoost contributors + * + * \file ellpack_page.cu + */ + +#include + +#include "./ellpack_page.cuh" +#include "../common/hist_util.h" +#include "../common/random.h" + +namespace xgboost { + +EllpackPage::EllpackPage(DMatrix* dmat) : impl_{new EllpackPageImpl(dmat)} {} + +EllpackPage::~EllpackPage() = default; + +EllpackPageImpl::EllpackPageImpl(DMatrix* dmat) : dmat_{dmat} {} + +template +void EllpackPageImpl::Init(int device, const tree::TrainParam& param, int gpu_batch_nrows) { + if (initialised_) return; + + device_ = device; + monitor_.Init("ellpack_page"); + + monitor_.StartCuda("Quantiles"); + // Create the quantile sketches for the dmatrix and initialize HistogramCuts. + size_t row_stride = common::DeviceSketch(device, param.max_bin, gpu_batch_nrows, dmat_, &hmat_); + monitor_.StopCuda("Quantiles"); + + const auto& info = dmat_->Info(); + auto is_dense = info.num_nonzero_ == info.num_row_ * info.num_col_; + + // Init global data for each shard + monitor_.StartCuda("InitCompressedData"); + dh::safe_cuda(cudaSetDevice(device)); + InitCompressedData(hmat_, param, row_stride, is_dense); + monitor_.StopCuda("InitCompressedData"); + + monitor_.StartCuda("BinningCompression"); + DeviceHistogramBuilderState hist_builder_row_state(info.num_row_); + for (const auto& batch : dmat_->GetBatches()) { + hist_builder_row_state.BeginBatch(batch); + + dh::safe_cuda(cudaSetDevice(device_)); + CreateHistIndices(batch, + hmat_, + hist_builder_row_state.GetRowStateOnDevice(), + gpu_batch_nrows); + + hist_builder_row_state.EndBatch(); + } + monitor_.StopCuda("BinningCompression"); + + initialised_ = true; +} + +template +void EllpackPageImpl::InitCompressedData(const common::HistogramCuts& hmat, + const tree::TrainParam& param, + size_t row_stride, + bool is_dense) { + n_bins = hmat.Ptrs().back(); + int null_gidx_value = hmat.Ptrs().back(); + + int num_symbols = n_bins + 1; + // Required buffer size for storing data matrix in ELLPack format. + size_t compressed_size_bytes = common::CompressedBufferWriter::CalculateBufferSize( + row_stride * dmat_->Info().num_row_, num_symbols); + + ba.Allocate(device_, + &feature_segments_, hmat.Ptrs().size(), + &gidx_fvalue_map_, hmat.Values().size(), + &min_fvalue_, hmat.MinValues().size(), + &gidx_buffer_, compressed_size_bytes); + + dh::CopyVectorToDeviceSpan(gidx_fvalue_map_, hmat.Values()); + dh::CopyVectorToDeviceSpan(min_fvalue_, hmat.MinValues()); + dh::CopyVectorToDeviceSpan(feature_segments_, hmat.Ptrs()); + thrust::fill( + thrust::device_pointer_cast(gidx_buffer_.data()), + thrust::device_pointer_cast(gidx_buffer_.data() + gidx_buffer_.size()), 0); + + ellpack_matrix.Init( + feature_segments_, min_fvalue_, + gidx_fvalue_map_, row_stride, + common::CompressedIterator(gidx_buffer_.data(), num_symbols), + is_dense, null_gidx_value); + // check if we can use shared memory for building histograms + // (assuming atleast we need 2 CTAs per SM to maintain decent latency + // hiding) + auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back(); + auto max_smem = dh::MaxSharedMemory(device_); + if (histogram_size <= max_smem) { + use_shared_memory_histograms = true; + } +} + +// Bin each input data entry, store the bin indices in compressed form. +template::type = 0> +__global__ void CompressBinEllpackKernel( + common::CompressedBufferWriter wr, + common::CompressedByteT* __restrict__ buffer, // gidx_buffer + const size_t* __restrict__ row_ptrs, // row offset of input data + const Entry* __restrict__ entries, // One batch of input data + const float* __restrict__ cuts, // HistogramCuts::cut + const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs + size_t base_row, // batch_row_begin + size_t n_rows, + size_t row_stride, + unsigned int null_gidx_value) { + size_t irow = threadIdx.x + blockIdx.x * blockDim.x; + int ifeature = threadIdx.y + blockIdx.y * blockDim.y; + if (irow >= n_rows || ifeature >= row_stride) { + return; + } + int row_length = static_cast(row_ptrs[irow + 1] - row_ptrs[irow]); + unsigned int bin = null_gidx_value; + if (ifeature < row_length) { + Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature]; + int feature = entry.index; + float fvalue = entry.fvalue; + // {feature_cuts, ncuts} forms the array of cuts of `feature'. + const float *feature_cuts = &cuts[cut_rows[feature]]; + int ncuts = cut_rows[feature + 1] - cut_rows[feature]; + // Assigning the bin in current entry. + // S.t.: fvalue < feature_cuts[bin] + bin = dh::UpperBound(feature_cuts, ncuts, fvalue); + if (bin >= ncuts) { + bin = ncuts - 1; + } + // Add the number of bins in previous features. + bin += cut_rows[feature]; + } + // Write to gidx buffer. + wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); +} + +template +void EllpackPageImpl::CreateHistIndices( + const SparsePage &row_batch, + const common::HistogramCuts &hmat, + const RowStateOnDevice &device_row_state, + int rows_per_batch) { + // Has any been allocated for me in this batch? + if (!device_row_state.rows_to_process_from_batch) return; + + unsigned int null_gidx_value = hmat.Ptrs().back(); + size_t row_stride = this->ellpack_matrix.row_stride; + + const auto &offset_vec = row_batch.offset.ConstHostVector(); + + int num_symbols = n_bins + 1; + // bin and compress entries in batches of rows + size_t gpu_batch_nrows = std::min( + dh::TotalMemory(device_) / (16 * row_stride * sizeof(Entry)), + static_cast(device_row_state.rows_to_process_from_batch)); + const std::vector& data_vec = row_batch.data.ConstHostVector(); + + size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch, + gpu_batch_nrows); + + for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { + size_t batch_row_begin = gpu_batch * gpu_batch_nrows; + size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows; + if (batch_row_end > device_row_state.rows_to_process_from_batch) { + batch_row_end = device_row_state.rows_to_process_from_batch; + } + size_t batch_nrows = batch_row_end - batch_row_begin; + + const auto ent_cnt_begin = + offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin]; + const auto ent_cnt_end = + offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end]; + + /*! \brief row offset in SparsePage (the input data). */ + dh::device_vector row_ptrs(batch_nrows+1); + thrust::copy( + offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin, + offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1, + row_ptrs.begin()); + + // number of entries in this batch. + size_t n_entries = ent_cnt_end - ent_cnt_begin; + dh::device_vector entries_d(n_entries); + // copy data entries to device. + dh::safe_cuda + (cudaMemcpy + (entries_d.data().get(), data_vec.data() + ent_cnt_begin, + n_entries * sizeof(Entry), cudaMemcpyDefault)); + const dim3 block3(32, 8, 1); // 256 threads + const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), + common::DivRoundUp(row_stride, block3.y), 1); + CompressBinEllpackKernel<<>> + (common::CompressedBufferWriter(num_symbols), + gidx_buffer_.data(), + row_ptrs.data().get(), + entries_d.data().get(), + gidx_fvalue_map_.data(), + feature_segments_.data(), + device_row_state.total_rows_processed + batch_row_begin, + batch_nrows, + row_stride, + null_gidx_value); + } +} + +} // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh new file mode 100644 index 000000000000..7b932eeabf01 --- /dev/null +++ b/src/data/ellpack_page.cuh @@ -0,0 +1,309 @@ +/*! + * Copyright 2019 by XGBoost Contributors + * + * \file ellpack_page.cuh + */ + +#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_ +#define XGBOOST_DATA_ELLPACK_PAGE_H_ + +#include + +#include "../common/compressed_iterator.h" +#include "../common/device_helpers.cuh" +#include "../common/hist_util.h" + +namespace xgboost { + +// Find a gidx value for a given feature otherwise return -1 if not found +__forceinline__ __device__ int BinarySearchRow( + bst_uint begin, bst_uint end, + common::CompressedIterator data, + int const fidx_begin, int const fidx_end) { + bst_uint previous_middle = UINT32_MAX; + while (end != begin) { + auto middle = begin + (end - begin) / 2; + if (middle == previous_middle) { + break; + } + previous_middle = middle; + + auto gidx = data[middle]; + + if (gidx >= fidx_begin && gidx < fidx_end) { + return gidx; + } else if (gidx < fidx_begin) { + begin = middle; + } else { + end = middle; + } + } + // Value is missing + return -1; +} + +/** \brief Struct for accessing and manipulating an ellpack matrix on the + * device. Does not own underlying memory and may be trivially copied into + * kernels.*/ +struct ELLPackMatrix { + common::Span feature_segments; + /*! \brief minimum value for each feature. */ + common::Span min_fvalue; + /*! \brief Cut. */ + common::Span gidx_fvalue_map; + /*! \brief row length for ELLPack. */ + size_t row_stride{0}; + common::CompressedIterator gidx_iter; + bool is_dense; + int null_gidx_value; + + XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); } + + // Get a matrix element, uses binary search for look up Return NaN if missing + // Given a row index and a feature index, returns the corresponding cut value + __device__ bst_float GetElement(size_t ridx, size_t fidx) const { + auto row_begin = row_stride * ridx; + auto row_end = row_begin + row_stride; + auto gidx = -1; + if (is_dense) { + gidx = gidx_iter[row_begin + fidx]; + } else { + gidx = + BinarySearchRow(row_begin, row_end, gidx_iter, feature_segments[fidx], + feature_segments[fidx + 1]); + } + if (gidx == -1) { + return nan(""); + } + return gidx_fvalue_map[gidx]; + } + void Init(common::Span feature_segments, + common::Span min_fvalue, + common::Span gidx_fvalue_map, size_t row_stride, + common::CompressedIterator gidx_iter, bool is_dense, + int null_gidx_value) { + this->feature_segments = feature_segments; + this->min_fvalue = min_fvalue; + this->gidx_fvalue_map = gidx_fvalue_map; + this->row_stride = row_stride; + this->gidx_iter = gidx_iter; + this->is_dense = is_dense; + this->null_gidx_value = null_gidx_value; + } +}; + +/** + * \struct DeviceHistogram + * + * \summary Data storage for node histograms on device. Automatically expands. + * + * \tparam GradientSumT histogram entry type. + * \tparam kStopGrowingSize Do not grow beyond this size + * + * \author Rory + * \date 28/07/2018 + */ +template +class DeviceHistogram { + private: + /*! \brief Map nidx to starting index of its histogram. */ + std::map nidx_map_; + dh::device_vector data_; + int n_bins_; + int device_id_; + static constexpr size_t kNumItemsInGradientSum = + sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); + static_assert(kNumItemsInGradientSum == 2, + "Number of items in gradient type should be 2."); + + public: + void Init(int device_id, int n_bins) { + this->n_bins_ = n_bins; + this->device_id_ = device_id; + } + + void Reset() { + dh::safe_cuda(cudaMemsetAsync( + data_.data().get(), 0, + data_.size() * sizeof(typename decltype(data_)::value_type))); + nidx_map_.clear(); + } + bool HistogramExists(int nidx) const { + return nidx_map_.find(nidx) != nidx_map_.cend(); + } + size_t HistogramSize() const { + return n_bins_ * kNumItemsInGradientSum; + } + + dh::device_vector& Data() { + return data_; + } + + void AllocateHistogram(int nidx) { + if (HistogramExists(nidx)) return; + // Number of items currently used in data + const size_t used_size = nidx_map_.size() * HistogramSize(); + const size_t new_used_size = used_size + HistogramSize(); + dh::safe_cuda(cudaSetDevice(device_id_)); + if (data_.size() >= kStopGrowingSize) { + // Recycle histogram memory + if (new_used_size <= data_.size()) { + // no need to remove old node, just insert the new one. + nidx_map_[nidx] = used_size; + // memset histogram size in bytes + dh::safe_cuda(cudaMemsetAsync(data_.data().get() + used_size, 0, + n_bins_ * sizeof(GradientSumT))); + } else { + std::pair old_entry = *nidx_map_.begin(); + nidx_map_.erase(old_entry.first); + dh::safe_cuda(cudaMemsetAsync(data_.data().get() + old_entry.second, 0, + n_bins_ * sizeof(GradientSumT))); + nidx_map_[nidx] = old_entry.second; + } + } else { + // Append new node histogram + nidx_map_[nidx] = used_size; + size_t new_required_memory = std::max(data_.size() * 2, HistogramSize()); + if (data_.size() < new_required_memory) { + data_.resize(new_required_memory); + } + } + } + + /** + * \summary Return pointer to histogram memory for a given node. + * \param nidx Tree node index. + * \return hist pointer. + */ + common::Span GetNodeHistogram(int nidx) { + CHECK(this->HistogramExists(nidx)); + auto ptr = data_.data().get() + nidx_map_[nidx]; + return common::Span( + reinterpret_cast(ptr), n_bins_); + } +}; + +// Instances of this type are created while creating the histogram bins for the +// entire dataset across multiple sparse page batches. This keeps track of the number +// of rows to process from a batch and the position from which to process on each device. +struct RowStateOnDevice { + // Number of rows assigned to this device + size_t total_rows_assigned_to_device; + // Number of rows processed thus far + size_t total_rows_processed; + // Number of rows to process from the current sparse page batch + size_t rows_to_process_from_batch; + // Offset from the current sparse page batch to begin processing + size_t row_offset_in_current_batch; + + explicit RowStateOnDevice(size_t total_rows) + : total_rows_assigned_to_device(total_rows), total_rows_processed(0), + rows_to_process_from_batch(0), row_offset_in_current_batch(0) { + } + + explicit RowStateOnDevice(size_t total_rows, size_t batch_rows) + : total_rows_assigned_to_device(total_rows), total_rows_processed(0), + rows_to_process_from_batch(batch_rows), row_offset_in_current_batch(0) { + } + + // Advance the row state by the number of rows processed + void Advance() { + total_rows_processed += rows_to_process_from_batch; + CHECK_LE(total_rows_processed, total_rows_assigned_to_device); + rows_to_process_from_batch = row_offset_in_current_batch = 0; + } +}; + +// An instance of this type is created which keeps track of total number of rows to process, +// rows processed thus far, rows to process and the offset from the current sparse page batch +// to begin processing on each device +class DeviceHistogramBuilderState { + public: + explicit DeviceHistogramBuilderState(int n_rows) : device_row_state_(n_rows) {} + + const RowStateOnDevice& GetRowStateOnDevice() const { + return device_row_state_; + } + + // This method is invoked at the beginning of each sparse page batch. This distributes + // the rows in the sparse page to the device. + // TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins. + void BeginBatch(const SparsePage &batch) { + size_t rem_rows = batch.Size(); + size_t row_offset_in_current_batch = 0; + + // Do we have anymore left to process from this batch on this device? + if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) { + // There are still some rows that needs to be assigned to this device + device_row_state_.rows_to_process_from_batch = + std::min( + device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed, + rem_rows); + } else { + // All rows have been assigned to this device + device_row_state_.rows_to_process_from_batch = 0; + } + + device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch; + row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch; + rem_rows -= device_row_state_.rows_to_process_from_batch; + } + + // This method is invoked after completion of each sparse page batch + void EndBatch() { + device_row_state_.Advance(); + } + + private: + RowStateOnDevice device_row_state_{0}; +}; + +class EllpackPageImpl { + public: + explicit EllpackPageImpl(DMatrix* dmat); + + template + void Init(int device, const tree::TrainParam& param, int gpu_batch_nrows); + + private: + template + void InitCompressedData(const common::HistogramCuts& hmat, + const tree::TrainParam& param, + size_t row_stride, + bool is_dense); + + template + void CreateHistIndices( + const SparsePage& row_batch, const common::HistogramCuts& hmat, + const RowStateOnDevice& device_row_state, int rows_per_batch); + + bool initialised_{false}; + int device_{-1}; + int n_bins{}; + bool use_shared_memory_histograms {false}; + + DMatrix* dmat_; + common::HistogramCuts hmat_; + common::Monitor monitor_; + + dh::BulkAllocator ba; + ELLPackMatrix ellpack_matrix; + + /*! \brief row_ptr form HistogramCuts. */ + common::Span feature_segments_; + /*! \brief minimum value for each feature. */ + common::Span min_fvalue_; + /*! \brief Cut. */ + common::Span gidx_fvalue_map_; + /*! \brief global index of histogram, which is stored in ELLPack format. */ + common::Span gidx_buffer_; +}; + +// Total number of nodes in tree, given depth +XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { + return (1 << (depth + 1)) - 1; +} + +} // namespace xgboost + +#endif // XGBOOST_DATA_ELLPACK_PAGE_H_ diff --git a/src/data/simple_batch_iterator.h b/src/data/simple_batch_iterator.h new file mode 100644 index 000000000000..53464c6fa6c5 --- /dev/null +++ b/src/data/simple_batch_iterator.h @@ -0,0 +1,33 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#ifndef XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_ +#define XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_ + +#include + +namespace xgboost { +namespace data { + +template +class SimpleBatchIteratorImpl : public BatchIteratorImpl { + public: + explicit SimpleBatchIteratorImpl(T* page) : page_(page) {} + T& operator*() override { + CHECK(page_ != nullptr); + return *page_; + } + const T& operator*() const override { + CHECK(page_ != nullptr); + return *page_; + } + void operator++() override { page_ = nullptr; } + bool AtEnd() const override { return page_ == nullptr; } + + private: + T* page_{nullptr}; +}; + +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_SIMPLE_BATCH_ITERATOR_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 9f75ab055bbd..65f639a13e48 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -6,6 +6,7 @@ */ #include "./simple_dmatrix.h" #include +#include "./simple_batch_iterator.h" #include "../common/random.h" namespace xgboost { @@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) { return 1.0f - (static_cast(nmiss)) / this->Info().num_row_; } -template -class SimpleBatchIteratorImpl : public BatchIteratorImpl { - public: - explicit SimpleBatchIteratorImpl(T* page) : page_(page) {} - T& operator*() override { - CHECK(page_ != nullptr); - return *page_; - } - const T& operator*() const override { - CHECK(page_ != nullptr); - return *page_; - } - void operator++() override { page_ = nullptr; } - bool AtEnd() const override { return page_ == nullptr; } - - private: - T* page_{nullptr}; -}; - BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. auto cast = dynamic_cast(source_.get()); @@ -80,6 +62,16 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { return BatchSet(begin_iter); } +BatchSet SimpleDMatrix::GetEllpackBatches() { + // ELLPACK page doesn't exist, generate it + if (!ellpack_page_) { + ellpack_page_.reset(new EllpackPage(this)); + } + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_.get())); + return BatchSet(begin_iter); +} + bool SimpleDMatrix::SingleColBlock() const { return true; } } // namespace data } // namespace xgboost diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index a838fc960701..2c740924dcba 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -38,12 +38,14 @@ class SimpleDMatrix : public DMatrix { BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; + BatchSet GetEllpackBatches() override; // source data pointer. std::unique_ptr> source_; std::unique_ptr column_page_; std::unique_ptr sorted_column_page_; + std::unique_ptr ellpack_page_; }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index a1c73f1e6737..8a7fdb6eb30c 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -10,9 +10,13 @@ #if DMLC_ENABLE_STD_THREAD #include "./sparse_page_dmatrix.h" +#include "./simple_batch_iterator.h" + namespace xgboost { namespace data { +extern template class SimpleBatchIteratorImpl; + MetaInfo& SparsePageDMatrix::Info() { return row_source_->info; } @@ -72,6 +76,16 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { return BatchSet(begin_iter); } +BatchSet SparsePageDMatrix::GetEllpackBatches() { + // ELLPACK page doesn't exist, generate it + if (!ellpack_page_) { + ellpack_page_.reset(new EllpackPage(this)); + } + auto begin_iter = + BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_.get())); + return BatchSet(begin_iter); +} + float SparsePageDMatrix::GetColDensity(size_t cidx) { // Finds densities if we don't already have them if (col_density_.empty()) { diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index c5a2401d228b..b8921ba95ef9 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -24,7 +24,7 @@ class SparsePageDMatrix : public DMatrix { explicit SparsePageDMatrix(std::unique_ptr>&& source, std::string cache_info) : row_source_(std::move(source)), cache_info_(std::move(cache_info)) {} - virtual ~SparsePageDMatrix() = default; + ~SparsePageDMatrix() override = default; MetaInfo& Info() override; @@ -38,11 +38,13 @@ class SparsePageDMatrix : public DMatrix { BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; + BatchSet GetEllpackBatches() override; // source data pointers. std::unique_ptr> row_source_; std::unique_ptr> column_source_; std::unique_ptr> sorted_column_source_; + std::unique_ptr ellpack_page_; // the cache prefix std::string cache_info_; // Store column densities to avoid recalculating diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 0a06daa2a930..af71a3a16ed8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1329,7 +1329,7 @@ class GPUHistMakerSpecialised { monitor_.StartCuda("Quantiles"); // Create the quantile sketches for the dmatrix and initialize HistogramCuts - size_t row_stride = common::DeviceSketch(param_, *generic_param_, + size_t row_stride = common::DeviceSketch(generic_param_->gpu_id, param_.max_bin, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_); monitor_.StopCuda("Quantiles"); @@ -1359,7 +1359,7 @@ class GPUHistMakerSpecialised { initialised_ = true; } - void InitData(HostDeviceVector* gpair, DMatrix* dmat) { + void InitData(DMatrix* dmat) { if (!initialised_) { monitor_.StartCuda("InitDataOnce"); this->InitDataOnce(dmat); @@ -1387,7 +1387,7 @@ class GPUHistMakerSpecialised { void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { monitor_.StartCuda("InitData"); - this->InitData(gpair, p_fmat); + this->InitData(p_fmat); monitor_.StopCuda("InitData"); gpair->SetDevice(device_); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index bbb1cc4bfe4d..cdbafec4ddac 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -43,18 +43,17 @@ void TestDeviceSketch(bool use_external_memory) { dmat = static_cast *>(dmat_handle); } - tree::TrainParam p; - p.max_bin = 20; - int gpu_batch_nrows = 0; + int device{0}; + int max_bin{20}; + int gpu_batch_nrows{0}; // find quantiles on the CPU HistogramCuts hmat_cpu; - hmat_cpu.Build((*dmat).get(), p.max_bin); + hmat_cpu.Build((*dmat).get(), max_bin); // find the cuts on the GPU HistogramCuts hmat_gpu; - size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0), gpu_batch_nrows, - dmat->get(), &hmat_gpu); + size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu); // compare the row stride with the one obtained from the dmatrix size_t expected_row_stride = 0;