Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 8, 2021
1 parent d875445 commit c089efc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 206 deletions.
77 changes: 0 additions & 77 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,6 @@ void SketchContainerImpl<WQSketch>::AllReduce(
monitor_.Stop(__func__);
}

// void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary, int max_bin,
// HistogramCuts *cuts) {
// size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// auto &cut_values = cuts->cut_values_.HostVector();
// for (size_t i = 1; i < required_cuts; ++i) {
// bst_float cpt = summary.data[i].value;
// if (i == 1 || cpt > cut_values.back()) {
// cut_values.push_back(cpt);
// }
// }
// }

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
Expand Down Expand Up @@ -466,70 +454,5 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
});
monitor_.Stop(__func__);
}

// void SortedSketchContainer::MakeCuts(HistogramCuts* cuts) {
// monitor_.Start(__func__);
// std::vector<WXQSketch::SummaryContainer> reduced;
// std::vector<int32_t> num_cuts;
// this->AllReduce(&reduced, &num_cuts);

// cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
// std::vector<WXQSketch::SummaryContainer> final_summaries(reduced.size());

// ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
// if (IsCat(feature_types_, fidx)) {
// return;
// }
// WXQSketch::SummaryContainer &a = final_summaries[fidx];
// size_t max_num_bins = std::min(num_cuts[fidx], max_bins_);
// a.Reserve(max_num_bins + 1);
// CHECK(a.data);
// if (num_cuts[fidx] != 0) {
// a.SetPrune(reduced[fidx], max_num_bins + 1);
// CHECK(a.data && reduced[fidx].data);
// const bst_float mval = a.data[0].value;
// cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
// } else {
// // Empty column.
// const float mval = 1e-5f;
// cuts->min_vals_.HostVector()[fidx] = mval;
// }
// });

// for (size_t fid = 0; fid < reduced.size(); ++fid) {
// size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// WXQSketch::SummaryContainer const& a = final_summaries[fid];
// if (IsCat(feature_types_, fid)) {
// // AddCategories(categories_.at(fid), cuts);
// } else {
// AddCutPoint<WXQSketch>(a, max_num_bins, cuts);
// // push a value that is greater than anything
// const bst_float cpt =
// (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
// // this must be bigger than last value in a scale
// const bst_float last = cpt + (fabs(cpt) + kRtEps);
// cuts->cut_values_.HostVector().push_back(last);
// }

// // Ensure that every feature gets at least one quantile point
// CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
// auto cut_size = static_cast<uint32_t>(cuts->cut_values_.HostVector().size());
// CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back());
// cuts->cut_ptrs_.HostVector().push_back(cut_size);
// }

// // std::cout << "new cuts" << std::endl;
// // for (size_t i = 0; i < cuts->Ptrs().size(); ++i) {
// // auto beg = cuts->Ptrs()[i - 1];
// // auto end = cuts->Ptrs()[i];
// // for (size_t j = beg; j < end; ++j) {
// // std::cout << cuts->Values()[j] << ", ";
// // }
// // std::cout << std::endl;
// // }
// // std::cout << std::endl;

// monitor_.Stop(__func__);
// }
} // namespace common
} // namespace xgboost
131 changes: 2 additions & 129 deletions src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,18 +487,8 @@ class CQHistMaker: public HistMaker {
this->wspace_.cut.push_back(0.0f);
this->wspace_.rptr.push_back(static_cast<unsigned>(this->wspace_.cut.size()));
}
CHECK_EQ(this->wspace_.rptr.size(), (fset.size() + 1) * this->qexpand_.size() + 1);

std::cout << "Maker cuts" << std::endl;
for (size_t i = 0; i < this->wspace_.rptr.size(); ++i) {
auto beg = this->wspace_.rptr[i - 1];
auto end = this->wspace_.rptr[i];
for (size_t j = beg; j < end; ++j) {
std::cout << this->wspace_.cut[j] << ", ";
}
std::cout << std::endl;
}
std::cout << std::endl;
CHECK_EQ(this->wspace_.rptr.size(),
(fset.size() + 1) * this->qexpand_.size() + 1);
}

inline void UpdateHistCol(const std::vector<GradientPair> &gpair,
Expand Down Expand Up @@ -585,7 +575,6 @@ class CQHistMaker: public HistMaker {
}
// two pass scan
unsigned max_size = this->param_.MaxSketchSize();
std::cout << "max_size:" << max_size << std::endl;
for (int const nid : this->qexpand_) {
sbuilder[nid].Init(max_size);
}
Expand Down Expand Up @@ -652,126 +641,10 @@ class CQHistMaker: public HistMaker {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs_;
};

// global proposal
class GlobalProposalHistMaker: public CQHistMaker {
public:
char const* Name() const override {
return "grow_global_histmaker";
}

protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
if (this->qexpand_.size() == 1) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand_.size(), 1U);
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace_.rptr;
cached_cut_ = this->wspace_.cut;
} else {
this->wspace_.cut.clear();
this->wspace_.rptr.clear();
this->wspace_.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace_.rptr.push_back(
this->wspace_.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace_.cut.insert(this->wspace_.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace_.rptr.size(),
(fset.size() + 1) * this->qexpand_.size() + 1);
CHECK_EQ(this->wspace_.rptr.back(), this->wspace_.cut.size());
}
}

// code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_feature_t> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
this->feat2workindex_.resize(tree.param.num_feature);
this->work_set_ = fset;
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex_[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace_.Configure(1);
// to gain speedup in recovery
{
this->thread_hist_.resize(omp_get_max_threads());

// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
this->fsplit_set_.end());
XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
std::less<>{});
this->work_set_.resize(
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());

// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set_, tree);
auto page = batch.GetView();

// start enumeration
const auto nsize = static_cast<bst_omp_uint>(this->work_set_.size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = this->work_set_[i];
int offset = this->feat2workindex_[fid];
if (offset >= 0) {
this->UpdateHistCol(gpair, page[fid], info, tree,
fset, offset,
&this->thread_hist_[omp_get_thread_num()]);
}
});
}
exc.Rethrow();
}

// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats_), &(this->node_stats_));
for (const int nid : this->qexpand_) {
const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats_[nid];
}
}
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size());
}

// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};

XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.")
.set_body([](ObjInfo) {
return new CQHistMaker();
});

// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_global_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([](ObjInfo) {
return new GlobalProposalHistMaker();
});
} // namespace tree
} // namespace xgboost

0 comments on commit c089efc

Please sign in to comment.