Skip to content

Commit

Permalink
double to float
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed Apr 16, 2020
1 parent a8e9cf0 commit 0438212
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 18 deletions.
12 changes: 6 additions & 6 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,14 +831,14 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
* \brief fill a histogram by zeros in range [begin, end)
*/
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
memset(hist.data() + begin, '\0', (end-begin)*sizeof(GradStatHist));
}

/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = decltype(GradStatHist::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* padd = reinterpret_cast<const FPType*>(add.data());

Expand All @@ -851,7 +851,7 @@ void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
* \brief Copy hist from src to dst in range [begin, end)
*/
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = decltype(GradStatHist::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());

Expand All @@ -865,7 +865,7 @@ void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
*/
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = decltype(GradStatHist::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());
Expand Down Expand Up @@ -1023,7 +1023,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const GHistIndexMatrix& gmat,
GHistRow hist,
bool isDense) {
using FPType = decltype(tree::GradStats::sum_grad);
using FPType = decltype(GradStatHist::sum_grad);
const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);

Expand Down Expand Up @@ -1054,7 +1054,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_hist = hist.data();
GradStatHist* p_hist = hist.data();

#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
Expand Down
62 changes: 58 additions & 4 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,61 @@ class GHistIndexBlockMatrix {
* for that particular bin
* Uses global bin id so as to represent all features simultaneously
*/
using GHistRow = Span<tree::GradStats>;
//using GHistRow = Span<tree::GradStats>;

struct GradStatHist {
using GradType = float;
/*! \brief sum gradient statistics */
GradType sum_grad;
/*! \brief sum hessian statistics */
GradType sum_hess;

GradStatHist() : sum_grad{0}, sum_hess{0} {
static_assert(sizeof(GradStatHist) == 8,
"Size of GradStatHist is not 8 bytes.");
}

inline void Add(const GradStatHist& b) {
sum_grad += b.sum_grad;
sum_hess += b.sum_hess;
}

inline void Add(const tree::GradStats& b) {
sum_grad += b.sum_grad;
sum_hess += b.sum_hess;
}

inline void Add(const GradientPair& p) {
this->Add(p.GetGrad(), p.GetHess());
}

inline void Add(const GradType& grad, const GradType& hess) {
sum_grad += grad;
sum_hess += hess;
}

inline tree::GradStats ToGradStat() const {
return tree::GradStats(sum_grad, sum_hess);
}

inline void SetSubstract(const GradStatHist& a, const GradStatHist& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}

inline void SetSubstract(const tree::GradStats& a, const GradStatHist& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}

inline GradType GetGrad() const { return sum_grad; }
inline GradType GetHess() const { return sum_hess; }
inline static void Reduce(GradStatHist& a, const GradStatHist& b) { // NOLINT(*)
a.Add(b);
}
};

using GHistRow = Span<GradStatHist>;

/*!
* \brief fill a histogram by zeros
Expand Down Expand Up @@ -439,8 +493,8 @@ class HistCollection {
GHistRow operator[](bst_uint nid) const {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
CHECK_NE(row_ptr_[nid], kMax);
tree::GradStats* ptr =
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
GradStatHist* ptr =
const_cast<GradStatHist*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
return {ptr, nbins_};
}

Expand Down Expand Up @@ -483,7 +537,7 @@ class HistCollection {
/*! \brief amount of active nodes in hist collection */
uint32_t n_nodes_added_ = 0;

std::vector<tree::GradStats> data_;
std::vector<GradStatHist> data_;

/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
std::vector<size_t> row_ptr_;
Expand Down
6 changes: 3 additions & 3 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)
}

/*! \brief core statistics used for tree construction */
struct XGBOOST_ALIGNAS(8) GradStats {
using GradType = float;
struct XGBOOST_ALIGNAS(16) GradStats {
using GradType = double;
/*! \brief sum gradient statistics */
GradType sum_grad { 0 };
/*! \brief sum hessian statistics */
Expand All @@ -334,7 +334,7 @@ struct XGBOOST_ALIGNAS(8) GradStats {
}

XGBOOST_DEVICE GradStats() {
static_assert(sizeof(GradStats) == 8,
static_assert(sizeof(GradStats) == 16,
"Size of GradStats is not 16 bytes.");
}

Expand Down
10 changes: 6 additions & 4 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,23 +1026,25 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
{
auto& stats = snode_[nid].stats;
GHistRow hist = hist_[nid];
common::GradStatHist grad_stat;
if (tree[nid].IsRoot()) {
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
auto begin = hist.data();
for (uint32_t i = ibegin; i < iend; ++i) {
const GradStats et = begin[i];
stats.Add(et.sum_grad, et.sum_hess);
const common::GradStatHist et = begin[i];
grad_stat.Add(et.sum_grad, et.sum_hess);
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
grad_stat.Add(gpair[*it]);
}
}
histred_.Allreduce(&snode_[nid].stats, 1);
histred_.Allreduce(&grad_stat, 1);
snode_[nid].stats = grad_stat.ToGradStat();
} else {
int parent_id = tree[nid].Parent();
if (tree[nid].IsLeftChild()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class QuantileHistMaker: public TreeUpdater {

common::Monitor builder_monitor_;
common::ParallelGHistBuilder hist_buffer_;
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
rabit::Reducer<common::GradStatHist, common::GradStatHist::Reduce> histred_;
};

std::unique_ptr<Builder> builder_;
Expand Down

0 comments on commit 0438212

Please sign in to comment.