From 7115988465367a86b5e4f8d5eae34ea0f0143027 Mon Sep 17 00:00:00 2001 From: marugari Date: Sun, 15 May 2016 20:21:20 +0900 Subject: [PATCH] add Dart booster. --- doc/parameter.md | 25 ++- src/gbm/gbtree.cc | 297 ++++++++++++++++++++++++++++++ tests/python/test_basic_models.py | 26 +++ 3 files changed, 347 insertions(+), 1 deletion(-) diff --git a/doc/parameter.md b/doc/parameter.md index 3ca79077bc85..ea688498ae4e 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -13,7 +13,8 @@ In R-package, you can use .(dot) to replace under score in the parameters, for e General Parameters ------------------ * booster [default=gbtree] - - which booster to use, can be gbtree or gblinear. gbtree uses tree based model while gblinear uses linear function. + - which booster to use, can be gbtree, gblinear or dart. +  gbtree and dart use tree based model while gblinear uses linear function. * silent [default=0] - 0 means printing running messages, 1 means silent mode. * nthread [default to maximum number of threads available if not set] @@ -74,6 +75,28 @@ Parameters for Tree Booster * scale_pos_weight, [default=0] - Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py) +Additional parameters for Dart +------------------------------ +* samp_drop [default=0] + - type of sampling algorithm. + - 0: dropped trees are selected uniformly. + - 1: dropped trees are selected in proportion to weight. +* norm_drop [default=0] + - type of normalization algorithm. + - 0: weight of new trees are 1 / k + dropped trees are scaled by a factor of k / (k + 1) + - 1: weight of new trees are k / (k + 1) + dropped trees are scaled by a factor of k / (k + 1) + - 2: weight of new trees are 1 / (1 + learning_rate) + dropped trees are scaled by a factor of 1 / (1 + learning_rate) +* rate_drop [default=0.0] + - dropout rate. + - range: [0.0, 1.0] +* skip_drop [default=0.0] + - probability of skip dropout. + If a dropout is skipped, new trees are added in the same manner as gbtree. + - range: [0.0, 1.0] + Parameters for Linear Booster ----------------------------- * lambda [default=0] diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 7e58a060ab5a..ca043f1245d9 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -17,6 +17,8 @@ #include #include "../common/common.h" +#include "../common/random.h" + namespace xgboost { namespace gbm { @@ -47,6 +49,38 @@ struct GBTreeTrainParam : public dmlc::Parameter { } }; +/*! \brief training parameters */ +struct DartTrainParam : public dmlc::Parameter { + /*! \brief whether to not print info during training */ + int silent; + /*! \brief type of sampling algorithm */ + int samp_drop; + /*! \brief type of normalization algorithm */ + int norm_drop; + /*! \brief how many trees are dropped */ + float rate_drop; + /*! \brief whether to drop trees */ + float skip_drop; + /*! \brief learning step size for a time */ + float learning_rate; + // declare parameters + DMLC_DECLARE_PARAMETER(DartTrainParam) { + DMLC_DECLARE_FIELD(silent).set_default(0) + .describe("Not print information during trainig."); + DMLC_DECLARE_FIELD(samp_drop).set_default(0) + .describe("Different types of sampling algorithm."); + DMLC_DECLARE_FIELD(norm_drop).set_default(0) + .describe("Different types of normalization algorithm."); + DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f) + .describe("Parameter of how many trees are dropped."); + DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f) + .describe("Parameter of whether to drop trees."); + DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f) + .describe("Learning rate(step size) of update."); + DMLC_DECLARE_ALIAS(learning_rate, eta); + } +}; + /*! \brief model parameters */ struct GBTreeModelParam : public dmlc::Parameter { /*! \brief number of trees */ @@ -475,14 +509,277 @@ class GBTree : public GradientBooster { std::vector > updaters; }; +// dart +class Dart : public GBTree { + public: + Dart() { + num_pbuffer = 0; + weight_drop.clear(); + idx_drop.clear(); + } + + void Configure(const std::vector >& cfg) override { + GBTree::Configure(cfg); + if (trees.size() == 0) { + dparam.InitAllowUnknown(cfg); + } + } + + void Load(dmlc::Stream* fi) override { + GBTree::Load(fi); + weight_drop.resize(mparam.num_trees); + if (mparam.num_trees != 0) { + CHECK_EQ(fi->Read(dmlc::BeginPtr(weight_drop), sizeof(float) * mparam.num_trees), + sizeof(float) * mparam.num_trees) + << "Dart: invalid model file";; + } + } + + void Save(dmlc::Stream* fo) const override { + GBTree::Save(fo); + if (weight_drop.size() != 0) { + fo->Write(dmlc::BeginPtr(weight_drop), sizeof(float) * tree_info.size()); + } + } + + void DoBoost(DMatrix* p_fmat, + int64_t buffer_offset, + std::vector* in_gpair) override { + const std::vector& gpair = *in_gpair; + std::vector > > new_trees; + if (mparam.num_output_group == 1) { + std::vector > ret; + BoostNewTrees(gpair, p_fmat, buffer_offset, 0, &ret); + new_trees.push_back(std::move(ret)); + } else { + const int ngroup = mparam.num_output_group; + CHECK_EQ(gpair.size() % ngroup, 0) + << "must have exactly ngroup*nrow gpairs"; + std::vector tmp(gpair.size() / ngroup); + for (int gid = 0; gid < ngroup; ++gid) { + bst_omp_uint nsize = static_cast(tmp.size()); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + tmp[i] = gpair[i * ngroup + gid]; + } + std::vector > ret; + BoostNewTrees(tmp, p_fmat, buffer_offset, gid, &ret); + new_trees.push_back(std::move(ret)); + } + } + for (int gid = 0; gid < mparam.num_output_group; ++gid) { + this->CommitModel(std::move(new_trees[gid]), gid); + } + } + + void Predict(DMatrix* p_fmat, + int64_t buffer_offset, + std::vector* out_preds, + unsigned ntree_limit) override { + DropTrees(ntree_limit); + const MetaInfo& info = p_fmat->info(); + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + } + InitThreadTemp(nthread); + std::vector &preds = *out_preds; + const size_t stride = p_fmat->info().num_row * mparam.num_output_group; + preds.resize(stride * (mparam.size_leaf_vector+1)); + // start collecting the prediction + dmlc::DataIter* iter = p_fmat->RowIterator(); + + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + // parallel over local batch + const bst_omp_uint nsize = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + const int tid = omp_get_thread_num(); + RegTree::FVec &feats = thread_temp[tid]; + int64_t ridx = static_cast(batch.base_rowid + i); + CHECK_LT(static_cast(ridx), info.num_row); + // loop over output groups + for (int gid = 0; gid < mparam.num_output_group; ++gid) { + this->Pred(batch[i], + buffer_offset < 0 ? -1 : buffer_offset + ridx, + gid, info.GetRoot(ridx), &feats, + &preds[ridx * mparam.num_output_group + gid], stride, + ntree_limit); + } + } + } + } + + void Predict(const SparseBatch::Inst& inst, + std::vector* out_preds, + unsigned ntree_limit, + unsigned root_index) override { + DropTrees(1); + if (thread_temp.size() == 0) { + thread_temp.resize(1, RegTree::FVec()); + thread_temp[0].Init(mparam.num_feature); + } + out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1)); + // loop over output groups + for (int gid = 0; gid < mparam.num_output_group; ++gid) { + this->Pred(inst, -1, gid, root_index, &thread_temp[0], + &(*out_preds)[gid], mparam.num_output_group, + ntree_limit); + } + } + + protected: + // commit new trees all at once + inline void CommitModel(std::vector >&& new_trees, + int bst_group) { + for (size_t i = 0; i < new_trees.size(); ++i) { + trees.push_back(std::move(new_trees[i])); + tree_info.push_back(bst_group); + } + mparam.num_trees += static_cast(new_trees.size()); + size_t num_drop = NormalizeTrees(new_trees.size()); + if (dparam.silent != 1) { + LOG(INFO) << "drop " << num_drop << " trees, " + << "weight = " << weight_drop.back(); + } + } + inline void Pred(const RowBatch::Inst &inst, + int64_t buffer_index, + int bst_group, + unsigned root_index, + RegTree::FVec *p_feats, + float *out_pred, + size_t stride, + unsigned ntree_limit) { + float psum = 0.0f; + // sum of leaf vector + std::vector vec_psum(mparam.size_leaf_vector, 0.0f); + const int64_t bid = this->BufferOffset(buffer_index, bst_group); + p_feats->Fill(inst); + for (size_t i = 0; i < trees.size(); ++i) { + if (tree_info[i] == bst_group) { + bool drop = (std::find(idx_drop.begin(), idx_drop.end(), i) != idx_drop.end()); + if (!drop) { + int tid = trees[i]->GetLeafIndex(*p_feats, root_index); + psum += weight_drop[i] * (*trees[i])[tid].leaf_value(); + for (int j = 0; j < mparam.size_leaf_vector; ++j) { + vec_psum[j] += weight_drop[i] * trees[i]->leafvec(tid)[j]; + } + } + } + } + p_feats->Drop(inst); + // updated the buffered results + if (bid >= 0 && ntree_limit == 0) { + pred_counter[bid] = static_cast(trees.size()); + pred_buffer[bid] = psum; + for (int i = 0; i < mparam.size_leaf_vector; ++i) { + pred_buffer[bid + i + 1] = vec_psum[i]; + } + } + out_pred[0] = psum; + for (int i = 0; i < mparam.size_leaf_vector; ++i) { + out_pred[stride * (i + 1)] = vec_psum[i]; + } + } + + // drop trees + inline void DropTrees(unsigned ntree_limit_drop) { + std::uniform_real_distribution<> runif(0.0, 1.0); + auto& rnd = common::GlobalRandom(); + // reset + idx_drop.clear(); + // sample drop trees + bool skip = false; + if (dparam.skip_drop > 0.0) skip = (runif(rnd) < dparam.skip_drop); + if (ntree_limit_drop == 0 && !skip) { + if (dparam.samp_drop == 1) { + float sum_weight = 0.0; + for (size_t i = 0; i < weight_drop.size(); ++i) { + sum_weight += weight_drop[i]; + } + for (size_t i = 0; i < weight_drop.size(); ++i) { + if (runif(rnd) < dparam.rate_drop * weight_drop.size() * weight_drop[i] / sum_weight) { + idx_drop.push_back(i); + } + } + } else { + for (size_t i = 0; i < weight_drop.size(); ++i) { + if (runif(rnd) < dparam.rate_drop) { + idx_drop.push_back(i); + } + } + } + } + } + // normalize trees + inline size_t NormalizeTrees(size_t size_new_trees) { + size_t num_drop = idx_drop.size(); + float factor = 1.0 * num_drop / (num_drop + 1.0); + if (num_drop == 0) { + for (size_t i = 0; i < size_new_trees; ++i) { + weight_drop.push_back(1.0); + } + } else { + if (dparam.norm_drop == 2) { + // norm_drop 2 + float fl = 1.0 * num_drop / (num_drop + dparam.learning_rate); + for (size_t i = 0; i < idx_drop.size(); ++i) { + weight_drop[i] *= fl; + } + for (size_t i = 0; i < size_new_trees; ++i) { + weight_drop.push_back(fl); + } + } else if (dparam.norm_drop == 1) { + // norm_drop 1 + for (size_t i = 0; i < idx_drop.size(); ++i) { + weight_drop[i] *= factor; + } + for (size_t i = 0; i < size_new_trees; ++i) { + weight_drop.push_back(factor); + } + } else { + // norm_drop 0 + for (size_t i = 0; i < idx_drop.size(); ++i) { + weight_drop[i] *= factor; + } + for (size_t i = 0; i < size_new_trees; ++i) { + weight_drop.push_back(1.0 / (num_drop + 1.0)); + } + } + } + // reset + idx_drop.clear(); + return num_drop; + } + + // --- data structure --- + // training parameter + DartTrainParam dparam; + /*! \brief prediction buffer */ + std::vector weight_drop; + // indexes of dropped trees + std::vector idx_drop; + +}; + // register the ojective functions DMLC_REGISTER_PARAMETER(GBTreeModelParam); DMLC_REGISTER_PARAMETER(GBTreeTrainParam); +DMLC_REGISTER_PARAMETER(DartTrainParam); XGBOOST_REGISTER_GBM(GBTree, "gbtree") .describe("Tree booster, gradient boosted trees.") .set_body([]() { return new GBTree(); }); +XGBOOST_REGISTER_GBM(Dart, "dart") +.describe("Tree booster, dart.") +.set_body([]() { + return new Dart(); + }); } // namespace gbm } // namespace xgboost diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index c81935e9d937..18ec94e219ff 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -23,6 +23,32 @@ def test_glm(self): if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) assert err < 0.1 + def test_dart(self): + dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') + dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + param = {'max_depth': 5, 'objective': 'binary:logistic', 'booster': 'dart'} + # specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 2 + bst = xgb.train(param, dtrain, num_round, watchlist) + # this is prediction + preds = bst.predict(dtest) + labels = dtest.get_label() + err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + # error must be smaller than 10% + assert err < 0.1 + + # save dmatrix into binary buffer + dtest.save_binary('dtest.buffer') + # save model + bst.save_model('xgb.model') + # load model and data in + bst2 = xgb.Booster(params=param, model_file='xgb.model') + dtest2 = xgb.DMatrix('dtest.buffer') + preds2 = bst2.predict(dtest2) + # assert they are the same + assert np.sum(np.abs(preds2 - preds)) == 0 + def test_eta_decay(self): watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 4