From 72cd1517d6b1d145c34e13a063fadd31b507b01d Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Thu, 30 Aug 2018 04:28:47 +0200 Subject: [PATCH] Replaced std::vector with HostDeviceVector in MetaInfo and SparsePage. (#3446) * Replaced std::vector with HostDeviceVector in MetaInfo and SparsePage. - added distributions to HostDeviceVector - using HostDeviceVector for labels, weights and base margings in MetaInfo - using HostDeviceVector for offset and data in SparsePage - other necessary refactoring * Added const version of HostDeviceVector API calls. - const versions added to calls that can trigger data transfers, e.g. DevicePointer() - updated the code that uses HostDeviceVector - objective functions now accept const HostDeviceVector& for predictions * Updated src/linear/updater_gpu_coordinate.cu. * Added read-only state for HostDeviceVector sync. - this means no copies are performed if both host and devices access the HostDeviceVector read-only * Fixed linter and test errors. - updated the lz4 plugin - added ConstDeviceSpan to HostDeviceVector - using device % dh::NVisibleDevices() for the physical device number, e.g. in calls to cudaSetDevice() * Fixed explicit template instantiation errors for HostDeviceVector. - replaced HostDeviceVector with HostDeviceVector * Fixed HostDeviceVector tests that require multiple GPUs. - added a mock set device handler; when set, it is called instead of cudaSetDevice() --- include/xgboost/data.h | 84 ++-- include/xgboost/objective.h | 2 +- plugin/example/custom_obj.cc | 11 +- plugin/lz4/sparse_page_lz4_format.cc | 51 +-- src/c_api/c_api.cc | 103 +++-- src/cli_main.cc | 2 +- src/common/hist_util.cc | 4 +- src/common/hist_util.cu | 26 +- src/common/host_device_vector.cc | 81 +++- src/common/host_device_vector.cu | 426 ++++++++++++++------ src/common/host_device_vector.h | 209 ++++++++-- src/data/data.cc | 43 +- src/data/simple_csr_source.cc | 24 +- src/data/simple_dmatrix.cc | 10 +- src/data/sparse_page_dmatrix.cc | 23 +- src/data/sparse_page_raw_format.cc | 42 +- src/data/sparse_page_source.cc | 10 +- src/gbm/gblinear.cc | 4 +- src/gbm/gbtree.cc | 7 +- src/learner.cc | 8 +- src/linear/updater_coordinate.cc | 10 +- src/linear/updater_gpu_coordinate.cu | 6 +- src/linear/updater_shotgun.cc | 14 +- src/metric/elementwise_metric.cc | 12 +- src/metric/multiclass_metric.cc | 16 +- src/metric/rank_metric.cc | 45 ++- src/objective/hinge.cc | 18 +- src/objective/multiclass_obj.cc | 11 +- src/objective/rank_obj.cc | 13 +- src/objective/regression_obj.cc | 74 ++-- src/objective/regression_obj_gpu.cu | 60 +-- src/predictor/cpu_predictor.cc | 6 +- src/predictor/gpu_predictor.cu | 24 +- src/tree/updater_colmaker.cc | 4 +- src/tree/updater_fast_hist.cc | 2 +- src/tree/updater_gpu.cu | 2 +- src/tree/updater_gpu_hist.cu | 21 +- src/tree/updater_histmaker.cc | 2 +- src/tree/updater_refresh.cc | 2 +- src/tree/updater_skmaker.cc | 2 +- tests/cpp/common/test_host_device_vector.cu | 152 ++++++- tests/cpp/data/test_metainfo.cc | 14 +- tests/cpp/data/test_simple_dmatrix.cc | 2 +- tests/cpp/data/test_sparse_page_dmatrix.cc | 4 +- tests/cpp/helpers.cc | 15 +- 45 files changed, 1141 insertions(+), 560 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 36e872ef1850..436799fe2b67 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -17,6 +17,8 @@ #include "./base.h" #include "../../src/common/span.h" +#include "../../src/common/host_device_vector.h" + namespace xgboost { // forward declare learner. class LearnerImpl; @@ -41,7 +43,7 @@ class MetaInfo { /*! \brief number of nonzero entries in the data */ uint64_t num_nonzero_{0}; /*! \brief label of each instance */ - std::vector labels_; + HostDeviceVector labels_; /*! * \brief specified root index of each instance, * can be used for multi task setting @@ -53,7 +55,7 @@ class MetaInfo { */ std::vector group_ptr_; /*! \brief weights of each instance, optional */ - std::vector weights_; + HostDeviceVector weights_; /*! \brief session-id of each instance, optional */ std::vector qids_; /*! @@ -61,7 +63,7 @@ class MetaInfo { * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from. */ - std::vector base_margin_; + HostDeviceVector base_margin_; /*! \brief version flag, used to check version of this info */ static const int kVersion = 2; /*! \brief version that introduced qid field */ @@ -74,7 +76,7 @@ class MetaInfo { * \return The weight. */ inline bst_float GetWeight(size_t i) const { - return weights_.size() != 0 ? weights_[i] : 1.0f; + return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f; } /*! * \brief Get the root index of i-th instance. @@ -86,12 +88,12 @@ class MetaInfo { } /*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */ inline const std::vector& LabelAbsSort() const { - if (label_order_cache_.size() == labels_.size()) { + if (label_order_cache_.size() == labels_.Size()) { return label_order_cache_; } - label_order_cache_.resize(labels_.size()); + label_order_cache_.resize(labels_.Size()); std::iota(label_order_cache_.begin(), label_order_cache_.end(), 0); - const auto l = labels_; + const auto& l = labels_.HostVector(); XGBOOST_PARALLEL_SORT(label_order_cache_.begin(), label_order_cache_.end(), [&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);}); @@ -151,9 +153,9 @@ struct Entry { */ class SparsePage { public: - std::vector offset; + HostDeviceVector offset; /*! \brief the data of the segments */ - std::vector data; + HostDeviceVector data; size_t base_rowid; @@ -162,8 +164,10 @@ class SparsePage { /*! \brief get i-th row from the batch */ inline Inst operator[](size_t i) const { - return {data.data() + offset[i], - static_cast(offset[i + 1] - offset[i])}; + const auto& data_vec = data.HostVector(); + const auto& offset_vec = offset.HostVector(); + return {data_vec.data() + offset_vec[i], + static_cast(offset_vec[i + 1] - offset_vec[i])}; } /*! \brief constructor */ @@ -172,18 +176,19 @@ class SparsePage { } /*! \return number of instance in the page */ inline size_t Size() const { - return offset.size() - 1; + return offset.Size() - 1; } /*! \return estimation of memory cost of this page */ inline size_t MemCostBytes() const { - return offset.size() * sizeof(size_t) + data.size() * sizeof(Entry); + return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry); } /*! \brief clear the page */ inline void Clear() { base_rowid = 0; - offset.clear(); - offset.push_back(0); - data.clear(); + auto& offset_vec = offset.HostVector(); + offset_vec.clear(); + offset_vec.push_back(0); + data.HostVector().clear(); } /*! @@ -191,33 +196,39 @@ class SparsePage { * \param batch the row batch. */ inline void Push(const dmlc::RowBlock& batch) { - data.reserve(data.size() + batch.offset[batch.size] - batch.offset[0]); - offset.reserve(offset.size() + batch.size); + auto& data_vec = data.HostVector(); + auto& offset_vec = offset.HostVector(); + data_vec.reserve(data.Size() + batch.offset[batch.size] - batch.offset[0]); + offset_vec.reserve(offset.Size() + batch.size); CHECK(batch.index != nullptr); for (size_t i = 0; i < batch.size; ++i) { - offset.push_back(offset.back() + batch.offset[i + 1] - batch.offset[i]); + offset_vec.push_back(offset_vec.back() + batch.offset[i + 1] - batch.offset[i]); } for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) { uint32_t index = batch.index[i]; bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i]; - data.emplace_back(index, fvalue); + data_vec.emplace_back(index, fvalue); } - CHECK_EQ(offset.back(), data.size()); + CHECK_EQ(offset_vec.back(), data.Size()); } /*! * \brief Push a sparse page * \param batch the row page */ inline void Push(const SparsePage &batch) { - size_t top = offset.back(); - data.resize(top + batch.data.size()); - std::memcpy(dmlc::BeginPtr(data) + top, - dmlc::BeginPtr(batch.data), - sizeof(Entry) * batch.data.size()); - size_t begin = offset.size(); - offset.resize(begin + batch.Size()); + auto& data_vec = data.HostVector(); + auto& offset_vec = offset.HostVector(); + const auto& batch_offset_vec = batch.offset.HostVector(); + const auto& batch_data_vec = batch.data.HostVector(); + size_t top = offset_vec.back(); + data_vec.resize(top + batch.data.Size()); + std::memcpy(dmlc::BeginPtr(data_vec) + top, + dmlc::BeginPtr(batch_data_vec), + sizeof(Entry) * batch.data.Size()); + size_t begin = offset.Size(); + offset_vec.resize(begin + batch.Size()); for (size_t i = 0; i < batch.Size(); ++i) { - offset[i + begin] = top + batch.offset[i + 1]; + offset_vec[i + begin] = top + batch_offset_vec[i + 1]; } } /*! @@ -225,20 +236,21 @@ class SparsePage { * \param inst an instance row */ inline void Push(const Inst &inst) { - offset.push_back(offset.back() + inst.size()); - size_t begin = data.size(); - data.resize(begin + inst.size()); + auto& data_vec = data.HostVector(); + auto& offset_vec = offset.HostVector(); + offset_vec.push_back(offset_vec.back() + inst.size()); + + size_t begin = data_vec.size(); + data_vec.resize(begin + inst.size()); if (inst.size() != 0) { - std::memcpy(dmlc::BeginPtr(data) + begin, inst.data(), + std::memcpy(dmlc::BeginPtr(data_vec) + begin, inst.data(), sizeof(Entry) * inst.size()); } } - size_t Size() { return offset.size() - 1; } + size_t Size() { return offset.Size() - 1; } }; - - /*! * \brief This is data structure that user can pass to DMatrix::Create * to create a DMatrix for training, user can create this data structure diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index fa536e7e62a7..9b1738f22445 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -44,7 +44,7 @@ class ObjFunction { * \param iteration current iteration number. * \param out_gpair output of get gradient, saves gradient and second order gradient in */ - virtual void GetGradient(HostDeviceVector* preds, + virtual void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iteration, HostDeviceVector* out_gpair) = 0; diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index e2e502b3e6fb..6ce653875d19 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -33,21 +33,22 @@ class MyLogistic : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - out_gpair->Resize(preds->Size()); - std::vector& preds_h = preds->HostVector(); + out_gpair->Resize(preds.Size()); + const std::vector& preds_h = preds.HostVector(); std::vector& out_gpair_h = out_gpair->HostVector(); + const std::vector& labels_h = info.labels_.HostVector(); for (size_t i = 0; i < preds_h.size(); ++i) { bst_float w = info.GetWeight(i); // scale the negative examples! - if (info.labels_[i] == 0.0f) w *= param_.scale_neg_weight; + if (labels_h[i] == 0.0f) w *= param_.scale_neg_weight; // logistic transformation bst_float p = 1.0f / (1.0f + std::exp(-preds_h[i])); // this is the gradient - bst_float grad = (p - info.labels_[i]) * w; + bst_float grad = (p - labels_h[i]) * w; // this is the second order gradient bst_float hess = p * (1.0f - p) * w; out_gpair_h.at(i) = GradientPair(grad, hess); diff --git a/plugin/lz4/sparse_page_lz4_format.cc b/plugin/lz4/sparse_page_lz4_format.cc index a1757d5781bf..bf4132161449 100644 --- a/plugin/lz4/sparse_page_lz4_format.cc +++ b/plugin/lz4/sparse_page_lz4_format.cc @@ -177,15 +177,17 @@ class SparsePageLZ4Format : public SparsePageFormat { } bool Read(SparsePage* page, dmlc::SeekStream* fi) override { - if (!fi->Read(&(page->offset))) return false; - CHECK_NE(page->offset.size(), 0) << "Invalid SparsePage file"; + auto& offset_vec = page->offset.HostVector(); + auto& data_vec = page->data.HostVector(); + if (!fi->Read(&(offset_vec))) return false; + CHECK_NE(offset_vec.size(), 0) << "Invalid SparsePage file"; this->LoadIndexValue(fi); - page->data.resize(page->offset.back()); + data_vec.resize(offset_vec.back()); CHECK_EQ(index_.data.size(), value_.data.size()); - CHECK_EQ(index_.data.size(), page->data.size()); - for (size_t i = 0; i < page->data.size(); ++i) { - page->data[i] = Entry(index_.data[i] + min_index_, value_.data[i]); + CHECK_EQ(index_.data.size(), data_vec.size()); + for (size_t i = 0; i < data_vec.size(); ++i) { + data_vec[i] = Entry(index_.data[i] + min_index_, value_.data[i]); } return true; } @@ -195,24 +197,25 @@ class SparsePageLZ4Format : public SparsePageFormat { const std::vector& sorted_index_set) override { if (!fi->Read(&disk_offset_)) return false; this->LoadIndexValue(fi); - - page->offset.clear(); - page->offset.push_back(0); + auto& offset_vec = page->offset.HostVector(); + auto& data_vec = page->data.HostVector(); + offset_vec.clear(); + offset_vec.push_back(0); for (bst_uint cid : sorted_index_set) { - page->offset.push_back( - page->offset.back() + disk_offset_[cid + 1] - disk_offset_[cid]); + offset_vec.push_back( + offset_vec.back() + disk_offset_[cid + 1] - disk_offset_[cid]); } - page->data.resize(page->offset.back()); + data_vec.resize(offset_vec.back()); CHECK_EQ(index_.data.size(), value_.data.size()); CHECK_EQ(index_.data.size(), disk_offset_.back()); for (size_t i = 0; i < sorted_index_set.size(); ++i) { bst_uint cid = sorted_index_set[i]; - size_t dst_begin = page->offset[i]; + size_t dst_begin = offset_vec[i]; size_t src_begin = disk_offset_[cid]; size_t num = disk_offset_[cid + 1] - disk_offset_[cid]; for (size_t j = 0; j < num; ++j) { - page->data[dst_begin + j] = Entry( + data_vec[dst_begin + j] = Entry( index_.data[src_begin + j] + min_index_, value_.data[src_begin + j]); } } @@ -220,22 +223,24 @@ class SparsePageLZ4Format : public SparsePageFormat { } void Write(const SparsePage& page, dmlc::Stream* fo) override { - CHECK(page.offset.size() != 0 && page.offset[0] == 0); - CHECK_EQ(page.offset.back(), page.data.size()); - fo->Write(page.offset); + const auto& offset_vec = page.offset.HostVector(); + const auto& data_vec = page.data.HostVector(); + CHECK(offset_vec.size() != 0 && offset_vec[0] == 0); + CHECK_EQ(offset_vec.back(), data_vec.size()); + fo->Write(offset_vec); min_index_ = page.base_rowid; fo->Write(&min_index_, sizeof(min_index_)); - index_.data.resize(page.data.size()); - value_.data.resize(page.data.size()); + index_.data.resize(data_vec.size()); + value_.data.resize(data_vec.size()); - for (size_t i = 0; i < page.data.size(); ++i) { - bst_uint idx = page.data[i].index - min_index_; + for (size_t i = 0; i < data_vec.size(); ++i) { + bst_uint idx = data_vec[i].index - min_index_; CHECK_LE(idx, static_cast(std::numeric_limits::max())) << "The storage index is chosen to limited to smaller equal than " << std::numeric_limits::max() << "min_index=" << min_index_; index_.data[i] = static_cast(idx); - value_.data[i] = page.data[i].fvalue; + value_.data[i] = data_vec[i].fvalue; } index_.InitCompressChunks(kChunkSize, kMaxChunk); @@ -259,7 +264,7 @@ class SparsePageLZ4Format : public SparsePageFormat { raw_bytes_value_ += value_.RawBytes(); encoded_bytes_index_ += index_.EncodedBytes(); encoded_bytes_value_ += value_.EncodedBytes(); - raw_bytes_ += page.offset.size() * sizeof(size_t); + raw_bytes_ += offset_vec.size() * sizeof(size_t); } inline void LoadIndexValue(dmlc::SeekStream* fi) { diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 506f52b41d8d..40939b51cca5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -250,20 +250,22 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr, API_BEGIN(); data::SimpleCSRSource& mat = *source; - mat.page_.offset.reserve(nindptr); - mat.page_.data.reserve(nelem); - mat.page_.offset.resize(1); - mat.page_.offset[0] = 0; + auto& offset_vec = mat.page_.offset.HostVector(); + auto& data_vec = mat.page_.data.HostVector(); + offset_vec.reserve(nindptr); + data_vec.reserve(nelem); + offset_vec.resize(1); + offset_vec[0] = 0; size_t num_column = 0; for (size_t i = 1; i < nindptr; ++i) { for (size_t j = indptr[i - 1]; j < indptr[i]; ++j) { if (!common::CheckNAN(data[j])) { // automatically skip nan. - mat.page_.data.emplace_back(Entry(indices[j], data[j])); + data_vec.emplace_back(Entry(indices[j], data[j])); num_column = std::max(num_column, static_cast(indices[j] + 1)); } } - mat.page_.offset.push_back(mat.page_.data.size()); + offset_vec.push_back(mat.page_.data.Size()); } mat.info.num_col_ = num_column; @@ -273,7 +275,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr, mat.info.num_col_ = num_col; } mat.info.num_row_ = nindptr - 1; - mat.info.num_nonzero_ = mat.page_.data.size(); + mat.info.num_nonzero_ = mat.page_.data.Size(); *out = new std::shared_ptr(DMatrix::Create(std::move(source))); API_END(); } @@ -305,7 +307,9 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr, // FIXME: User should be able to control number of threads const int nthread = omp_get_max_threads(); data::SimpleCSRSource& mat = *source; - common::ParallelGroupBuilder builder(&mat.page_.offset, &mat.page_.data); + auto& offset_vec = mat.page_.offset.HostVector(); + auto& data_vec = mat.page_.data.HostVector(); + common::ParallelGroupBuilder builder(&offset_vec, &data_vec); builder.InitBudget(0, nthread); size_t ncol = nindptr - 1; // NOLINT(*) #pragma omp parallel for schedule(static) @@ -329,15 +333,16 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr, } } } - mat.info.num_row_ = mat.page_.offset.size() - 1; + mat.info.num_row_ = mat.page_.offset.Size() - 1; if (num_row > 0) { CHECK_LE(mat.info.num_row_, num_row); // provision for empty rows at the bottom of matrix + auto& offset_vec = mat.page_.offset.HostVector(); for (uint64_t i = mat.info.num_row_; i < static_cast(num_row); ++i) { - mat.page_.offset.push_back(mat.page_.offset.back()); + offset_vec.push_back(offset_vec.back()); } mat.info.num_row_ = num_row; - CHECK_EQ(mat.info.num_row_, mat.page_.offset.size() - 1); // sanity check + CHECK_EQ(mat.info.num_row_, offset_vec.size() - 1); // sanity check } mat.info.num_col_ = ncol; mat.info.num_nonzero_ = nelem; @@ -368,7 +373,9 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, API_BEGIN(); data::SimpleCSRSource& mat = *source; - mat.page_.offset.resize(1+nrow); + auto& offset_vec = mat.page_.offset.HostVector(); + auto& data_vec = mat.page_.data.HostVector(); + offset_vec.resize(1+nrow); bool nan_missing = common::CheckNAN(missing); mat.info.num_row_ = nrow; mat.info.num_col_ = ncol; @@ -388,9 +395,9 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, } } } - mat.page_.offset[i+1] = mat.page_.offset[i] + nelem; + offset_vec[i+1] = offset_vec[i] + nelem; } - mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back()); + data_vec.resize(mat.page_.data.Size() + offset_vec.back()); data = data0; for (xgboost::bst_ulong i = 0; i < nrow; ++i, data += ncol) { @@ -399,14 +406,14 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, if (common::CheckNAN(data[j])) { } else { if (nan_missing || data[j] != missing) { - mat.page_.data[mat.page_.offset[i] + matj] = Entry(j, data[j]); + data_vec[offset_vec[i] + matj] = Entry(j, data[j]); ++matj; } } } } - mat.info.num_nonzero_ = mat.page_.data.size(); + mat.info.num_nonzero_ = mat.page_.data.Size(); *out = new std::shared_ptr(DMatrix::Create(std::move(source))); API_END(); } @@ -461,7 +468,9 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT std::unique_ptr source(new data::SimpleCSRSource()); data::SimpleCSRSource& mat = *source; - mat.page_.offset.resize(1+nrow); + auto& offset_vec = mat.page_.offset.HostVector(); + auto& data_vec = mat.page_.data.HostVector(); + offset_vec.resize(1+nrow); mat.info.num_row_ = nrow; mat.info.num_col_ = ncol; @@ -487,7 +496,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT ++nelem; } } - mat.page_.offset[i+1] = nelem; + offset_vec[i+1] = nelem; } } // Inform about any NaNs and resize data matrix @@ -496,8 +505,8 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT } // do cumulative sum (to avoid otherwise need to copy) - PrefixSum(&mat.page_.offset[0], mat.page_.offset.size()); - mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back()); + PrefixSum(&offset_vec[0], offset_vec.size()); + data_vec.resize(mat.page_.data.Size() + offset_vec.back()); // Fill data matrix (now that know size, no need for slow push_back()) #pragma omp parallel num_threads(nthread) @@ -508,7 +517,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT for (xgboost::bst_ulong j = 0; j < ncol; ++j) { if (common::CheckNAN(data[ncol * i + j])) { } else if (nan_missing || data[ncol * i + j] != missing) { - mat.page_.data[mat.page_.offset[i] + matj] = + data_vec[offset_vec[i] + matj] = Entry(j, data[ncol * i + j]); ++matj; } @@ -518,7 +527,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT // restore omp state omp_set_num_threads(nthread_orig); - mat.info.num_nonzero_ = mat.page_.data.size(); + mat.info.num_nonzero_ = mat.page_.data.Size(); *out = new std::shared_ptr(DMatrix::Create(std::move(source))); API_END(); } @@ -611,10 +620,11 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, std::unique_ptr source(new data::SimpleCSRSource()); data::SimpleCSRSource& mat = *source; - mat.page_.offset.resize(1 + nrow); + mat.page_.offset.Resize(1 + nrow); mat.info.num_row_ = nrow; mat.info.num_col_ = ncol; + auto& page_offset = mat.page_.offset.HostVector(); #pragma omp parallel num_threads(nthread) { // Count elements per row, column by column @@ -624,15 +634,17 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, for (omp_ulong i = 0; i < nrow; ++i) { float val = DTGetValue(data[j], dtype, i); if (!std::isnan(val)) { - mat.page_.offset[i + 1]++; + page_offset[i + 1]++; } } } } // do cumulative sum (to avoid otherwise need to copy) - PrefixSum(&mat.page_.offset[0], mat.page_.offset.size()); + PrefixSum(&page_offset[0], page_offset.size()); - mat.page_.data.resize(mat.page_.data.size() + mat.page_.offset.back()); + mat.page_.data.Resize(mat.page_.data.Size() + page_offset.back()); + + auto& page_data = mat.page_.data.HostVector(); // Fill data matrix (now that know size, no need for slow push_back()) std::vector position(nrow); @@ -644,7 +656,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, for (omp_ulong i = 0; i < nrow; ++i) { float val = DTGetValue(data[j], dtype, i); if (!std::isnan(val)) { - mat.page_.data[mat.page_.offset[i] + position[i]] = Entry(j, val); + page_data[page_offset[i] + position[i]] = Entry(j, val); position[i]++; } } @@ -654,7 +666,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, // restore omp state omp_set_num_threads(nthread_orig); - mat.info.num_nonzero_ = mat.page_.data.size(); + mat.info.num_nonzero_ = mat.page_.data.Size(); *out = new std::shared_ptr(DMatrix::Create(std::move(source))); API_END(); } @@ -682,24 +694,33 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, iter->BeforeFirst(); CHECK(iter->Next()); - const auto& batch = iter->Value(); + const auto& batch = iter->Value(); + const auto& src_labels = src.info.labels_.ConstHostVector(); + const auto& src_weights = src.info.weights_.ConstHostVector(); + const auto& src_base_margin = src.info.base_margin_.ConstHostVector(); + auto& ret_labels = ret.info.labels_.HostVector(); + auto& ret_weights = ret.info.weights_.HostVector(); + auto& ret_base_margin = ret.info.base_margin_.HostVector(); + auto& offset_vec = ret.page_.offset.HostVector(); + auto& data_vec = ret.page_.data.HostVector(); + for (xgboost::bst_ulong i = 0; i < len; ++i) { const int ridx = idxset[i]; auto inst = batch[ridx]; CHECK_LT(static_cast(ridx), batch.Size()); - ret.page_.data.insert(ret.page_.data.end(), inst.data(), - inst.data() + inst.size()); - ret.page_.offset.push_back(ret.page_.offset.back() + inst.size()); + data_vec.insert(data_vec.end(), inst.data(), + inst.data() + inst.size()); + offset_vec.push_back(offset_vec.back() + inst.size()); ret.info.num_nonzero_ += inst.size(); - if (src.info.labels_.size() != 0) { - ret.info.labels_.push_back(src.info.labels_[ridx]); + if (src_labels.size() != 0) { + ret_labels.push_back(src_labels[ridx]); } - if (src.info.weights_.size() != 0) { - ret.info.weights_.push_back(src.info.weights_[ridx]); + if (src_weights.size() != 0) { + ret_weights.push_back(src_weights[ridx]); } - if (src.info.base_margin_.size() != 0) { - ret.info.base_margin_.push_back(src.info.base_margin_[ridx]); + if (src_base_margin.size() != 0) { + ret_base_margin.push_back(src_base_margin[ridx]); } if (src.info.root_index_.size() != 0) { ret.info.root_index_.push_back(src.info.root_index_[ridx]); @@ -771,11 +792,11 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle, const MetaInfo& info = static_cast*>(handle)->get()->Info(); const std::vector* vec = nullptr; if (!std::strcmp(field, "label")) { - vec = &info.labels_; + vec = &info.labels_.HostVector(); } else if (!std::strcmp(field, "weight")) { - vec = &info.weights_; + vec = &info.weights_.HostVector(); } else if (!std::strcmp(field, "base_margin")) { - vec = &info.base_margin_; + vec = &info.base_margin_.HostVector(); } else { LOG(FATAL) << "Unknown float field name " << field; } diff --git a/src/cli_main.cc b/src/cli_main.cc index 1bb39aaf027a..eb27191943dd 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -332,7 +332,7 @@ void CLIPredict(const CLIParam& param) { std::unique_ptr fo( dmlc::Stream::Create(param.name_pred.c_str(), "w")); dmlc::ostream os(fo.get()); - for (bst_float p : preds.HostVector()) { + for (bst_float p : preds.ConstHostVector()) { os << std::setprecision(std::numeric_limits::max_digits10 + 2) << p << '\n'; } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 028d0540c5b9..c9b3aa49738e 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -35,6 +35,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { auto iter = p_fmat->RowIterator(); iter->BeforeFirst(); + const auto& weights = info.weights_.HostVector(); while (iter->Next()) { auto &batch = iter->Value(); #pragma omp parallel num_threads(nthread) @@ -50,7 +51,8 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { SparsePage::Inst inst = batch[i]; for (auto& ins : inst) { if (ins.index >= begin && ins.index < end) { - sketchs[ins.index].Push(ins.fvalue, info.GetWeight(ridx)); + sketchs[ins.index].Push(ins.fvalue, + weights.size() > 0 ? weights[ridx] : 1.0f); } } } diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 0beb4fb55776..0676d17ba4e3 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -118,7 +118,7 @@ struct GPUSketcher { void Init(const SparsePage& row_batch, const MetaInfo& info) { num_cols_ = info.num_col_; - has_weights_ = info.weights_.size() > 0; + has_weights_ = info.weights_.Size() > 0; // find the batch size if (param_.gpu_batch_nrows == 0) { @@ -282,19 +282,23 @@ struct GPUSketcher { size_t batch_row_end = std::min((gpu_batch + 1) * gpu_batch_nrows_, static_cast(n_rows_)); size_t batch_nrows = batch_row_end - batch_row_begin; - size_t n_entries = - row_batch.offset[row_begin_ + batch_row_end] - - row_batch.offset[row_begin_ + batch_row_begin]; + + const auto& offset_vec = row_batch.offset.HostVector(); + const auto& data_vec = row_batch.data.HostVector(); + + size_t n_entries = offset_vec[row_begin_ + batch_row_end] - + offset_vec[row_begin_ + batch_row_begin]; // copy the batch to the GPU dh::safe_cuda (cudaMemcpy(entries_.data().get(), - &row_batch.data[row_batch.offset[row_begin_ + batch_row_begin]], + data_vec.data() + offset_vec[row_begin_ + batch_row_begin], n_entries * sizeof(Entry), cudaMemcpyDefault)); // copy the weights if necessary if (has_weights_) { + const auto& weights_vec = info.weights_.HostVector(); dh::safe_cuda (cudaMemcpy(weights_.data().get(), - info.weights_.data() + row_begin_ + batch_row_begin, + weights_vec.data() + row_begin_ + batch_row_begin, batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); } @@ -310,7 +314,7 @@ struct GPUSketcher { row_ptrs_.data().get() + batch_row_begin, has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), gpu_batch_nrows_, num_cols_, - row_batch.offset[row_begin_ + batch_row_begin], batch_nrows); + offset_vec[row_begin_ + batch_row_begin], batch_nrows); dh::safe_cuda(cudaGetLastError()); // NOLINT dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT @@ -331,13 +335,11 @@ struct GPUSketcher { void Sketch(const SparsePage& row_batch, const MetaInfo& info) { // copy rows to the device dh::safe_cuda(cudaSetDevice(device_)); + const auto& offset_vec = row_batch.offset.HostVector(); row_ptrs_.resize(n_rows_ + 1); - thrust::copy(row_batch.offset.data() + row_begin_, - row_batch.offset.data() + row_end_ + 1, - row_ptrs_.begin()); - + thrust::copy(offset_vec.data() + row_begin_, + offset_vec.data() + row_end_ + 1, row_ptrs_.begin()); size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_); - for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { SketchBatch(row_batch, info, gpu_batch); } diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index e306119f0ee1..ac7f3860e16b 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -6,7 +6,8 @@ // dummy implementation of HostDeviceVector in case CUDA is not used #include - +#include +#include #include #include "./host_device_vector.h" @@ -14,25 +15,27 @@ namespace xgboost { template struct HostDeviceVectorImpl { - explicit HostDeviceVectorImpl(size_t size, T v) : data_h_(size, v) {} - HostDeviceVectorImpl(std::initializer_list init) : data_h_(init) {} - explicit HostDeviceVectorImpl(std::vector init) : data_h_(std::move(init)) {} + explicit HostDeviceVectorImpl(size_t size, T v) : data_h_(size, v), distribution_() {} + HostDeviceVectorImpl(std::initializer_list init) : data_h_(init), distribution_() {} + explicit HostDeviceVectorImpl(std::vector init) : data_h_(std::move(init)), distribution_() {} std::vector data_h_; + GPUDistribution distribution_; }; template -HostDeviceVector::HostDeviceVector(size_t size, T v, GPUSet devices) : impl_(nullptr) { +HostDeviceVector::HostDeviceVector(size_t size, T v, GPUDistribution distribution) + : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(size, v); } template -HostDeviceVector::HostDeviceVector(std::initializer_list init, GPUSet devices) +HostDeviceVector::HostDeviceVector(std::initializer_list init, GPUDistribution distribution) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(init); } template -HostDeviceVector::HostDeviceVector(const std::vector& init, GPUSet devices) +HostDeviceVector::HostDeviceVector(const std::vector& init, GPUDistribution distribution) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(init); } @@ -44,33 +47,69 @@ HostDeviceVector::~HostDeviceVector() { delete tmp; } +template +HostDeviceVector::HostDeviceVector(const HostDeviceVector& other) + : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(*other.impl_); +} + +template +HostDeviceVector& HostDeviceVector::operator=(const HostDeviceVector& other) { + if (this == &other) { + return *this; + } + delete impl_; + impl_ = new HostDeviceVectorImpl(*other.impl_); + return *this; +} + template size_t HostDeviceVector::Size() const { return impl_->data_h_.size(); } template GPUSet HostDeviceVector::Devices() const { return GPUSet::Empty(); } +template +const GPUDistribution& HostDeviceVector::Distribution() const { + return impl_->distribution_; +} + template T* HostDeviceVector::DevicePointer(int device) { return nullptr; } +template +const T* HostDeviceVector::ConstDevicePointer(int device) const { + return nullptr; +} + template common::Span HostDeviceVector::DeviceSpan(int device) { return common::Span(); } +template +common::Span HostDeviceVector::ConstDeviceSpan(int device) const { + return common::Span(); +} + template std::vector& HostDeviceVector::HostVector() { return impl_->data_h_; } +template +const std::vector& HostDeviceVector::ConstHostVector() const { + return impl_->data_h_; +} + template void HostDeviceVector::Resize(size_t new_size, T v) { impl_->data_h_.resize(new_size, v); } template -size_t HostDeviceVector::DeviceStart(int device) { return 0; } +size_t HostDeviceVector::DeviceStart(int device) const { return 0; } template -size_t HostDeviceVector::DeviceSize(int device) { return 0; } +size_t HostDeviceVector::DeviceSize(int device) const { return 0; } template void HostDeviceVector::Fill(T v) { @@ -78,9 +117,9 @@ void HostDeviceVector::Fill(T v) { } template -void HostDeviceVector::Copy(HostDeviceVector* other) { - CHECK_EQ(Size(), other->Size()); - std::copy(other->HostVector().begin(), other->HostVector().end(), HostVector().begin()); +void HostDeviceVector::Copy(const HostDeviceVector& other) { + CHECK_EQ(Size(), other.Size()); + std::copy(other.HostVector().begin(), other.HostVector().end(), HostVector().begin()); } template @@ -96,13 +135,27 @@ void HostDeviceVector::Copy(std::initializer_list other) { } template -void HostDeviceVector::Reshard(GPUSet devices) { } +bool HostDeviceVector::HostCanAccess(GPUAccess access) const { + return true; +} + +template +bool HostDeviceVector::DeviceCanAccess(int device, GPUAccess access) const { + return false; +} + +template +void HostDeviceVector::Reshard(const GPUDistribution& distribution) const { } + +template +void HostDeviceVector::Reshard(GPUSet devices) const { } // explicit instantiations are required, as HostDeviceVector isn't header-only template class HostDeviceVector; template class HostDeviceVector; -template class HostDeviceVector; template class HostDeviceVector; +template class HostDeviceVector; +template class HostDeviceVector; } // namespace xgboost diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index d8c7755486d6..04130aacb26c 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -2,119 +2,159 @@ * Copyright 2017 XGBoost contributors */ - -#include #include "./host_device_vector.h" +#include +#include +#include +#include +#include #include "./device_helpers.cuh" + namespace xgboost { +// the handler to call instead of cudaSetDevice; only used for testing +static void (*cudaSetDeviceHandler)(int) = nullptr; // NOLINT + +void SetCudaSetDeviceHandler(void (*handler)(int)) { + cudaSetDeviceHandler = handler; +} + +// wrapper over access with useful methods +class Permissions { + GPUAccess access_; + explicit Permissions(GPUAccess access) : access_(access) {} + + public: + Permissions() : access_(GPUAccess::kNone) {} + explicit Permissions(bool perm) + : access_(perm ? GPUAccess::kWrite : GPUAccess::kNone) {} + + bool CanRead() const { return access_ >= kRead; } + bool CanWrite() const { return access_ == kWrite; } + bool CanAccess(GPUAccess access) const { return access_ >= access; } + void Grant(GPUAccess access) { access_ = std::max(access_, access); } + void DenyComplementary(GPUAccess compl_access) { + access_ = std::min(access_, GPUAccess::kWrite - compl_access); + } + Permissions Complementary() const { + return Permissions(GPUAccess::kWrite - access_); + } +}; template struct HostDeviceVectorImpl { struct DeviceShard { - DeviceShard() : index_(-1), device_(-1), start_(0), on_d_(false), vec_(nullptr) {} - - static size_t ShardStart(size_t size, int ndevices, int index) { - size_t portion = dh::DivRoundUp(size, ndevices); - size_t begin = index * portion; - begin = begin > size ? size : begin; - return begin; - } - - static size_t ShardSize(size_t size, int ndevices, int index) { - size_t portion = dh::DivRoundUp(size, ndevices); - size_t begin = index * portion, end = (index + 1) * portion; - begin = begin > size ? size : begin; - end = end > size ? size : end; - return end - begin; - } + DeviceShard() + : index_(-1), proper_size_(0), device_(-1), start_(0), perm_d_(false), + cached_size_(~0), vec_(nullptr) {} void Init(HostDeviceVectorImpl* vec, int device) { if (vec_ == nullptr) { vec_ = vec; } CHECK_EQ(vec, vec_); device_ = device; - index_ = vec_->devices_.Index(device); - size_t size_h = vec_->Size(); - int ndevices = vec_->devices_.Size(); - start_ = ShardStart(size_h, ndevices, index_); - size_t size_d = ShardSize(size_h, ndevices, index_); - dh::safe_cuda(cudaSetDevice(device_)); - data_.resize(size_d); - on_d_ = !vec_->on_h_; + index_ = vec_->distribution_.devices_.Index(device); + LazyResize(vec_->Size()); + perm_d_ = vec_->perm_h_.Complementary(); } void ScatterFrom(const T* begin) { // TODO(canonizer): avoid full copy of host data - LazySyncDevice(); - dh::safe_cuda(cudaSetDevice(device_)); + LazySyncDevice(GPUAccess::kWrite); + SetDevice(); dh::safe_cuda(cudaMemcpy(data_.data().get(), begin + start_, data_.size() * sizeof(T), cudaMemcpyDefault)); } void GatherTo(thrust::device_ptr begin) { - LazySyncDevice(); - dh::safe_cuda(cudaSetDevice(device_)); + LazySyncDevice(GPUAccess::kRead); + SetDevice(); dh::safe_cuda(cudaMemcpy(begin.get() + start_, data_.data().get(), - data_.size() * sizeof(T), cudaMemcpyDefault)); + proper_size_ * sizeof(T), cudaMemcpyDefault)); } void Fill(T v) { // TODO(canonizer): avoid full copy of host data - LazySyncDevice(); - dh::safe_cuda(cudaSetDevice(device_)); + LazySyncDevice(GPUAccess::kWrite); + SetDevice(); thrust::fill(data_.begin(), data_.end(), v); } void Copy(DeviceShard* other) { // TODO(canonizer): avoid full copy of host data for this (but not for other) - LazySyncDevice(); - other->LazySyncDevice(); - dh::safe_cuda(cudaSetDevice(device_)); + LazySyncDevice(GPUAccess::kWrite); + other->LazySyncDevice(GPUAccess::kRead); + SetDevice(); dh::safe_cuda(cudaMemcpy(data_.data().get(), other->data_.data().get(), data_.size() * sizeof(T), cudaMemcpyDefault)); } - void LazySyncHost() { - dh::safe_cuda(cudaSetDevice(device_)); + void LazySyncHost(GPUAccess access) { + SetDevice(); dh::safe_cuda(cudaMemcpy(vec_->data_h_.data() + start_, - data_.data().get(), data_.size() * sizeof(T), + data_.data().get(), proper_size_ * sizeof(T), cudaMemcpyDeviceToHost)); - on_d_ = false; + perm_d_.DenyComplementary(access); } - void LazySyncDevice() { - if (on_d_) { return; } + void LazyResize(size_t new_size) { + if (new_size == cached_size_) { return; } + // resize is required + int ndevices = vec_->distribution_.devices_.Size(); + start_ = vec_->distribution_.ShardStart(new_size, index_); + proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_); + size_t size_d = vec_->distribution_.ShardSize(new_size, index_); + SetDevice(); + data_.resize(size_d); + cached_size_ = new_size; + } + + void LazySyncDevice(GPUAccess access) { + if (perm_d_.CanAccess(access)) { return; } + if (perm_d_.CanRead()) { + // deny read to the host + perm_d_.Grant(access); + std::lock_guard lock(vec_->mutex_); + vec_->perm_h_.DenyComplementary(access); + return; + } // data is on the host size_t size_h = vec_->data_h_.size(); - int ndevices = vec_->devices_.Size(); - start_ = ShardStart(size_h, ndevices, index_); - size_t size_d = ShardSize(size_h, ndevices, index_); - dh::safe_cuda(cudaSetDevice(device_)); - data_.resize(size_d); - dh::safe_cuda(cudaMemcpy(data_.data().get(), - vec_->data_h_.data() + start_, - size_d * sizeof(T), cudaMemcpyHostToDevice)); - on_d_ = true; - // this may cause a race condition if LazySyncDevice() is called - // from multiple threads in parallel; - // however, the race condition is benign, and will not cause problems - vec_->on_h_ = false; - vec_->size_d_ = vec_->data_h_.size(); + LazyResize(size_h); + SetDevice(); + dh::safe_cuda( + cudaMemcpy(data_.data().get(), vec_->data_h_.data() + start_, + data_.size() * sizeof(T), cudaMemcpyHostToDevice)); + perm_d_.Grant(access); + + std::lock_guard lock(vec_->mutex_); + vec_->perm_h_.DenyComplementary(access); + vec_->size_d_ = size_h; + } + + void SetDevice() { + if (cudaSetDeviceHandler == nullptr) { + dh::safe_cuda(cudaSetDevice(device_)); + } else { + (*cudaSetDeviceHandler)(device_); + } } int index_; int device_; thrust::device_vector data_; + // cached vector size + size_t cached_size_; size_t start_; - // true if there is an up-to-date copy of data on device, false otherwise - bool on_d_; + // size of the portion to copy back to the host + size_t proper_size_; + Permissions perm_d_; HostDeviceVectorImpl* vec_; }; - HostDeviceVectorImpl(size_t size, T v, GPUSet devices) - : devices_(devices), on_h_(devices.IsEmpty()), size_d_(0) { - if (!devices.IsEmpty()) { + HostDeviceVectorImpl(size_t size, T v, GPUDistribution distribution) + : distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { + if (!distribution_.IsEmpty()) { size_d_ = size; InitShards(); Fill(v); @@ -123,11 +163,16 @@ struct HostDeviceVectorImpl { } } + // required, as a new std::mutex has to be created + HostDeviceVectorImpl(const HostDeviceVectorImpl& other) + : data_h_(other.data_h_), perm_h_(other.perm_h_), size_d_(other.size_d_), + distribution_(other.distribution_), mutex_(), shards_(other.shards_) {} + // Init can be std::vector or std::initializer_list template - HostDeviceVectorImpl(const Init& init, GPUSet devices) - : devices_(devices), on_h_(devices.IsEmpty()), size_d_(0) { - if (!devices.IsEmpty()) { + HostDeviceVectorImpl(const Init& init, GPUDistribution distribution) + : distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { + if (!distribution_.IsEmpty()) { size_d_ = init.size(); InitShards(); Copy(init); @@ -137,58 +182,78 @@ struct HostDeviceVectorImpl { } void InitShards() { - int ndevices = devices_.Size(); + int ndevices = distribution_.devices_.Size(); shards_.resize(ndevices); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { - shard.Init(this, devices_[i]); + shard.Init(this, distribution_.devices_[i]); }); } - HostDeviceVectorImpl(const HostDeviceVectorImpl&) = delete; - HostDeviceVectorImpl(HostDeviceVectorImpl&&) = delete; - void operator=(const HostDeviceVectorImpl&) = delete; - void operator=(HostDeviceVectorImpl&&) = delete; + size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : size_d_; } - size_t Size() const { return on_h_ ? data_h_.size() : size_d_; } + GPUSet Devices() const { return distribution_.devices_; } - GPUSet Devices() const { return devices_; } + const GPUDistribution& Distribution() const { return distribution_; } T* DevicePointer(int device) { - CHECK(devices_.Contains(device)); - LazySyncDevice(device); - return shards_[devices_.Index(device)].data_.data().get(); + CHECK(distribution_.devices_.Contains(device)); + LazySyncDevice(device, GPUAccess::kWrite); + return shards_[distribution_.devices_.Index(device)].data_.data().get(); + } + + const T* ConstDevicePointer(int device) { + CHECK(distribution_.devices_.Contains(device)); + LazySyncDevice(device, GPUAccess::kRead); + return shards_[distribution_.devices_.Index(device)].data_.data().get(); } common::Span DeviceSpan(int device) { - CHECK(devices_.Contains(device)); - LazySyncDevice(device); - return { shards_[devices_.Index(device)].data_.data().get(), - static_cast::index_type>(Size()) }; + GPUSet devices = distribution_.devices_; + CHECK(devices.Contains(device)); + LazySyncDevice(device, GPUAccess::kWrite); + return {shards_[devices.Index(device)].data_.data().get(), + static_cast::index_type>(Size())}; + } + + common::Span ConstDeviceSpan(int device) { + GPUSet devices = distribution_.devices_; + CHECK(devices.Contains(device)); + LazySyncDevice(device, GPUAccess::kRead); + return {shards_[devices.Index(device)].data_.data().get(), + static_cast::index_type>(Size())}; } size_t DeviceSize(int device) { - CHECK(devices_.Contains(device)); - LazySyncDevice(device); - return shards_[devices_.Index(device)].data_.size(); + CHECK(distribution_.devices_.Contains(device)); + LazySyncDevice(device, GPUAccess::kRead); + return shards_[distribution_.devices_.Index(device)].data_.size(); } size_t DeviceStart(int device) { - CHECK(devices_.Contains(device)); - LazySyncDevice(device); - return shards_[devices_.Index(device)].start_; + CHECK(distribution_.devices_.Contains(device)); + LazySyncDevice(device, GPUAccess::kRead); + return shards_[distribution_.devices_.Index(device)].start_; } thrust::device_ptr tbegin(int device) { // NOLINT return thrust::device_ptr(DevicePointer(device)); } + thrust::device_ptr tcbegin(int device) { // NOLINT + return thrust::device_ptr(ConstDevicePointer(device)); + } + thrust::device_ptr tend(int device) { // NOLINT return tbegin(device) + DeviceSize(device); } - void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end) { + thrust::device_ptr tcend(int device) { // NOLINT + return tcbegin(device) + DeviceSize(device); + } + + void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); - if (on_h_) { + if (perm_h_.CanWrite()) { dh::safe_cuda(cudaMemcpy(data_h_.data(), begin.get(), (end - begin) * sizeof(T), cudaMemcpyDeviceToHost)); @@ -201,7 +266,7 @@ struct HostDeviceVectorImpl { void GatherTo(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); - if (on_h_) { + if (perm_h_.CanWrite()) { dh::safe_cuda(cudaMemcpy(begin.get(), data_h_.data(), data_h_.size() * sizeof(T), cudaMemcpyHostToDevice)); @@ -211,7 +276,7 @@ struct HostDeviceVectorImpl { } void Fill(T v) { - if (on_h_) { + if (perm_h_.CanWrite()) { std::fill(data_h_.begin(), data_h_.end(), v); } else { dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.Fill(v); }); @@ -220,10 +285,10 @@ struct HostDeviceVectorImpl { void Copy(HostDeviceVectorImpl* other) { CHECK_EQ(Size(), other->Size()); - if (on_h_ && other->on_h_) { + if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) { std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin()); } else { - CHECK(devices_ == other->devices_); + CHECK(distribution_ == other->distribution_); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { shard.Copy(&other->shards_[i]); }); @@ -232,7 +297,7 @@ struct HostDeviceVectorImpl { void Copy(const std::vector& other) { CHECK_EQ(Size(), other.size()); - if (on_h_) { + if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { @@ -243,7 +308,7 @@ struct HostDeviceVectorImpl { void Copy(std::initializer_list other) { CHECK_EQ(Size(), other.size()); - if (on_h_) { + if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { @@ -253,72 +318,117 @@ struct HostDeviceVectorImpl { } std::vector& HostVector() { - LazySyncHost(); + LazySyncHost(GPUAccess::kWrite); return data_h_; } - void Reshard(GPUSet new_devices) { - if (devices_ == new_devices) - return; - CHECK(devices_.IsEmpty()); - devices_ = new_devices; + const std::vector& ConstHostVector() { + LazySyncHost(GPUAccess::kRead); + return data_h_; + } + + void Reshard(const GPUDistribution& distribution) { + if (distribution_ == distribution) { return; } + CHECK(distribution_.IsEmpty()); + distribution_ = distribution; InitShards(); } + void Reshard(GPUSet new_devices) { + if (distribution_.Devices() == new_devices) { return; } + Reshard(GPUDistribution::Block(new_devices)); + } + void Resize(size_t new_size, T v) { - if (new_size == Size()) - return; - if (Size() == 0 && !devices_.IsEmpty()) { + if (new_size == Size()) { return; } + if (distribution_.IsFixedSize()) { + CHECK_EQ(new_size, distribution_.offsets_.back()); + } + if (Size() == 0 && !distribution_.IsEmpty()) { // fast on-device resize - on_h_ = false; + perm_h_ = Permissions(false); size_d_ = new_size; InitShards(); Fill(v); } else { // resize on host - LazySyncHost(); + LazySyncHost(GPUAccess::kWrite); data_h_.resize(new_size, v); } } - void LazySyncHost() { - if (on_h_) + void LazySyncHost(GPUAccess access) { + if (perm_h_.CanAccess(access)) { return; } + if (perm_h_.CanRead()) { + // data is present, just need to deny access to the device + dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + shard.perm_d_.DenyComplementary(access); + }); + perm_h_.Grant(access); return; - if (data_h_.size() != size_d_) - data_h_.resize(size_d_); - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.LazySyncHost(); }); - on_h_ = true; + } + if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } + dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + shard.LazySyncHost(access); + }); + perm_h_.Grant(access); + } + + void LazySyncDevice(int device, GPUAccess access) { + GPUSet devices = distribution_.Devices(); + CHECK(devices.Contains(device)); + shards_[devices.Index(device)].LazySyncDevice(access); } - void LazySyncDevice(int device) { - CHECK(devices_.Contains(device)); - shards_[devices_.Index(device)].LazySyncDevice(); + bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); } + + bool DeviceCanAccess(int device, GPUAccess access) { + GPUSet devices = distribution_.Devices(); + if (!devices.Contains(device)) { return false; } + return shards_[devices.Index(device)].perm_d_.CanAccess(access); } std::vector data_h_; - bool on_h_; + Permissions perm_h_; // the total size of the data stored on the devices size_t size_d_; - GPUSet devices_; + GPUDistribution distribution_; + // protects size_d_ and perm_h_ when updated from multiple threads + std::mutex mutex_; std::vector shards_; }; template -HostDeviceVector::HostDeviceVector(size_t size, T v, GPUSet devices) - : impl_(nullptr) { - impl_ = new HostDeviceVectorImpl(size, v, devices); +HostDeviceVector::HostDeviceVector +(size_t size, T v, GPUDistribution distribution) : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(size, v, distribution); } template -HostDeviceVector::HostDeviceVector(std::initializer_list init, GPUSet devices) - : impl_(nullptr) { - impl_ = new HostDeviceVectorImpl(init, devices); +HostDeviceVector::HostDeviceVector +(std::initializer_list init, GPUDistribution distribution) : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(init, distribution); } template -HostDeviceVector::HostDeviceVector(const std::vector& init, GPUSet devices) +HostDeviceVector::HostDeviceVector +(const std::vector& init, GPUDistribution distribution) : impl_(nullptr) { + impl_ = new HostDeviceVectorImpl(init, distribution); +} + +template +HostDeviceVector::HostDeviceVector(const HostDeviceVector& other) : impl_(nullptr) { - impl_ = new HostDeviceVectorImpl(init, devices); + impl_ = new HostDeviceVectorImpl(*other.impl_); +} + +template +HostDeviceVector& HostDeviceVector::operator= +(const HostDeviceVector& other) { + if (this == &other) { return *this; } + delete impl_; + impl_ = new HostDeviceVectorImpl(*other.impl_); + return *this; } template @@ -335,7 +445,19 @@ template GPUSet HostDeviceVector::Devices() const { return impl_->Devices(); } template -T* HostDeviceVector::DevicePointer(int device) { return impl_->DevicePointer(device); } +const GPUDistribution& HostDeviceVector::Distribution() const { + return impl_->Distribution(); +} + +template +T* HostDeviceVector::DevicePointer(int device) { + return impl_->DevicePointer(device); +} + +template +const T* HostDeviceVector::ConstDevicePointer(int device) const { + return impl_->ConstDevicePointer(device); +} template common::Span HostDeviceVector::DeviceSpan(int device) { @@ -343,30 +465,49 @@ common::Span HostDeviceVector::DeviceSpan(int device) { } template -size_t HostDeviceVector::DeviceStart(int device) { return impl_->DeviceStart(device); } +common::Span HostDeviceVector::ConstDeviceSpan(int device) const { + return impl_->ConstDeviceSpan(device); +} template -size_t HostDeviceVector::DeviceSize(int device) { return impl_->DeviceSize(device); } +size_t HostDeviceVector::DeviceStart(int device) const { + return impl_->DeviceStart(device); +} + +template +size_t HostDeviceVector::DeviceSize(int device) const { + return impl_->DeviceSize(device); +} template thrust::device_ptr HostDeviceVector::tbegin(int device) { // NOLINT return impl_->tbegin(device); } +template +thrust::device_ptr HostDeviceVector::tcbegin(int device) const { // NOLINT + return impl_->tcbegin(device); +} + template thrust::device_ptr HostDeviceVector::tend(int device) { // NOLINT return impl_->tend(device); } +template +thrust::device_ptr HostDeviceVector::tcend(int device) const { // NOLINT + return impl_->tcend(device); +} + template void HostDeviceVector::ScatterFrom -(thrust::device_ptr begin, thrust::device_ptr end) { +(thrust::device_ptr begin, thrust::device_ptr end) { impl_->ScatterFrom(begin, end); } template void HostDeviceVector::GatherTo -(thrust::device_ptr begin, thrust::device_ptr end) { +(thrust::device_ptr begin, thrust::device_ptr end) const { impl_->GatherTo(begin, end); } @@ -376,8 +517,8 @@ void HostDeviceVector::Fill(T v) { } template -void HostDeviceVector::Copy(HostDeviceVector* other) { - impl_->Copy(other->impl_); +void HostDeviceVector::Copy(const HostDeviceVector& other) { + impl_->Copy(other.impl_); } template @@ -394,10 +535,30 @@ template std::vector& HostDeviceVector::HostVector() { return impl_->HostVector(); } template -void HostDeviceVector::Reshard(GPUSet new_devices) { +const std::vector& HostDeviceVector::ConstHostVector() const { + return impl_->ConstHostVector(); +} + +template +bool HostDeviceVector::HostCanAccess(GPUAccess access) const { + return impl_->HostCanAccess(access); +} + +template +bool HostDeviceVector::DeviceCanAccess(int device, GPUAccess access) const { + return impl_->DeviceCanAccess(device, access); +} + +template +void HostDeviceVector::Reshard(GPUSet new_devices) const { impl_->Reshard(new_devices); } +template +void HostDeviceVector::Reshard(const GPUDistribution& distribution) const { + impl_->Reshard(distribution); +} + template void HostDeviceVector::Resize(size_t new_size, T v) { impl_->Resize(new_size, v); @@ -406,7 +567,8 @@ void HostDeviceVector::Resize(size_t new_size, T v) { // explicit instantiations are required, as HostDeviceVector isn't header-only template class HostDeviceVector; template class HostDeviceVector; -template class HostDeviceVector; template class HostDeviceVector; +template class HostDeviceVector; +template class HostDeviceVector; } // namespace xgboost diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index 3bab1009541d..a3ef3082b43e 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -1,28 +1,6 @@ /*! * Copyright 2017 XGBoost contributors */ -#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ -#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ - -#include - -#include -#include -#include -#include - -#include "gpu_set.h" -#include "span.h" - -// only include thrust-related files if host_device_vector.h -// is included from a .cu file -#ifdef __CUDACC__ -#include -#endif - -namespace xgboost { - -template struct HostDeviceVectorImpl; /** * @file host_device_vector.h @@ -70,44 +48,203 @@ template struct HostDeviceVectorImpl; * if different threads call these methods with different values of the device argument. * All other methods are not thread safe. */ + +#ifndef XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ +#define XGBOOST_COMMON_HOST_DEVICE_VECTOR_H_ + +#include + +#include +#include +#include +#include + +#include "gpu_set.h" +#include "span.h" + +// only include thrust-related files if host_device_vector.h +// is included from a .cu file +#ifdef __CUDACC__ +#include +#endif + +namespace xgboost { + +#ifdef __CUDACC__ +// Sets a function to call instead of cudaSetDevice(); +// only added for testing +void SetCudaSetDeviceHandler(void (*handler)(int)); +#endif + +template struct HostDeviceVectorImpl; + +// Distribution for the HostDeviceVector; it specifies such aspects as the devices it is +// distributed on, whether there are copies of elements from other GPUs as well as the granularity +// of splitting. It may also specify explicit boundaries for devices, in which case the size of the +// array cannot be changed. +class GPUDistribution { + template friend struct HostDeviceVectorImpl; + + public: + explicit GPUDistribution(GPUSet devices = GPUSet::Empty()) + : devices_(devices), granularity_(1), overlap_(0) {} + + private: + GPUDistribution(GPUSet devices, int granularity, int overlap, + std::vector offsets) + : devices_(devices), granularity_(granularity), overlap_(overlap), + offsets_(std::move(offsets)) {} + + public: + static GPUDistribution Block(GPUSet devices) { return GPUDistribution(devices); } + + static GPUDistribution Overlap(GPUSet devices, int overlap) { + return GPUDistribution(devices, 1, overlap, std::vector()); + } + + static GPUDistribution Granular(GPUSet devices, int granularity) { + return GPUDistribution(devices, granularity, 0, std::vector()); + } + + static GPUDistribution Explicit(GPUSet devices, std::vector offsets) { + return GPUDistribution(devices, 1, 0, offsets); + } + + friend bool operator==(const GPUDistribution& a, const GPUDistribution& b) { + return a.devices_ == b.devices_ && a.granularity_ == b.granularity_ && + a.overlap_ == b.overlap_ && a.offsets_ == b.offsets_; + } + + friend bool operator!=(const GPUDistribution& a, const GPUDistribution& b) { + return !(a == b); + } + + GPUSet Devices() const { return devices_; } + + bool IsEmpty() const { return devices_.IsEmpty(); } + + size_t ShardStart(size_t size, int index) const { + if (size == 0) { return 0; } + if (offsets_.size() > 0) { + // explicit offsets are provided + CHECK_EQ(offsets_.back(), size); + return offsets_.at(index); + } + // no explicit offsets + size_t begin = std::min(index * Portion(size), size); + begin = begin > size ? size : begin; + return begin; + } + + size_t ShardSize(size_t size, int index) const { + if (size == 0) { return 0; } + if (offsets_.size() > 0) { + // explicit offsets are provided + CHECK_EQ(offsets_.back(), size); + return offsets_.at(index + 1) - offsets_.at(index) + + (index == devices_.Size() - 1 ? overlap_ : 0); + } + size_t portion = Portion(size); + size_t begin = std::min(index * portion, size); + size_t end = std::min((index + 1) * portion + overlap_ * granularity_, size); + return end - begin; + } + + size_t ShardProperSize(size_t size, int index) const { + if (size == 0) { return 0; } + return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0); + } + + bool IsFixedSize() const { return !offsets_.empty(); } + + private: + static size_t DivRoundUp(size_t a, size_t b) { return (a + b - 1) / b; } + static size_t RoundUp(size_t a, size_t b) { return DivRoundUp(a, b) * b; } + + size_t Portion(size_t size) const { + return RoundUp + (DivRoundUp + (std::max(static_cast(size - overlap_ * granularity_), + static_cast(1)), + devices_.Size()), granularity_); + } + + GPUSet devices_; + int granularity_; + int overlap_; + // explicit offsets for the GPU parts, if any + std::vector offsets_; +}; + +enum GPUAccess { + kNone, kRead, + // write implies read + kWrite +}; + +inline GPUAccess operator-(GPUAccess a, GPUAccess b) { + return static_cast(static_cast(a) - static_cast(b)); +} + template class HostDeviceVector { public: explicit HostDeviceVector(size_t size = 0, T v = T(), - GPUSet devices = GPUSet::Empty()); - HostDeviceVector(std::initializer_list init, GPUSet devices = GPUSet::Empty()); + GPUDistribution distribution = GPUDistribution()); + HostDeviceVector(std::initializer_list init, + GPUDistribution distribution = GPUDistribution()); explicit HostDeviceVector(const std::vector& init, - GPUSet devices = GPUSet::Empty()); + GPUDistribution distribution = GPUDistribution()); ~HostDeviceVector(); - HostDeviceVector(const HostDeviceVector&) = delete; - HostDeviceVector(HostDeviceVector&&) = delete; - void operator=(const HostDeviceVector&) = delete; - void operator=(HostDeviceVector&&) = delete; + HostDeviceVector(const HostDeviceVector&); + HostDeviceVector& operator=(const HostDeviceVector&); size_t Size() const; GPUSet Devices() const; - T* DevicePointer(int device); + const GPUDistribution& Distribution() const; common::Span DeviceSpan(int device); + common::Span ConstDeviceSpan(int device) const; + common::Span DeviceSpan(int device) const { return ConstDeviceSpan(device); } + T* DevicePointer(int device); + const T* ConstDevicePointer(int device) const; + const T* DevicePointer(int device) const { return ConstDevicePointer(device); } T* HostPointer() { return HostVector().data(); } - size_t DeviceStart(int device); - size_t DeviceSize(int device); + const T* ConstHostPointer() const { return ConstHostVector().data(); } + const T* HostPointer() const { return ConstHostPointer(); } + + size_t DeviceStart(int device) const; + size_t DeviceSize(int device) const; // only define functions returning device_ptr // if HostDeviceVector.h is included from a .cu file #ifdef __CUDACC__ thrust::device_ptr tbegin(int device); // NOLINT thrust::device_ptr tend(int device); // NOLINT - void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end); - void GatherTo(thrust::device_ptr begin, thrust::device_ptr end); + thrust::device_ptr tcbegin(int device) const; // NOLINT + thrust::device_ptr tcend(int device) const; // NOLINT + thrust::device_ptr tbegin(int device) const { // NOLINT + return tcbegin(device); + } + thrust::device_ptr tend(int device) const { return tcend(device); } // NOLINT + + void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end); + void GatherTo(thrust::device_ptr begin, thrust::device_ptr end) const; #endif void Fill(T v); - void Copy(HostDeviceVector* other); + void Copy(const HostDeviceVector& other); void Copy(const std::vector& other); void Copy(std::initializer_list other); std::vector& HostVector(); - void Reshard(GPUSet devices); + const std::vector& ConstHostVector() const; + const std::vector& HostVector() const {return ConstHostVector(); } + + bool HostCanAccess(GPUAccess access) const; + bool DeviceCanAccess(int device, GPUAccess access) const; + + void Reshard(const GPUDistribution& distribution) const; + void Reshard(GPUSet devices) const; void Resize(size_t new_size, T v = T()); private: diff --git a/src/data/data.cc b/src/data/data.cc index a8bf41c46655..4b24d5da88b5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -25,12 +25,12 @@ namespace xgboost { // implementation of inline functions void MetaInfo::Clear() { num_row_ = num_col_ = num_nonzero_ = 0; - labels_.clear(); + labels_.HostVector().clear(); root_index_.clear(); group_ptr_.clear(); qids_.clear(); - weights_.clear(); - base_margin_.clear(); + weights_.HostVector().clear(); + base_margin_.HostVector().clear(); } void MetaInfo::SaveBinary(dmlc::Stream *fo) const { @@ -39,12 +39,12 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { fo->Write(&num_row_, sizeof(num_row_)); fo->Write(&num_col_, sizeof(num_col_)); fo->Write(&num_nonzero_, sizeof(num_nonzero_)); - fo->Write(labels_); + fo->Write(labels_.HostVector()); fo->Write(group_ptr_); fo->Write(qids_); - fo->Write(weights_); + fo->Write(weights_.HostVector()); fo->Write(root_index_); - fo->Write(base_margin_); + fo->Write(base_margin_.HostVector()); } void MetaInfo::LoadBinary(dmlc::Stream *fi) { @@ -55,16 +55,16 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format"; CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_)) << "MetaInfo: invalid format"; - CHECK(fi->Read(&labels_)) << "MetaInfo: invalid format"; + CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format"; if (version >= kVersionQidAdded) { CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format"; } else { // old format doesn't contain qid field qids_.clear(); } - CHECK(fi->Read(&weights_)) << "MetaInfo: invalid format"; + CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format"; - CHECK(fi->Read(&base_margin_)) << "MetaInfo: invalid format"; + CHECK(fi->Read(&base_margin_.HostVector())) << "MetaInfo: invalid format"; } // try to load group information from file, if exists @@ -121,17 +121,20 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, std::copy(cast_dptr, cast_dptr + num, root_index_.begin())); } else if (!std::strcmp(key, "label")) { - labels_.resize(num); + auto& labels = labels_.HostVector(); + labels.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, labels_.begin())); + std::copy(cast_dptr, cast_dptr + num, labels.begin())); } else if (!std::strcmp(key, "weight")) { - weights_.resize(num); + auto& weights = weights_.HostVector(); + weights.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, weights_.begin())); + std::copy(cast_dptr, cast_dptr + num, weights.begin())); } else if (!std::strcmp(key, "base_margin")) { - base_margin_.resize(num); + auto& base_margin = base_margin_.HostVector(); + base_margin.resize(num); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, - std::copy(cast_dptr, cast_dptr + num, base_margin_.begin())); + std::copy(cast_dptr, cast_dptr + num, base_margin.begin())); } else if (!std::strcmp(key, "group")) { group_ptr_.resize(num + 1); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, @@ -230,12 +233,14 @@ DMatrix* DMatrix::Load(const std::string& uri, LOG(CONSOLE) << info.group_ptr_.size() - 1 << " groups are loaded from " << fname << ".group"; } - if (MetaTryLoadFloatInfo(fname + ".base_margin", &info.base_margin_) && !silent) { - LOG(CONSOLE) << info.base_margin_.size() + if (MetaTryLoadFloatInfo + (fname + ".base_margin", &info.base_margin_.HostVector()) && !silent) { + LOG(CONSOLE) << info.base_margin_.Size() << " base_margin are loaded from " << fname << ".base_margin"; } - if (MetaTryLoadFloatInfo(fname + ".weight", &info.weights_) && !silent) { - LOG(CONSOLE) << info.weights_.size() + if (MetaTryLoadFloatInfo + (fname + ".weight", &info.weights_.HostVector()) && !silent) { + LOG(CONSOLE) << info.weights_.Size() << " weights are loaded from " << fname << ".weight"; } } diff --git a/src/data/simple_csr_source.cc b/src/data/simple_csr_source.cc index 816404fad676..c13d96d678dc 100644 --- a/src/data/simple_csr_source.cc +++ b/src/data/simple_csr_source.cc @@ -35,10 +35,12 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { while (parser->Next()) { const dmlc::RowBlock& batch = parser->Value(); if (batch.label != nullptr) { - info.labels_.insert(info.labels_.end(), batch.label, batch.label + batch.size); + auto& labels = info.labels_.HostVector(); + labels.insert(labels.end(), batch.label, batch.label + batch.size); } if (batch.weight != nullptr) { - info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size); + auto& weights = info.weights_.HostVector(); + weights.insert(weights.end(), batch.weight, batch.weight + batch.size); } if (batch.qid != nullptr) { info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size); @@ -62,16 +64,18 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { // update information this->info.num_row_ += batch.size; // copy the data over + auto& data_vec = page_.data.HostVector(); + auto& offset_vec = page_.offset.HostVector(); for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) { uint32_t index = batch.index[i]; bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i]; - page_.data.emplace_back(index, fvalue); + data_vec.emplace_back(index, fvalue); this->info.num_col_ = std::max(this->info.num_col_, static_cast(index + 1)); } - size_t top = page_.offset.size(); + size_t top = page_.offset.Size(); for (size_t i = 0; i < batch.size; ++i) { - page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]); + offset_vec.push_back(offset_vec[top - 1] + batch.offset[i + 1] - batch.offset[0]); } } if (last_group_id != default_max) { @@ -79,7 +83,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { info.group_ptr_.push_back(group_size); } } - this->info.num_nonzero_ = static_cast(page_.data.size()); + this->info.num_nonzero_ = static_cast(page_.data.Size()); // Either every row has query ID or none at all CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_); } @@ -89,16 +93,16 @@ void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) { CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format"; CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; info.LoadBinary(fi); - fi->Read(&page_.offset); - fi->Read(&page_.data); + fi->Read(&page_.offset.HostVector()); + fi->Read(&page_.data.HostVector()); } void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const { int tmagic = kMagic; fo->Write(&tmagic, sizeof(tmagic)); info.SaveBinary(fo); - fo->Write(page_.offset); - fo->Write(page_.data); + fo->Write(page_.offset.HostVector()); + fo->Write(page_.data.HostVector()); } void SimpleCSRSource::BeforeFirst() { diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index c14faf0ced4e..98a29396750c 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -41,8 +41,10 @@ void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) { // bit map const int nthread = omp_get_max_threads(); pcol->Clear(); + auto& pcol_offset_vec = pcol->offset.HostVector(); + auto& pcol_data_vec = pcol->data.HostVector(); common::ParallelGroupBuilder - builder(&pcol->offset, &pcol->data); + builder(&pcol_offset_vec, &pcol_data_vec); builder.InitBudget(Info().num_col_, nthread); // start working auto iter = this->RowIterator(); @@ -88,9 +90,9 @@ void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) { auto ncol = static_cast(pcol->Size()); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) for (bst_omp_uint i = 0; i < ncol; ++i) { - if (pcol->offset[i] < pcol->offset[i + 1]) { - std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], - dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], + if (pcol_offset_vec[i] < pcol_offset_vec[i + 1]) { + std::sort(dmlc::BeginPtr(pcol_data_vec) + pcol_offset_vec[i], + dmlc::BeginPtr(pcol_data_vec) + pcol_offset_vec[i + 1], Entry::CmpValue); } } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 55e078d847b2..06c195e140be 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -141,15 +141,19 @@ void SparsePageDMatrix::InitColAccess( pcol->Clear(); pcol->base_rowid = buffered_rowset_[begin]; const int nthread = std::max(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 1, 1)); + auto& offset_vec = pcol->offset.HostVector(); + auto& data_vec = pcol->data.HostVector(); common::ParallelGroupBuilder - builder(&pcol->offset, &pcol->data); + builder(&offset_vec, &data_vec); builder.InitBudget(info.num_col_, nthread); bst_omp_uint ndata = static_cast(prow.Size()); + const auto& prow_offset_vec = prow.offset.HostVector(); + const auto& prow_data_vec = prow.data.HostVector(); #pragma omp parallel for schedule(static) num_threads(nthread) for (bst_omp_uint i = 0; i < ndata; ++i) { int tid = omp_get_thread_num(); - for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) { - const auto e = prow.data[j]; + for (size_t j = prow_offset_vec[i]; j < prow_offset_vec[i+1]; ++j) { + const auto e = prow_data_vec[j]; builder.AddBudget(e.index, tid); } } @@ -157,8 +161,8 @@ void SparsePageDMatrix::InitColAccess( #pragma omp parallel for schedule(static) num_threads(nthread) for (bst_omp_uint i = 0; i < ndata; ++i) { int tid = omp_get_thread_num(); - for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) { - const Entry &e = prow.data[j]; + for (size_t j = prow_offset_vec[i]; j < prow_offset_vec[i+1]; ++j) { + const Entry &e = prow_data_vec[j]; builder.Push(e.index, Entry(buffered_rowset_[i + begin], e.fvalue), tid); @@ -170,9 +174,9 @@ void SparsePageDMatrix::InitColAccess( auto ncol = static_cast(pcol->Size()); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) for (bst_omp_uint i = 0; i < ncol; ++i) { - if (pcol->offset[i] < pcol->offset[i + 1]) { - std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i], - dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1], + if (offset_vec[i] < offset_vec[i + 1]) { + std::sort(dmlc::BeginPtr(data_vec) + offset_vec[i], + dmlc::BeginPtr(data_vec) + offset_vec[i + 1], Entry::CmpValue); } } @@ -233,8 +237,9 @@ void SparsePageDMatrix::InitColAccess( size_t tick_expected = kStep; while (make_next_col(page.get())) { + const auto& page_offset_vec = page->offset.ConstHostVector(); for (size_t i = 0; i < page->Size(); ++i) { - col_size_[i] += page->offset[i + 1] - page->offset[i]; + col_size_[i] += page_offset_vec[i + 1] - page_offset_vec[i]; } bytes_write += page->MemCostBytes(); diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc index 053bfde88cec..0cb3b6ebed0a 100644 --- a/src/data/sparse_page_raw_format.cc +++ b/src/data/sparse_page_raw_format.cc @@ -15,13 +15,15 @@ DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format); class SparsePageRawFormat : public SparsePageFormat { public: bool Read(SparsePage* page, dmlc::SeekStream* fi) override { - if (!fi->Read(&(page->offset))) return false; - CHECK_NE(page->offset.size(), 0U) << "Invalid SparsePage file"; - page->data.resize(page->offset.back()); - if (page->data.size() != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data), - (page->data).size() * sizeof(Entry)), - (page->data).size() * sizeof(Entry)) + auto& offset_vec = page->offset.HostVector(); + if (!fi->Read(&offset_vec)) return false; + auto& data_vec = page->data.HostVector(); + CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file"; + data_vec.resize(offset_vec.back()); + if (page->data.Size() != 0) { + CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec), + (page->data).Size() * sizeof(Entry)), + (page->data).Size() * sizeof(Entry)) << "Invalid SparsePage file"; } return true; @@ -31,15 +33,17 @@ class SparsePageRawFormat : public SparsePageFormat { dmlc::SeekStream* fi, const std::vector& sorted_index_set) override { if (!fi->Read(&disk_offset_)) return false; + auto& offset_vec = page->offset.HostVector(); + auto& data_vec = page->data.HostVector(); // setup the offset - page->offset.clear(); - page->offset.push_back(0); + offset_vec.clear(); + offset_vec.push_back(0); for (unsigned int fid : sorted_index_set) { CHECK_LT(fid + 1, disk_offset_.size()); size_t size = disk_offset_[fid + 1] - disk_offset_[fid]; - page->offset.push_back(page->offset.back() + size); + offset_vec.push_back(offset_vec.back() + size); } - page->data.resize(page->offset.back()); + data_vec.resize(offset_vec.back()); // read in the data size_t begin = fi->Tell(); size_t curr_offset = 0; @@ -53,14 +57,14 @@ class SparsePageRawFormat : public SparsePageFormat { size_t j, size_to_read = 0; for (j = i; j < sorted_index_set.size(); ++j) { if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) { - size_to_read += page->offset[j + 1] - page->offset[j]; + size_to_read += offset_vec[j + 1] - offset_vec[j]; } else { break; } } if (size_to_read != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data) + page->offset[i], + CHECK_EQ(fi->Read(dmlc::BeginPtr(data_vec) + offset_vec[i], size_to_read * sizeof(Entry)), size_to_read * sizeof(Entry)) << "Invalid SparsePage file"; @@ -76,11 +80,13 @@ class SparsePageRawFormat : public SparsePageFormat { } void Write(const SparsePage& page, dmlc::Stream* fo) override { - CHECK(page.offset.size() != 0 && page.offset[0] == 0); - CHECK_EQ(page.offset.back(), page.data.size()); - fo->Write(page.offset); - if (page.data.size() != 0) { - fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(Entry)); + const auto& offset_vec = page.offset.HostVector(); + const auto& data_vec = page.data.HostVector(); + CHECK(page.offset.Size() != 0 && offset_vec[0] == 0); + CHECK_EQ(offset_vec.back(), page.data.Size()); + fo->Write(offset_vec); + if (page.data.Size() != 0) { + fo->Write(dmlc::BeginPtr(data_vec), page.data.Size() * sizeof(Entry)); } } diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index ddac4a9415f0..7d47dedd9f92 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -129,10 +129,12 @@ void SparsePageSource::Create(dmlc::Parser* src, while (src->Next()) { const dmlc::RowBlock& batch = src->Value(); if (batch.label != nullptr) { - info.labels_.insert(info.labels_.end(), batch.label, batch.label + batch.size); + auto& labels = info.labels_.HostVector(); + labels.insert(labels.end(), batch.label, batch.label + batch.size); } if (batch.weight != nullptr) { - info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size); + auto& weights = info.weights_.HostVector(); + weights.insert(weights.end(), batch.weight, batch.weight + batch.size); } if (batch.qid != nullptr) { info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size); @@ -175,7 +177,7 @@ void SparsePageSource::Create(dmlc::Parser* src, } } - if (page->data.size() != 0) { + if (page->data.Size() != 0) { writer.PushWrite(std::move(page)); } @@ -224,7 +226,7 @@ void SparsePageSource::Create(DMatrix* src, << (bytes_write >> 20UL) << " written"; } } - if (page->data.size() != 0) { + if (page->data.Size() != 0) { writer.PushWrite(std::move(page)); } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 6a432c057658..471134575bf8 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -143,7 +143,7 @@ class GBLinear : public GradientBooster { model_.LazyInitModel(); CHECK_EQ(ntree_limit, 0U) << "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor"; - const std::vector& base_margin = p_fmat->Info().base_margin_; + const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); const int ngroup = model_.param.num_output_group; const size_t ncolumns = model_.param.num_feature + 1; // allocate space for (#features + bias) times #groups times #rows @@ -201,7 +201,7 @@ class GBLinear : public GradientBooster { monitor_.Start("PredictBatchInternal"); model_.LazyInitModel(); std::vector &preds = *out_preds; - const std::vector& base_margin = p_fmat->Info().base_margin_; + const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); // start collecting the prediction auto iter = p_fmat->RowIterator(); const int ngroup = model_.param.num_output_group; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 739acde3b6be..319b5188f804 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -195,8 +195,8 @@ class GBTree : public GradientBooster { << "must have exactly ngroup*nrow gpairs"; // TODO(canonizer): perform this on GPU if HostDeviceVector has device set. HostDeviceVector tmp(in_gpair->Size() / ngroup, - GradientPair(), in_gpair->Devices()); - std::vector& gpair_h = in_gpair->HostVector(); + GradientPair(), in_gpair->Distribution()); + const auto& gpair_h = in_gpair->ConstHostVector(); auto nsize = static_cast(tmp.Size()); for (int gid = 0; gid < ngroup; ++gid) { std::vector& tmp_h = tmp.HostVector(); @@ -402,7 +402,8 @@ class Dart : public GBTree { if (init_out_preds) { size_t n = num_group * p_fmat->Info().num_row_; - const std::vector& base_margin = p_fmat->Info().base_margin_; + const auto& base_margin = + p_fmat->Info().base_margin_.ConstHostVector(); out_preds->resize(n); if (base_margin.size() != 0) { CHECK_EQ(out_preds->size(), n); diff --git a/src/learner.cc b/src/learner.cc index a1d3b469e658..dfbc1ede2e3f 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -386,7 +386,7 @@ class LearnerImpl : public Learner { this->PredictRaw(train, &preds_); monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); - obj_->GetGradient(&preds_, train->Info(), iter, &gpair_); + obj_->GetGradient(preds_, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); gbm_->DoBoost(train, &gpair_, obj_.get()); monitor_.Stop("UpdateOneIter"); @@ -416,7 +416,8 @@ class LearnerImpl : public Learner { obj_->EvalTransform(&preds_); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(preds_.HostVector(), data_sets[i]->Info(), tparam_.dsplit == 2); + << ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(), + tparam_.dsplit == 2); } } @@ -459,7 +460,8 @@ class LearnerImpl : public Learner { this->PredictRaw(data, &preds_); obj_->EvalTransform(&preds_); return std::make_pair(metric, - ev->Eval(preds_.HostVector(), data->Info(), tparam_.dsplit == 2)); + ev->Eval(preds_.ConstHostVector(), data->Info(), + tparam_.dsplit == 2)); } void Predict(DMatrix* data, bool output_margin, diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc index 468911d19e5b..d8b62aad4d07 100644 --- a/src/linear/updater_coordinate.cc +++ b/src/linear/updater_coordinate.cc @@ -90,7 +90,8 @@ class CoordinateUpdater : public LinearUpdater { const int ngroup = model->param.num_output_group; // update bias for (int group_idx = 0; group_idx < ngroup; ++group_idx) { - auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->HostVector(), p_fmat); + auto grad = GetBiasGradientParallel(group_idx, ngroup, + in_gpair->ConstHostVector(), p_fmat); auto dbias = static_cast(param.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->bias()[group_idx] += dbias; @@ -98,13 +99,14 @@ class CoordinateUpdater : public LinearUpdater { dbias, &in_gpair->HostVector(), p_fmat); } // prepare for updating the weights - selector->Setup(*model, in_gpair->HostVector(), p_fmat, param.reg_alpha_denorm, + selector->Setup(*model, in_gpair->ConstHostVector(), p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, param.top_k); // update weights for (int group_idx = 0; group_idx < ngroup; ++group_idx) { for (unsigned i = 0U; i < model->param.num_feature; i++) { - int fidx = selector->NextFeature(i, *model, group_idx, in_gpair->HostVector(), p_fmat, - param.reg_alpha_denorm, param.reg_lambda_denorm); + int fidx = selector->NextFeature + (i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, + param.reg_alpha_denorm, param.reg_lambda_denorm); if (fidx < 0) break; this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), p_fmat, model); } diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 84761caaabdc..fe1f5b5fc971 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -259,7 +259,7 @@ class GPUCoordinateUpdater : public LinearUpdater { monitor.Start("UpdateGpair"); // Update gpair dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { - shard->UpdateGpair(in_gpair->HostVector(), model->param); + shard->UpdateGpair(in_gpair->ConstHostVector(), model->param); }); monitor.Stop("UpdateGpair"); @@ -267,7 +267,7 @@ class GPUCoordinateUpdater : public LinearUpdater { this->UpdateBias(p_fmat, model); monitor.Stop("UpdateBias"); // prepare for updating the weights - selector->Setup(*model, in_gpair->HostVector(), p_fmat, + selector->Setup(*model, in_gpair->ConstHostVector(), p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm, param.top_k); monitor.Start("UpdateFeature"); @@ -275,7 +275,7 @@ class GPUCoordinateUpdater : public LinearUpdater { ++group_idx) { for (auto i = 0U; i < model->param.num_feature; i++) { auto fidx = selector->NextFeature( - i, *model, group_idx, in_gpair->HostVector(), p_fmat, + i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat, param.reg_alpha_denorm, param.reg_lambda_denorm); if (fidx < 0) break; this->UpdateFeature(fidx, group_idx, &in_gpair->HostVector(), model); diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index fc666cfa1d43..2b760f89fc34 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -63,13 +63,14 @@ class ShotgunUpdater : public LinearUpdater { } void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { - std::vector &gpair = in_gpair->HostVector(); + auto &gpair = in_gpair->HostVector(); param_.DenormalizePenalties(sum_instance_weight); const int ngroup = model->param.num_output_group; // update bias for (int gid = 0; gid < ngroup; ++gid) { - auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->HostVector(), p_fmat); + auto grad = GetBiasGradientParallel(gid, ngroup, + in_gpair->ConstHostVector(), p_fmat); auto dbias = static_cast(param_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->bias()[gid] += dbias; @@ -77,7 +78,7 @@ class ShotgunUpdater : public LinearUpdater { } // lock-free parallel updates of weights - selector_->Setup(*model, in_gpair->HostVector(), p_fmat, + selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, param_.reg_lambda_denorm, 0); auto iter = p_fmat->ColIterator(); while (iter->Next()) { @@ -85,15 +86,16 @@ class ShotgunUpdater : public LinearUpdater { const auto nfeat = static_cast(batch.Size()); #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < nfeat; ++i) { - int ii = selector_->NextFeature(i, *model, 0, in_gpair->HostVector(), p_fmat, - param_.reg_alpha_denorm, param_.reg_lambda_denorm); + int ii = selector_->NextFeature + (i, *model, 0, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, + param_.reg_lambda_denorm); if (ii < 0) continue; const bst_uint fid = ii; auto col = batch[ii]; for (int gid = 0; gid < ngroup; ++gid) { double sum_grad = 0.0, sum_hess = 0.0; for (auto& c : col) { - GradientPair &p = gpair[c.index * ngroup + gid]; + const GradientPair &p = gpair[c.index * ngroup + gid]; if (p.GetHess() < 0.0f) continue; const bst_float v = c.fvalue; sum_grad += p.GetGrad() * v; diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index 06d8df61e8ce..a9df69e11526 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -24,16 +24,18 @@ struct EvalEWiseBase : public Metric { bst_float Eval(const std::vector& preds, const MetaInfo& info, bool distributed) const override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels_.Size()) << "label and prediction size not match, " << "hint: use merror or mlogloss for multi-class classification"; - const auto ndata = static_cast(info.labels_.size()); + const auto ndata = static_cast(info.labels_.Size()); double sum = 0.0, wsum = 0.0; + const auto& labels = info.labels_.HostVector(); + const auto& weights = info.weights_.HostVector(); #pragma omp parallel for reduction(+: sum, wsum) schedule(static) for (omp_ulong i = 0; i < ndata; ++i) { - const bst_float wt = info.GetWeight(i); - sum += static_cast(this)->EvalRow(info.labels_[i], preds[i]) * wt; + const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f; + sum += static_cast(this)->EvalRow(labels[i], preds[i]) * wt; wsum += wt; } double dat[2]; dat[0] = sum, dat[1] = wsum; diff --git a/src/metric/multiclass_metric.cc b/src/metric/multiclass_metric.cc index 312dc76b5617..c68ebc25b8bc 100644 --- a/src/metric/multiclass_metric.cc +++ b/src/metric/multiclass_metric.cc @@ -23,20 +23,24 @@ struct EvalMClassBase : public Metric { bst_float Eval(const std::vector &preds, const MetaInfo &info, bool distributed) const override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK(preds.size() % info.labels_.size() == 0) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK(preds.size() % info.labels_.Size() == 0) << "label and prediction size not match"; - const size_t nclass = preds.size() / info.labels_.size(); + const size_t nclass = preds.size() / info.labels_.Size(); CHECK_GE(nclass, 1U) << "mlogloss and merror are only used for multi-class classification," << " use logloss for binary classification"; - const auto ndata = static_cast(info.labels_.size()); + const auto ndata = static_cast(info.labels_.Size()); double sum = 0.0, wsum = 0.0; int label_error = 0; + + const auto& labels = info.labels_.HostVector(); + const auto& weights = info.weights_.HostVector(); + #pragma omp parallel for reduction(+: sum, wsum) schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { - const bst_float wt = info.GetWeight(i); - auto label = static_cast(info.labels_[i]); + const bst_float wt = weights.size() > 0 ? weights[i] : 1.0f; + auto label = static_cast(labels[i]); if (label >= 0 && label < static_cast(nclass)) { sum += Derived::EvalRow(label, preds.data() + i * nclass, diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index f4c2a5300a26..15cbfccd8a11 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -32,7 +32,7 @@ struct EvalAMS : public Metric { CHECK(!distributed) << "metric AMS do not support distributed evaluation"; using namespace std; // NOLINT(*) - const auto ndata = static_cast(info.labels_.size()); + const auto ndata = static_cast(info.labels_.Size()); std::vector > rec(ndata); #pragma omp parallel for schedule(static) @@ -45,10 +45,11 @@ struct EvalAMS : public Metric { const double br = 10.0; unsigned thresindex = 0; double s_tp = 0.0, b_fp = 0.0, tams = 0.0; + const auto& labels = info.labels_.HostVector(); for (unsigned i = 0; i < static_cast(ndata-1) && i < ntop; ++i) { const unsigned ridx = rec[i].second; const bst_float wt = info.GetWeight(ridx); - if (info.labels_[ridx] > 0.5f) { + if (labels[ridx] > 0.5f) { s_tp += wt; } else { b_fp += wt; @@ -84,14 +85,14 @@ struct EvalAuc : public Metric { bst_float Eval(const std::vector &preds, const MetaInfo &info, bool distributed) const override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels_.Size()) << "label size predict size not match"; std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.size()); + tgptr[1] = static_cast(info.labels_.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.size()) + CHECK_EQ(gptr.back(), info.labels_.Size()) << "EvalAuc: group structure must match number of prediction"; const auto ngroup = static_cast(gptr.size() - 1); // sum statistics @@ -99,6 +100,7 @@ struct EvalAuc : public Metric { int auc_error = 0; // each thread takes a local rec std::vector< std::pair > rec; + const auto& labels = info.labels_.HostVector(); for (bst_omp_uint k = 0; k < ngroup; ++k) { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { @@ -110,7 +112,7 @@ struct EvalAuc : public Metric { double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; for (size_t j = 0; j < rec.size(); ++j) { const bst_float wt = info.GetWeight(rec[j].second); - const bst_float ctr = info.labels_[rec[j].second]; + const bst_float ctr = labels[rec[j].second]; // keep bucketing predictions in same bucket if (j != 0 && rec[j].first != rec[j - 1].first) { sum_pospair += buf_neg * (sum_npos + buf_pos *0.5); @@ -156,7 +158,7 @@ struct EvalRankList : public Metric { bst_float Eval(const std::vector &preds, const MetaInfo &info, bool distributed) const override { - CHECK_EQ(preds.size(), info.labels_.size()) + CHECK_EQ(preds.size(), info.labels_.Size()) << "label size predict size not match"; // quick consistency when group is not available std::vector tgptr(2, 0); @@ -168,6 +170,7 @@ struct EvalRankList : public Metric { const auto ngroup = static_cast(gptr.size() - 1); // sum statistics double sum_metric = 0.0f; + const auto& labels = info.labels_.HostVector(); #pragma omp parallel reduction(+:sum_metric) { // each thread takes a local rec @@ -176,7 +179,7 @@ struct EvalRankList : public Metric { for (bst_omp_uint k = 0; k < ngroup; ++k) { rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - rec.emplace_back(preds[j], static_cast(info.labels_[j])); + rec.emplace_back(preds[j], static_cast(labels[j])); } sum_metric += this->EvalMetric(rec); } @@ -314,7 +317,7 @@ struct EvalCox : public Metric { CHECK(!distributed) << "Cox metric does not support distributed evaluation"; using namespace std; // NOLINT(*) - const auto ndata = static_cast(info.labels_.size()); + const auto ndata = static_cast(info.labels_.Size()); const std::vector &label_order = info.LabelAbsSort(); // pre-compute a sum for the denominator @@ -326,9 +329,10 @@ struct EvalCox : public Metric { double out = 0; double accumulated_sum = 0; bst_omp_uint num_events = 0; + const auto& labels = info.labels_.HostVector(); for (bst_omp_uint i = 0; i < ndata; ++i) { const size_t ind = label_order[i]; - const auto label = info.labels_[ind]; + const auto label = labels[ind]; if (label > 0) { out -= log(preds[ind]) - log(exp_p_sum); ++num_events; @@ -336,7 +340,7 @@ struct EvalCox : public Metric { // only update the denominator after we move forward in time (labels are sorted) accumulated_sum += preds[ind]; - if (i == ndata - 1 || std::abs(label) < std::abs(info.labels_[label_order[i + 1]])) { + if (i == ndata - 1 || std::abs(label) < std::abs(labels[label_order[i + 1]])) { exp_p_sum -= accumulated_sum; accumulated_sum = 0; } @@ -358,14 +362,14 @@ struct EvalAucPR : public Metric { bst_float Eval(const std::vector &preds, const MetaInfo &info, bool distributed) const override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds.size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels_.Size()) << "label size predict size not match"; std::vector tgptr(2, 0); - tgptr[1] = static_cast(info.labels_.size()); + tgptr[1] = static_cast(info.labels_.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK_EQ(gptr.back(), info.labels_.size()) + CHECK_EQ(gptr.back(), info.labels_.Size()) << "EvalAucPR: group structure must match number of prediction"; const auto ngroup = static_cast(gptr.size() - 1); // sum statistics @@ -373,13 +377,14 @@ struct EvalAucPR : public Metric { int auc_error = 0, auc_gt_one = 0; // each thread takes a local rec std::vector> rec; + const auto& labels = info.labels_.HostVector(); for (bst_omp_uint k = 0; k < ngroup; ++k) { double total_pos = 0.0; double total_neg = 0.0; rec.clear(); for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) { - total_pos += info.GetWeight(j) * info.labels_[j]; - total_neg += info.GetWeight(j) * (1.0f - info.labels_[j]); + total_pos += info.GetWeight(j) * labels[j]; + total_neg += info.GetWeight(j) * (1.0f - labels[j]); rec.emplace_back(preds[j], j); } XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); @@ -390,8 +395,8 @@ struct EvalAucPR : public Metric { // calculate AUC double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0; for (size_t j = 0; j < rec.size(); ++j) { - tp += info.GetWeight(rec[j].second) * info.labels_[rec[j].second]; - fp += info.GetWeight(rec[j].second) * (1.0f - info.labels_[rec[j].second]); + tp += info.GetWeight(rec[j].second) * labels[rec[j].second]; + fp += info.GetWeight(rec[j].second) * (1.0f - labels[rec[j].second]); if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) { if (tp == prevtp) { a = 1.0; diff --git a/src/objective/hinge.cc b/src/objective/hinge.cc index 5b04d215a246..503cd1e924e2 100644 --- a/src/objective/hinge.cc +++ b/src/objective/hinge.cc @@ -21,24 +21,26 @@ class HingeObj : public ObjFunction { // This objective does not take any parameters } - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided" - << "preds.size=" << preds->Size() - << ", label.size=" << info.labels_.size(); - auto& preds_h = preds->HostVector(); + << "preds.size=" << preds.Size() + << ", label.size=" << info.labels_.Size(); + const auto& preds_h = preds.HostVector(); + const auto& labels_h = info.labels_.HostVector(); + const auto& weights_h = info.weights_.HostVector(); out_gpair->Resize(preds_h.size()); auto& gpair = out_gpair->HostVector(); for (size_t i = 0; i < preds_h.size(); ++i) { - auto y = info.labels_[i] * 2.0 - 1.0; + auto y = labels_h[i] * 2.0 - 1.0; bst_float p = preds_h[i]; - bst_float w = info.GetWeight(i); + bst_float w = weights_h.size() > 0 ? weights_h[i] : 1.0f; bst_float g, h; if (p * y < 1.0) { g = -y * w; diff --git a/src/objective/multiclass_obj.cc b/src/objective/multiclass_obj.cc index af212d25a0af..dc43f932764c 100644 --- a/src/objective/multiclass_obj.cc +++ b/src/objective/multiclass_obj.cc @@ -35,19 +35,20 @@ class SoftmaxMultiClassObj : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector* preds, + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK(preds->Size() == (static_cast(param_.num_class) * info.labels_.size())) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels_.Size())) << "SoftmaxMultiClassObj: label size and pred size does not match"; - std::vector& preds_h = preds->HostVector(); + const std::vector& preds_h = preds.HostVector(); out_gpair->Resize(preds_h.size()); std::vector& gpair = out_gpair->HostVector(); const int nclass = param_.num_class; const auto ndata = static_cast(preds_h.size() / nclass); + const auto& labels = info.labels_.HostVector(); int label_error = 0; #pragma omp parallel { @@ -58,7 +59,7 @@ class SoftmaxMultiClassObj : public ObjFunction { rec[k] = preds_h[i * nclass + k]; } common::Softmax(&rec); - auto label = static_cast(info.labels_[i]); + auto label = static_cast(labels[i]); if (label < 0 || label >= nclass) { label_error = label; label = 0; } diff --git a/src/objective/rank_obj.cc b/src/objective/rank_obj.cc index ed18f13c0646..cb186a4f4e03 100644 --- a/src/objective/rank_obj.cc +++ b/src/objective/rank_obj.cc @@ -38,18 +38,18 @@ class LambdaRankObj : public ObjFunction { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector* preds, + void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, HostDeviceVector* out_gpair) override { - CHECK_EQ(preds->Size(), info.labels_.size()) << "label size predict size not match"; - auto& preds_h = preds->HostVector(); + CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match"; + const auto& preds_h = preds.HostVector(); out_gpair->Resize(preds_h.size()); std::vector& gpair = out_gpair->HostVector(); // quick consistency when group is not available - std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.size()); + std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK(gptr.size() != 0 && gptr.back() == info.labels_.size()) + CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size()) << "group structure not consistent with #rows"; const auto ngroup = static_cast(gptr.size() - 1); @@ -67,11 +67,12 @@ class LambdaRankObj : public ObjFunction { sum_weights += info.GetWeight(k); } bst_float weight_normalization_factor = ngroup/sum_weights; + const auto& labels = info.labels_.HostVector(); #pragma omp for schedule(static) for (bst_omp_uint k = 0; k < ngroup; ++k) { lst.clear(); pairs.clear(); for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) { - lst.emplace_back(preds_h[j], info.labels_[j], j); + lst.emplace_back(preds_h[j], labels[j], j); gpair[j] = GradientPair(0.0f, 0.0f); } std::sort(lst.begin(), lst.end(), ListEntry::CmpPred); diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 6b793d59c63a..5a69e3825611 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -38,16 +38,18 @@ class RegLossObj : public ObjFunction { const std::vector > &args) override { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector *preds, const MetaInfo &info, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided" - << "preds.size=" << preds->Size() - << ", label.size=" << info.labels_.size(); - auto& preds_h = preds->HostVector(); + << "preds.size=" << preds.Size() + << ", label.size=" << info.labels_.Size(); + const auto& preds_h = preds.HostVector(); + const auto& labels = info.labels_.HostVector(); + const auto& weights = info.weights_.HostVector(); - this->LazyCheckLabels(info.labels_); + this->LazyCheckLabels(labels); out_gpair->Resize(preds_h.size()); auto& gpair = out_gpair->HostVector(); const auto n = static_cast(preds_h.size()); @@ -57,10 +59,10 @@ class RegLossObj : public ObjFunction { const omp_ulong remainder = n % 8; #pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < n - remainder; i += 8) { - avx::Float8 y(&info.labels_[i]); + avx::Float8 y(&labels[i]); avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i])); - avx::Float8 w = info.weights_.empty() ? avx::Float8(1.0f) - : avx::Float8(&info.weights_[i]); + avx::Float8 w = weights.empty() ? avx::Float8(1.0f) + : avx::Float8(&weights[i]); // Adjust weight w += y * (scale * w - w); avx::Float8 grad = Loss::FirstOrderGradient(p, y); @@ -68,7 +70,7 @@ class RegLossObj : public ObjFunction { avx::StoreGpair(gpair_ptr + i, grad * w, hess * w); } for (omp_ulong i = n - remainder; i < n; ++i) { - auto y = info.labels_[i]; + auto y = labels[i]; bst_float p = Loss::PredTransform(preds_h[i]); bst_float w = info.GetWeight(i); w += y * ((param_.scale_pos_weight * w) - w); @@ -140,15 +142,16 @@ class PoissonRegression : public ObjFunction { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) << "labels are not correctly provided"; - auto& preds_h = preds->HostVector(); - out_gpair->Resize(preds->Size()); + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const auto& preds_h = preds.HostVector(); + out_gpair->Resize(preds.Size()); auto& gpair = out_gpair->HostVector(); + const auto& labels = info.labels_.HostVector(); // check if label in range bool label_correct = true; // start calculating gradient @@ -157,7 +160,7 @@ class PoissonRegression : public ObjFunction { for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds_h[i]; bst_float w = info.GetWeight(i); - bst_float y = info.labels_[i]; + bst_float y = labels[i]; if (y >= 0.0f) { gpair[i] = GradientPair((std::exp(p) - y) * w, std::exp(p + param_.max_delta_step) * w); @@ -201,13 +204,13 @@ class CoxRegression : public ObjFunction { public: // declare functions void Configure(const std::vector >& args) override {} - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) << "labels are not correctly provided"; - auto& preds_h = preds->HostVector(); + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const auto& preds_h = preds.HostVector(); out_gpair->Resize(preds_h.size()); auto& gpair = out_gpair->HostVector(); const std::vector &label_order = info.LabelAbsSort(); @@ -221,6 +224,7 @@ class CoxRegression : public ObjFunction { } // start calculating grad and hess + const auto& labels = info.labels_.HostVector(); double r_k = 0; double s_k = 0; double last_exp_p = 0.0; @@ -231,7 +235,7 @@ class CoxRegression : public ObjFunction { const double p = preds_h[ind]; const double exp_p = std::exp(p); const double w = info.GetWeight(ind); - const double y = info.labels_[ind]; + const double y = labels[ind]; const double abs_y = std::abs(y); // only update the denominator after we move forward in time (labels are sorted) @@ -289,15 +293,16 @@ class GammaRegression : public ObjFunction { void Configure(const std::vector >& args) override { } - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) << "labels are not correctly provided"; - auto& preds_h = preds->HostVector(); + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const auto& preds_h = preds.HostVector(); out_gpair->Resize(preds_h.size()); auto& gpair = out_gpair->HostVector(); + const auto& labels = info.labels_.HostVector(); // check if label in range bool label_correct = true; // start calculating gradient @@ -306,7 +311,7 @@ class GammaRegression : public ObjFunction { for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds_h[i]; bst_float w = info.GetWeight(i); - bst_float y = info.labels_[i]; + bst_float y = labels[i]; if (y >= 0.0f) { gpair[i] = GradientPair((1 - y / std::exp(p)) * w, y / std::exp(p) * w); } else { @@ -356,24 +361,25 @@ class TweedieRegression : public ObjFunction { param_.InitAllowUnknown(args); } - void GetGradient(HostDeviceVector *preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector *out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) << "labels are not correctly provided"; - auto& preds_h = preds->HostVector(); - out_gpair->Resize(preds->Size()); + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided"; + const auto& preds_h = preds.HostVector(); + out_gpair->Resize(preds.Size()); auto& gpair = out_gpair->HostVector(); + const auto& labels = info.labels_.HostVector(); // check if label in range bool label_correct = true; // start calculating gradient - const omp_ulong ndata = static_cast(preds->Size()); // NOLINT(*) + const omp_ulong ndata = static_cast(preds.Size()); // NOLINT(*) #pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds_h[i]; bst_float w = info.GetWeight(i); - bst_float y = info.labels_[i]; + bst_float y = labels[i]; float rho = param_.tweedie_variance_power; if (y >= 0.0f) { bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p); diff --git a/src/objective/regression_obj_gpu.cu b/src/objective/regression_obj_gpu.cu index 2525fb80276d..ab1a11a72766 100644 --- a/src/objective/regression_obj_gpu.cu +++ b/src/objective/regression_obj_gpu.cu @@ -45,7 +45,7 @@ struct GPURegLossParam : public dmlc::Parameter { // GPU kernel for gradient computation template __global__ void get_gradient_k -(common::Span out_gpair, common::Span label_correct, +(common::Span out_gpair, common::Span label_correct, common::Span preds, common::Span labels, const float * __restrict__ weights, int n, float scale_pos_weight) { int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -75,66 +75,46 @@ __global__ void pred_transform_k(common::Span preds, int n) { template class GPURegLossObj : public ObjFunction { protected: - bool copied_; - HostDeviceVector labels_, weights_; - HostDeviceVector label_correct_; + HostDeviceVector label_correct_; // allocate device data for n elements, do nothing if memory is allocated already - void LazyResize(size_t n, size_t n_weights) { - if (labels_.Size() == n && weights_.Size() == n_weights) - return; - copied_ = false; - - labels_.Reshard(devices_); - weights_.Reshard(devices_); - label_correct_.Reshard(devices_); - - if (labels_.Size() != n) { - labels_.Resize(n); - label_correct_.Resize(devices_.Size()); - } - if (weights_.Size() != n_weights) - weights_.Resize(n_weights); + void LazyResize() { } public: - GPURegLossObj() : copied_(false) {} + GPURegLossObj() {} void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); - // CHECK(param_.n_gpus != 0) << "Must have at least one device"; + CHECK(param_.n_gpus != 0) << "Must have at least one device"; devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + label_correct_.Reshard(devices_); + label_correct_.Resize(devices_.Size()); } - void GetGradient(HostDeviceVector* preds, + void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int iter, HostDeviceVector* out_gpair) override { - CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; - CHECK_EQ(preds->Size(), info.labels_.size()) + CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty"; + CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided" - << "preds.size=" << preds->Size() << ", label.size=" << info.labels_.size(); - size_t ndata = preds->Size(); - preds->Reshard(devices_); + << "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size(); + size_t ndata = preds.Size(); + preds.Reshard(devices_); + info.labels_.Reshard(devices_); + info.weights_.Reshard(devices_); out_gpair->Reshard(devices_); out_gpair->Resize(ndata); - LazyResize(ndata, info.weights_.size()); GetGradientDevice(preds, info, iter, out_gpair); } private: - void GetGradientDevice(HostDeviceVector* preds, + void GetGradientDevice(const HostDeviceVector& preds, const MetaInfo &info, int iter, HostDeviceVector* out_gpair) { label_correct_.Fill(1); - // only copy the labels and weights once, similar to how the data is copied - if (!copied_) { - labels_.Copy(info.labels_); - if (info.weights_.size() > 0) - weights_.Copy(info.weights_); - copied_ = true; - } // run the kernel #pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1) @@ -142,12 +122,12 @@ class GPURegLossObj : public ObjFunction { int d = devices_[i]; dh::safe_cuda(cudaSetDevice(d)); const int block = 256; - size_t n = preds->DeviceSize(d); + size_t n = preds.DeviceSize(d); if (n > 0) { get_gradient_k<<>> (out_gpair->DeviceSpan(d), label_correct_.DeviceSpan(d), - preds->DeviceSpan(d), labels_.DeviceSpan(d), - info.weights_.size() > 0 ? weights_.DevicePointer(d) : nullptr, + preds.DeviceSpan(d), info.labels_.DeviceSpan(d), + info.weights_.Size() > 0 ? info.weights_.DevicePointer(d) : nullptr, n, param_.scale_pos_weight); dh::safe_cuda(cudaGetLastError()); } @@ -155,7 +135,7 @@ class GPURegLossObj : public ObjFunction { } // copy "label correct" flags back to host - std::vector& label_correct_h = label_correct_.HostVector(); + std::vector& label_correct_h = label_correct_.HostVector(); for (int i = 0; i < devices_.Size(); ++i) { if (label_correct_h[i] == 0) LOG(FATAL) << Loss::LabelErrorMsg(); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 964bbaa0d5ae..22a31425ce49 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -112,7 +112,7 @@ class CPUPredictor : public Predictor { ntree_limit * model.param.num_output_group >= model.trees.size()) { auto it = cache_.find(dmat); if (it != cache_.end()) { - HostDeviceVector& y = it->second.predictions; + const HostDeviceVector& y = it->second.predictions; if (y.Size() != 0) { out_preds->Resize(y.Size()); std::copy(y.HostVector().begin(), y.HostVector().end(), @@ -128,7 +128,7 @@ class CPUPredictor : public Predictor { HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { size_t n = model.param.num_output_group * info.num_row_; - const std::vector& base_margin = info.base_margin_; + const auto& base_margin = info.base_margin_.HostVector(); out_preds->Resize(n); std::vector& out_preds_h = out_preds->HostVector(); if (base_margin.size() == n) { @@ -282,7 +282,7 @@ class CPUPredictor : public Predictor { } // start collecting the contributions auto iter = p_fmat->RowIterator(); - const std::vector& base_margin = info.base_margin_; + const auto& base_margin = info.base_margin_.HostVector(); iter->BeforeFirst(); while (iter->Next()) { auto &batch = iter->Value(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 1fba61656abf..b59564f96d8a 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -58,28 +58,30 @@ struct DeviceMatrix { DeviceMatrix(DMatrix* dmat, int device_idx, bool silent) : p_mat(dmat) { dh::safe_cuda(cudaSetDevice(device_idx)); - auto info = dmat->Info(); + const auto& info = dmat->Info(); ba.Allocate(device_idx, silent, &row_ptr, info.num_row_ + 1, &data, info.num_nonzero_); auto iter = dmat->RowIterator(); iter->BeforeFirst(); size_t data_offset = 0; while (iter->Next()) { - auto &batch = iter->Value(); + const auto& batch = iter->Value(); + const auto& offset_vec = batch.offset.HostVector(); + const auto& data_vec = batch.data.HostVector(); // Copy row ptr dh::safe_cuda(cudaMemcpy( - row_ptr.Data() + batch.base_rowid, batch.offset.data(), - sizeof(size_t) * batch.offset.size(), cudaMemcpyHostToDevice)); + row_ptr.Data() + batch.base_rowid, offset_vec.data(), + sizeof(size_t) * offset_vec.size(), cudaMemcpyHostToDevice)); if (batch.base_rowid > 0) { auto begin_itr = row_ptr.tbegin() + batch.base_rowid; auto end_itr = begin_itr + batch.Size() + 1; IncrementOffset(begin_itr, end_itr, batch.base_rowid); } - dh::safe_cuda(cudaMemcpy(data.Data() + data_offset, batch.data.data(), - sizeof(Entry) * batch.data.size(), + dh::safe_cuda(cudaMemcpy(data.Data() + data_offset, data_vec.data(), + sizeof(Entry) * data_vec.size(), cudaMemcpyHostToDevice)); // Copy data - data_offset += batch.data.size(); + data_offset += batch.data.Size(); } } }; @@ -374,10 +376,10 @@ class GPUPredictor : public xgboost::Predictor { HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { size_t n = model.param.num_output_group * info.num_row_; - const std::vector& base_margin = info.base_margin_; + const HostDeviceVector& base_margin = info.base_margin_; out_preds->Reshard(devices); out_preds->Resize(n); - if (base_margin.size() != 0) { + if (base_margin.Size() != 0) { CHECK_EQ(out_preds->Size(), n); out_preds->Copy(base_margin); } else { @@ -391,11 +393,11 @@ class GPUPredictor : public xgboost::Predictor { ntree_limit * model.param.num_output_group >= model.trees.size()) { auto it = cache_.find(dmat); if (it != cache_.end()) { - HostDeviceVector& y = it->second.predictions; + const HostDeviceVector& y = it->second.predictions; if (y.Size() != 0) { out_preds->Reshard(devices); out_preds->Resize(y.Size()); - out_preds->Copy(&y); + out_preds->Copy(y); return true; } } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index b28a87ff00d5..42a681b4599f 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -41,7 +41,7 @@ class ColMaker: public TreeUpdater { Builder builder( param_, std::unique_ptr(spliteval_->GetHostClone())); - builder.Update(gpair->HostVector(), dmat, tree); + builder.Update(gpair->ConstHostVector(), dmat, tree); } param_.learning_rate = lr; } @@ -784,7 +784,7 @@ class DistColMaker : public ColMaker { param_, std::unique_ptr(spliteval_->GetHostClone())); // build the tree - builder.Update(gpair->HostVector(), dmat, trees[0]); + builder.Update(gpair->ConstHostVector(), dmat, trees[0]); //// prune the tree, note that pruner will sync the tree pruner_->Update(gpair, dmat, trees); // update position after the tree is pruned diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc index cda6c30b47d5..dff16471482d 100644 --- a/src/tree/updater_fast_hist.cc +++ b/src/tree/updater_fast_hist.cc @@ -164,7 +164,7 @@ class FastHistMaker: public TreeUpdater { double time_evaluate_split = 0; double time_apply_split = 0; - std::vector& gpair_h = gpair->HostVector(); + const std::vector& gpair_h = gpair->ConstHostVector(); spliteval_->Reset(); diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index f549f1922d56..e92c3545e55f 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -650,7 +650,7 @@ class GPUMaker : public TreeUpdater { void convertToCsc(DMatrix* dmat, std::vector* fval, std::vector* fId, std::vector* offset) { - MetaInfo info = dmat->Info(); + const MetaInfo& info = dmat->Info(); CHECK(info.num_col_ < std::numeric_limits::max()); CHECK(info.num_row_ < std::numeric_limits::max()); nRows = static_cast(info.num_row_); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 69c0b2ab4bfa..8231182132e9 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -387,11 +387,13 @@ struct DeviceShard { void InitRowPtrs(const SparsePage& row_batch) { dh::safe_cuda(cudaSetDevice(device_idx)); + const auto& offset_vec = row_batch.offset.HostVector(); row_ptrs.resize(n_rows + 1); - thrust::copy(row_batch.offset.data() + row_begin_idx, - row_batch.offset.data() + row_end_idx + 1, + thrust::copy(offset_vec.data() + row_begin_idx, + offset_vec.data() + row_end_idx + 1, row_ptrs.begin()); auto row_iter = row_ptrs.begin(); + // find the maximum row size auto get_size = [=] __device__(size_t row) { return row_iter[row + 1] - row_iter[row]; }; // NOLINT @@ -432,9 +434,12 @@ struct DeviceShard { (dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)), static_cast(n_rows)); - thrust::device_vector entries_d(gpu_batch_nrows * row_stride); + const auto& offset_vec = row_batch.offset.HostVector(); + const auto& data_vec = row_batch.data.HostVector(); + thrust::device_vector entries_d(gpu_batch_nrows * row_stride); size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows); + for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { size_t batch_row_begin = gpu_batch * gpu_batch_nrows; size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows; @@ -443,12 +448,12 @@ struct DeviceShard { } size_t batch_nrows = batch_row_end - batch_row_begin; size_t n_entries = - row_batch.offset[row_begin_idx + batch_row_end] - - row_batch.offset[row_begin_idx + batch_row_begin]; + offset_vec[row_begin_idx + batch_row_end] - + offset_vec[row_begin_idx + batch_row_begin]; dh::safe_cuda (cudaMemcpy (entries_d.data().get(), - &row_batch.data[row_batch.offset[row_begin_idx + batch_row_begin]], + data_vec.data() + offset_vec[row_begin_idx + batch_row_begin], n_entries * sizeof(Entry), cudaMemcpyDefault)); dim3 block3(32, 8, 1); dim3 grid3(dh::DivRoundUp(n_rows, block3.x), @@ -458,7 +463,7 @@ struct DeviceShard { row_ptrs.data().get() + batch_row_begin, entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(), batch_row_begin, batch_nrows, - row_batch.offset[row_begin_idx + batch_row_begin], + offset_vec[row_begin_idx + batch_row_begin], row_stride, null_gidx_value); dh::safe_cuda(cudaGetLastError()); @@ -538,7 +543,7 @@ struct DeviceShard { std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); ridx_segments.front() = Segment(0, ridx.Size()); - this->gpair.copy(dh_gpair->tbegin(device_idx), dh_gpair->tend(device_idx)); + this->gpair.copy(dh_gpair->tcbegin(device_idx), dh_gpair->tcend(device_idx)); SubsampleGradientPair(&gpair, param.subsample, row_begin_idx); hist.Reset(); } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 62b5b13e19cd..7c67afe93f2b 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -30,7 +30,7 @@ class HistMaker: public BaseMaker { param_.learning_rate = lr / trees.size(); // build tree for (auto tree : trees) { - this->Update(gpair->HostVector(), p_fmat, tree); + this->Update(gpair->ConstHostVector(), p_fmat, tree); } param_.learning_rate = lr; } diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index b14fa248d51b..3df4baeea994 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -29,7 +29,7 @@ class TreeRefresher: public TreeUpdater { DMatrix *p_fmat, const std::vector &trees) override { if (trees.size() == 0) return; - std::vector &gpair_h = gpair->HostVector(); + const std::vector &gpair_h = gpair->ConstHostVector(); // number of threads // thread temporal space std::vector > stemp; diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 50f1a56c407f..bf27e2c94e2a 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -30,7 +30,7 @@ class SketchMaker: public BaseMaker { param_.learning_rate = lr / trees.size(); // build tree for (auto tree : trees) { - this->Update(gpair->HostVector(), p_fmat, tree); + this->Update(gpair->ConstHostVector(), p_fmat, tree); } param_.learning_rate = lr; } diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index da3192600853..e471e785425c 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -3,20 +3,168 @@ */ #include -#include "../../../src/common/host_device_vector.h" +#include +#include + #include "../../../src/common/device_helpers.cuh" +#include "../../../src/common/host_device_vector.h" namespace xgboost { namespace common { +void SetDevice(int device) { + int n_devices; + dh::safe_cuda(cudaGetDeviceCount(&n_devices)); + device %= n_devices; + dh::safe_cuda(cudaSetDevice(device)); +} + +void InitHostDeviceVector(size_t n, const GPUDistribution& distribution, + HostDeviceVector *v) { + // create the vector + GPUSet devices = distribution.Devices(); + v->Reshard(distribution); + v->Resize(n); + + ASSERT_EQ(v->Size(), n); + ASSERT_TRUE(v->Distribution() == distribution); + ASSERT_TRUE(v->Devices() == devices); + // ensure that the devices have read-write access + for (int i = 0; i < devices.Size(); ++i) { + ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead)); + ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kWrite)); + } + // ensure that the host has no access + ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); + ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead)); + + // fill in the data on the host + std::vector& data_h = v->HostVector(); + // ensure that the host has full access, while the devices have none + ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead)); + ASSERT_TRUE(v->HostCanAccess(GPUAccess::kWrite)); + for (int i = 0; i < devices.Size(); ++i) { + ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kRead)); + ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kWrite)); + } + ASSERT_EQ(data_h.size(), n); + std::copy_n(thrust::make_counting_iterator(0), n, data_h.begin()); +} + +void PlusOne(HostDeviceVector *v) { + int n_devices = v->Devices().Size(); + for (int i = 0; i < n_devices; ++i) { + SetDevice(i); + thrust::transform(v->tbegin(i), v->tend(i), v->tbegin(i), + [=]__device__(unsigned int a){ return a + 1; }); + } +} + +void CheckDevice(HostDeviceVector *v, + const std::vector& starts, + const std::vector& sizes, + unsigned int first, GPUAccess access) { + int n_devices = sizes.size(); + ASSERT_EQ(v->Devices().Size(), n_devices); + for (int i = 0; i < n_devices; ++i) { + ASSERT_EQ(v->DeviceSize(i), sizes.at(i)); + SetDevice(i); + ASSERT_TRUE(thrust::equal(v->tcbegin(i), v->tcend(i), + thrust::make_counting_iterator(first + starts[i]))); + ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead)); + // ensure that the device has at most the access specified by access + ASSERT_EQ(v->DeviceCanAccess(i, GPUAccess::kWrite), access == GPUAccess::kWrite); + } + ASSERT_EQ(v->HostCanAccess(GPUAccess::kRead), access == GPUAccess::kRead); + ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); + for (int i = 0; i < n_devices; ++i) { + SetDevice(i); + ASSERT_TRUE(thrust::equal(v->tbegin(i), v->tend(i), + thrust::make_counting_iterator(first + starts[i]))); + ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kRead)); + ASSERT_TRUE(v->DeviceCanAccess(i, GPUAccess::kWrite)); + } + ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead)); + ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite)); +} + +void CheckHost(HostDeviceVector *v, GPUAccess access) { + const std::vector& data_h = access == GPUAccess::kWrite ? + v->HostVector() : v->ConstHostVector(); + for (size_t i = 0; i < v->Size(); ++i) { + ASSERT_EQ(data_h.at(i), i + 1); + } + ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead)); + ASSERT_EQ(v->HostCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite); + size_t n_devices = v->Devices().Size(); + for (int i = 0; i < n_devices; ++i) { + ASSERT_EQ(v->DeviceCanAccess(i, GPUAccess::kRead), access == GPUAccess::kRead); + // the devices should have no write access + ASSERT_FALSE(v->DeviceCanAccess(i, GPUAccess::kWrite)); + } +} + +void TestHostDeviceVector +(size_t n, const GPUDistribution& distribution, + const std::vector& starts, const std::vector& sizes) { + SetCudaSetDeviceHandler(SetDevice); + HostDeviceVector v; + InitHostDeviceVector(n, distribution, &v); + CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead); + PlusOne(&v); + CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite); + CheckHost(&v, GPUAccess::kRead); + CheckHost(&v, GPUAccess::kWrite); + SetCudaSetDeviceHandler(nullptr); +} + +TEST(HostDeviceVector, TestBlock) { + size_t n = 1001; + int n_devices = 2; + auto distribution = GPUDistribution::Block(GPUSet::Range(0, n_devices)); + std::vector starts{0, 501}; + std::vector sizes{501, 500}; + TestHostDeviceVector(n, distribution, starts, sizes); +} + +TEST(HostDeviceVector, TestGranular) { + size_t n = 3003; + int n_devices = 2; + auto distribution = GPUDistribution::Granular(GPUSet::Range(0, n_devices), 3); + std::vector starts{0, 1503}; + std::vector sizes{1503, 1500}; + TestHostDeviceVector(n, distribution, starts, sizes); +} + +TEST(HostDeviceVector, TestOverlap) { + size_t n = 1001; + int n_devices = 2; + auto distribution = GPUDistribution::Overlap(GPUSet::Range(0, n_devices), 1); + std::vector starts{0, 500}; + std::vector sizes{501, 501}; + TestHostDeviceVector(n, distribution, starts, sizes); +} + +TEST(HostDeviceVector, TestExplicit) { + size_t n = 1001; + int n_devices = 2; + std::vector offsets{0, 550, 1001}; + auto distribution = GPUDistribution::Explicit(GPUSet::Range(0, n_devices), offsets); + std::vector starts{0, 550}; + std::vector sizes{550, 451}; + TestHostDeviceVector(n, distribution, starts, sizes); +} + TEST(HostDeviceVector, Span) { HostDeviceVector vec {1.0f, 2.0f, 3.0f, 4.0f}; vec.Reshard(GPUSet{0, 1}); auto span = vec.DeviceSpan(0); ASSERT_EQ(vec.Size(), span.size()); ASSERT_EQ(vec.DevicePointer(0), span.data()); + auto const_span = vec.ConstDeviceSpan(0); + ASSERT_EQ(vec.Size(), span.size()); + ASSERT_EQ(vec.ConstDevicePointer(0), span.data()); } } // namespace common } // namespace xgboost - diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 0c0f028940b1..7231e880dac1 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -16,9 +16,9 @@ TEST(MetaInfo, GetSet) { info.SetInfo("root_index", double2, xgboost::kDouble, 2); EXPECT_EQ(info.GetRoot(1), 2.0f); - EXPECT_EQ(info.labels_.size(), 0); + EXPECT_EQ(info.labels_.Size(), 0); info.SetInfo("label", double2, xgboost::kFloat32, 2); - EXPECT_EQ(info.labels_.size(), 2); + EXPECT_EQ(info.labels_.Size(), 2); float float2[2] = {1.0f, 2.0f}; EXPECT_EQ(info.GetWeight(1), 1.0f) @@ -27,9 +27,9 @@ TEST(MetaInfo, GetSet) { EXPECT_EQ(info.GetWeight(1), 2.0f); uint32_t uint32_t2[2] = {1U, 2U}; - EXPECT_EQ(info.base_margin_.size(), 0); + EXPECT_EQ(info.base_margin_.Size(), 0); info.SetInfo("base_margin", uint32_t2, xgboost::kUInt32, 2); - EXPECT_EQ(info.base_margin_.size(), 2); + EXPECT_EQ(info.base_margin_.Size(), 2); uint64_t uint64_t2[2] = {1U, 2U}; EXPECT_EQ(info.group_ptr_.size(), 0); @@ -59,7 +59,7 @@ TEST(MetaInfo, SaveLoadBinary) { fs = dmlc::Stream::Create(tmp_file.c_str(), "r"); xgboost::MetaInfo inforead; inforead.LoadBinary(fs); - EXPECT_EQ(inforead.labels_, info.labels_); + EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector()); EXPECT_EQ(inforead.num_col_, info.num_col_); EXPECT_EQ(inforead.num_row_, info.num_row_); @@ -128,7 +128,7 @@ TEST(MetaInfo, LoadQid) { CHECK(iter->Next()); const xgboost::SparsePage& batch = iter->Value(); CHECK_EQ(batch.base_rowid, 0); - CHECK(batch.offset == expected_offset); - CHECK(batch.data == expected_data); + CHECK(batch.offset.HostVector() == expected_offset); + CHECK(batch.data.HostVector() == expected_data); CHECK(!iter->Next()); } diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index f88a0295c2c5..2d3ae03325d1 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -13,7 +13,7 @@ TEST(SimpleDMatrix, MetaInfo) { EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_col_, 5); EXPECT_EQ(dmat->Info().num_nonzero_, 6); - EXPECT_EQ(dmat->Info().labels_.size(), dmat->Info().num_row_); + EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); delete dmat; } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 9a5db8a89668..209c033d60d1 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -16,7 +16,7 @@ TEST(SparsePageDMatrix, MetaInfo) { EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_col_, 5); EXPECT_EQ(dmat->Info().num_nonzero_, 6); - EXPECT_EQ(dmat->Info().labels_.size(), dmat->Info().num_row_); + EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); // Clean up of external memory files std::remove((tmp_file + ".cache").c_str()); @@ -54,7 +54,7 @@ TEST(SparsePageDMatrix, RowAccess) { delete dmat; } -TEST(SparsePageDMatrix, ColAcess) { +TEST(SparsePageDMatrix, ColAccess) { std::string tmp_file = CreateSimpleTestData(); xgboost::DMatrix * dmat = xgboost::DMatrix::Load( tmp_file + "#" + tmp_file + ".cache", true, false); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index cb9b12f49f4f..9f13b3868432 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -49,9 +49,8 @@ void _CheckObjFunction(xgboost::ObjFunction * obj, std::vector out_grad, std::vector out_hess) { xgboost::HostDeviceVector in_preds(preds); - xgboost::HostDeviceVector out_gpair; - obj->GetGradient(&in_preds, info, 1, &out_gpair); + obj->GetGradient(in_preds, info, 1, &out_gpair); std::vector& gpair = out_gpair.HostVector(); ASSERT_EQ(gpair.size(), in_preds.Size()); @@ -73,8 +72,8 @@ void CheckObjFunction(xgboost::ObjFunction * obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_ = labels; - info.weights_ = weights; + info.labels_.HostVector() = labels; + info.weights_.HostVector() = weights; _CheckObjFunction(obj, preds, labels, weights, info, out_grad, out_hess); } @@ -88,8 +87,8 @@ void CheckRankingObjFunction(xgboost::ObjFunction * obj, std::vector out_hess) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_ = labels; - info.weights_ = weights; + info.labels_.HostVector() = labels; + info.weights_.HostVector() = weights; info.group_ptr_ = groups; _CheckObjFunction(obj, preds, labels, weights, info, out_grad, out_hess); @@ -102,8 +101,8 @@ xgboost::bst_float GetMetricEval(xgboost::Metric * metric, std::vector weights) { xgboost::MetaInfo info; info.num_row_ = labels.size(); - info.labels_ = labels; - info.weights_ = weights; + info.labels_.HostVector() = labels; + info.weights_.HostVector() = weights; return metric->Eval(preds, info, false); }