Skip to content

Commit

Permalink
Fix test failure
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Dec 10, 2018
1 parent ae0f5c9 commit 66be465
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 32 deletions.
12 changes: 9 additions & 3 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ class RegressionTree {
++param.num_deleted;
}

bst_float FillNodeMeanValue(int nid);

public:
/*! \brief model parameter */
TreeParam param;
Expand Down Expand Up @@ -382,7 +380,15 @@ class RegressionTree {
/*!
* \brief Get the mean value for node, required for feature contributions
*/
float GetNodeMeanValue(int nid);
float GetNodeMeanValue(int nid) const;

/**
* \brief Generate node mean values lazily.
* \param nid The nid.
* \return A bst_float.
*/
bst_float FillNodeMeanValue(int nid = 0);

/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
Expand Down
4 changes: 4 additions & 0 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ class CPUPredictor : public Predictor {
// allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ntree_limit; ++i) {
model.trees[i]->FillNodeMeanValue();
}
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
// start collecting the contributions
for (const auto &batch : p_fmat->GetRowBatches()) {
Expand Down
35 changes: 19 additions & 16 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,17 @@ int RegressionTree::GetLeafIndex(const DenseFeatureVector& feat, unsigned root_i
return pid;
}

float RegressionTree::GetNodeMeanValue(int nid) {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() != num_nodes) {
this->node_mean_values_.resize(num_nodes);
this->FillNodeMeanValue(0);
}
float RegressionTree::GetNodeMeanValue(int nid) const {
CHECK_LT(nid, node_mean_values_.size());
return this->node_mean_values_[nid];
}

bst_float RegressionTree::FillNodeMeanValue(int nid) {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {
return this->node_mean_values_[nid];
}
this->node_mean_values_.resize(num_nodes);
bst_float result;
const auto& node = (*this).GetNode(nid);
if (node.IsLeaf()) {
Expand All @@ -215,8 +216,9 @@ bst_float RegressionTree::FillNodeMeanValue(int nid) {
return result;
}

void RegressionTree::CalculateContributionsApprox(const DenseFeatureVector& feat, unsigned root_id,
bst_float *out_contribs) {
void RegressionTree::CalculateContributionsApprox(
const DenseFeatureVector &feat, unsigned root_id, bst_float *out_contribs) {
CHECK_GT(this->node_mean_values_.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
auto pid = static_cast<int>(root_id);
Expand Down Expand Up @@ -407,6 +409,7 @@ void RegressionTree::CalculateContributions(const DenseFeatureVector& feat, unsi
bst_float *out_contribs,
int condition,
unsigned condition_feature) {
CHECK_GT(this->node_mean_values_.size(), 0U);
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->GetNodeMeanValue(root_id);
Expand All @@ -429,17 +432,17 @@ void DenseFeatureVector::Init(size_t size) {
std::fill(data_.begin(), data_.end(), e);
}

void DenseFeatureVector::Fill(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].fvalue = inst[i].fvalue;
void DenseFeatureVector::Fill(const SparsePage::Inst &inst) {
for (const auto &elem : inst) {
if (elem.index >= data_.size()) continue;
data_[elem.index].fvalue = elem.fvalue;
}
}

void DenseFeatureVector::Drop(const SparsePage::Inst& inst) {
for (bst_uint i = 0; i < inst.size(); ++i) {
if (inst[i].index >= data_.size()) continue;
data_[inst[i].index].flag = -1;
void DenseFeatureVector::Drop(const SparsePage::Inst &inst) {
for (const auto &elem : inst) {
if (elem.index >= data_.size()) continue;
data_[elem.index].flag = -1;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_basemaker-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BaseMaker: public TreeUpdater {
/*! \brief find type of each feature, use column format */
inline void InitByCol(DMatrix* p_fmat,
const RegressionTree& tree) {
fminmax_.resize(tree.param.num_feature * 2);
fminmax_.resize(p_fmat->Info().num_col_ * 2);
std::fill(fminmax_.begin(), fminmax_.end(),
-std::numeric_limits<bst_float>::max());
// start accumulating statistics
Expand Down
8 changes: 4 additions & 4 deletions src/tree/updater_histmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class HistMaker: public BaseMaker {
virtual void InitWorkSet(DMatrix *p_fmat,
const RegressionTree &tree,
std::vector<bst_uint> *p_fset) {
p_fset->resize(tree.param.num_feature);
p_fset->resize(p_fmat->Info().num_col_);
for (size_t i = 0; i < p_fset->size(); ++i) {
(*p_fset)[i] = static_cast<unsigned>(i);
}
Expand Down Expand Up @@ -327,7 +327,7 @@ class CQHistMaker: public HistMaker<TStats> {
const RegressionTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
feat2workindex_.resize(tree.param.num_feature);
feat2workindex_.resize(p_fmat->Info().num_col_);
std::fill(feat2workindex_.begin(), feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
feat2workindex_[fset[i]] = static_cast<int>(i);
Expand Down Expand Up @@ -386,7 +386,7 @@ class CQHistMaker: public HistMaker<TStats> {
const RegressionTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
feat2workindex_.resize(tree.param.num_feature);
feat2workindex_.resize(p_fmat->Info().num_col_);
std::fill(feat2workindex_.begin(), feat2workindex_.end(), -1);
work_set_.clear();
for (auto fidx : fset) {
Expand Down Expand Up @@ -685,7 +685,7 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
const RegressionTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// fill in reverse map
this->feat2workindex_.resize(tree.param.num_feature);
this->feat2workindex_.resize(p_fmat->Info().num_col_);
this->work_set_ = fset;
std::fill(this->feat2workindex_.begin(), this->feat2workindex_.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TreeRefresher: public TreeUpdater {
}
stemp[tid].resize(num_nodes, TStats(param_));
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param_));
fvec_temp[tid].Init(trees[0]->param.num_feature);
fvec_temp[tid].Init(p_fmat->Info().num_col_);
}
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
Expand Down
12 changes: 6 additions & 6 deletions src/tree/updater_skmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SketchMaker: public BaseMaker {
DMatrix *p_fmat,
const RegressionTree &tree) {
const MetaInfo& info = p_fmat->Info();
sketchs_.resize(this->qexpand_.size() * tree.param.num_feature * 3);
sketchs_.resize(this->qexpand_.size() * p_fmat->Info().num_col_ * 3);
for (auto & sketch : sketchs_) {
sketch.Init(info.num_row_, this->param_.sketch_eps);
}
Expand All @@ -146,7 +146,7 @@ class SketchMaker: public BaseMaker {
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint fidx = 0; fidx < nsize; ++fidx) {
this->UpdateSketchCol(gpair, batch[fidx], tree,
this->UpdateSketchCol(p_fmat, gpair, batch[fidx], tree,
node_stats_,
fidx,
batch[fidx].size() == nrows,
Expand All @@ -167,7 +167,7 @@ class SketchMaker: public BaseMaker {
sketch_reducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size());
}
// update sketch information in column fid
inline void UpdateSketchCol(const std::vector<GradientPair> &gpair,
inline void UpdateSketchCol(DMatrix *p_fmat, const std::vector<GradientPair> &gpair,
const SparsePage::Inst &col,
const RegressionTree &tree,
const std::vector<SKStats> &nstats,
Expand All @@ -182,7 +182,7 @@ class SketchMaker: public BaseMaker {
const unsigned wid = this->node2workindex_[nid];
for (int k = 0; k < 3; ++k) {
sbuilder[3 * nid + k].sum_total = 0.0f;
sbuilder[3 * nid + k].sketch = &sketchs_[(wid * tree.param.num_feature + fid) * 3 + k];
sbuilder[3 * nid + k].sketch = &sketchs_[(wid * p_fmat->Info().num_col_ + fid) * 3 + k];
}
}
if (!col_full) {
Expand Down Expand Up @@ -259,7 +259,7 @@ class SketchMaker: public BaseMaker {
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
RegressionTree *p_tree) {
const bst_uint num_feature = p_tree->param.num_feature;
const bst_uint num_feature = p_fmat->Info().num_col_;
// get the best split condition for each node
std::vector<SplitEntry> sol(qexpand_.size());
auto nexpand = static_cast<bst_omp_uint>(qexpand_.size());
Expand All @@ -269,7 +269,7 @@ class SketchMaker: public BaseMaker {
CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
SplitEntry &best = sol[wid];
for (bst_uint fid = 0; fid < num_feature; ++fid) {
unsigned base = (wid * p_tree->param.num_feature + fid) * 3;
unsigned base = (wid * p_fmat->Info().num_col_ + fid) * 3;
EnumerateSplit(summary_array_[base + 0],
summary_array_[base + 1],
summary_array_[base + 2],
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/tree/test_refresh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ TEST(Updater, Refresh) {
{"reg_lambda", "1"}};

RegressionTree tree = RegressionTree();
tree.param.InitAllowUnknown(cfg);
std::vector<RegressionTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));

Expand Down

0 comments on commit 66be465

Please sign in to comment.