From 288c52596c9a68ba005d6f0664cb477b8933d88e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 29 Apr 2022 19:41:39 +0800 Subject: [PATCH] Define bin type. (#7850) --- include/xgboost/base.h | 2 ++ src/common/column_matrix.h | 10 +++++----- src/common/hist_util.h | 24 ++++++++++-------------- src/data/gradient_index.h | 12 ++++++++---- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 4ea2de31a436..04cb6ddb7609 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -121,6 +121,8 @@ using bst_float = float; // NOLINT using bst_cat_t = int32_t; // NOLINT /*! \brief Type for data column (feature) index. */ using bst_feature_t = uint32_t; // NOLINT +/*! \brief Type for histogram bin index. */ +using bst_bin_t = int32_t; // NOLINT /*! \brief Type for data row index. * * Be careful `std::size_t' is implementation-defined. Meaning that the binary diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index d289db05e279..57e602114453 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -34,8 +34,8 @@ class Column { public: static constexpr int32_t kMissingId = -1; - Column(ColumnType type, common::Span index, const uint32_t index_base) - : type_(type), index_(index), index_base_(index_base) {} + Column(ColumnType type, common::Span index, const bst_bin_t index_base) + : type_(type), index_(index), index_base_{index_base} {} virtual ~Column() = default; @@ -60,19 +60,19 @@ class Column { /* bin indexes in range [0, max_bins - 1] */ common::Span index_; /* bin index offset for specific feature */ - const uint32_t index_base_; + bst_bin_t const index_base_; }; template class SparseColumn : public Column { public: - SparseColumn(ColumnType type, common::Span index, uint32_t index_base, + SparseColumn(ColumnType type, common::Span index, bst_bin_t index_base, common::Span row_ind) : Column(type, index, index_base), row_ind_(row_ind) {} const size_t* GetRowData() const { return row_ind_.data(); } - int32_t GetBinIdx(size_t rid, size_t* state) const { + bst_bin_t GetBinIdx(size_t rid, size_t* state) const { const size_t column_size = this->Size(); if (!((*state) < column_size)) { return this->kMissingId; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 442bddfcdaab..fad082c2c596 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -40,8 +40,6 @@ class HistogramCuts { float max_cat_{-1.0f}; protected: - using BinIdx = uint32_t; - void Swap(HistogramCuts&& that) noexcept(true) { std::swap(cut_values_, that.cut_values_); std::swap(cut_ptrs_, that.cut_ptrs_); @@ -110,31 +108,31 @@ class HistogramCuts { // Return the index of a cut point that is strictly greater than the input // value, or the last available index if none exists - BinIdx SearchBin(float value, bst_feature_t column_id, std::vector const& ptrs, - std::vector const& values) const { + bst_bin_t SearchBin(float value, bst_feature_t column_id, std::vector const& ptrs, + std::vector const& values) const { auto end = ptrs[column_id + 1]; auto beg = ptrs[column_id]; auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); - BinIdx idx = it - values.cbegin(); + bst_bin_t idx = it - values.cbegin(); idx -= !!(idx == end); return idx; } - BinIdx SearchBin(float value, bst_feature_t column_id) const { + bst_bin_t SearchBin(float value, bst_feature_t column_id) const { return this->SearchBin(value, column_id, Ptrs(), Values()); } /** * \brief Search the bin index for numerical feature. */ - BinIdx SearchBin(Entry const& e) const { + bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); } /** * \brief Search the bin index for categorical feature. */ - BinIdx SearchCatBin(Entry const &e) const { + bst_bin_t SearchCatBin(Entry const &e) const { auto const &ptrs = this->Ptrs(); auto const &vals = this->Values(); auto end = ptrs.at(e.index + 1) + vals.cbegin(); @@ -296,10 +294,10 @@ struct Index { }; template -int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, - GradientIndex const &data, - uint32_t const fidx_begin, - uint32_t const fidx_end) { +bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, + GradientIndex const& data, + uint32_t const fidx_begin, + uint32_t const fidx_end) { size_t previous_middle = std::numeric_limits::max(); while (end != begin) { size_t middle = begin + (end - begin) / 2; @@ -324,8 +322,6 @@ int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, return -1; } -class ColumnMatrix; - template using GHistRow = Span >; diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 5a41d7b2a52c..3d179e0fd3ad 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -14,6 +14,10 @@ #include "xgboost/data.h" namespace xgboost { +namespace common { +class ColumnMatrix; +} // namespace common + /*! * \brief preprocessed global index matrix, in CSR format * @@ -80,13 +84,13 @@ class GHistIndexMatrix { for (bst_uint j = 0; j < inst.size(); ++j) { auto e = inst[j]; if (common::IsCat(ft, e.index)) { - auto bin_idx = cut.SearchCatBin(e); + bst_bin_t bin_idx = cut.SearchCatBin(e); index_data[ibegin + j] = get_offset(bin_idx, j); ++hit_count_tloc_[tid * nbins + bin_idx]; } else { - uint32_t idx = cut.SearchBin(e.fvalue, e.index, ptrs, values); - index_data[ibegin + j] = get_offset(idx, j); - ++hit_count_tloc_[tid * nbins + idx]; + bst_bin_t bin_idx = cut.SearchBin(e.fvalue, e.index, ptrs, values); + index_data[ibegin + j] = get_offset(bin_idx, j); + ++hit_count_tloc_[tid * nbins + bin_idx]; } } });