Skip to content

Commit

Permalink
Add the missing max_delta_step (#3668)
Browse files Browse the repository at this point in the history
* add max_delta_step to SplitEvaluator

* test for max_delta_step

* missing x2 factor for L1 term

* remove gamma from ElasticNet
  • Loading branch information
khotilov authored Sep 12, 2018
1 parent d1e75d6 commit ad3a0bb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
14 changes: 14 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,17 @@ test_that("train and predict with non-strict classes", {
expect_error(pr <- predict(bst, train_dense), regexp = NA)
expect_equal(pr0, pr)
})

test_that("max_delta_step works", {
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
watchlist <- list(train = dtrain)
param <- list(objective = "binary:logistic", eval_metric="logloss", max_depth = 2, nthread = 2, eta = 0.5)
nrounds = 5
# model with no restriction on max_delta_step
bst1 <- xgb.train(param, dtrain, nrounds, watchlist, verbose = 1)
# model with restricted max_delta_step
bst2 <- xgb.train(param, dtrain, nrounds, watchlist, verbose = 1, max_delta_step = 1)
# the no-restriction model is expected to have consistently lower loss during the initial interations
expect_true(all(bst1$evaluation_log$train_logloss < bst2$evaluation_log$train_logloss))
expect_lt(mean(bst1$evaluation_log$train_logloss)/mean(bst2$evaluation_log$train_logloss), 0.8)
})
31 changes: 21 additions & 10 deletions src/tree/split_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
struct ElasticNetParams : public dmlc::Parameter<ElasticNetParams> {
bst_float reg_lambda;
bst_float reg_alpha;
bst_float reg_gamma;
// maximum delta update we can add in weight estimation
// this parameter can be used to stabilize update
// default=0 means no constraint on weight delta
float max_delta_step;

DMLC_DECLARE_PARAMETER(ElasticNetParams) {
DMLC_DECLARE_FIELD(reg_lambda)
Expand All @@ -74,13 +77,13 @@ struct ElasticNetParams : public dmlc::Parameter<ElasticNetParams> {
.set_lower_bound(0.0)
.set_default(0.0)
.describe("L1 regularization on leaf weight");
DMLC_DECLARE_FIELD(reg_gamma)
.set_lower_bound(0.0)
.set_default(0.0)
.describe("Cost incurred by adding a new leaf node to the tree");
DMLC_DECLARE_FIELD(max_delta_step)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
"If the value is set to 0, it means there is no constraint");
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
DMLC_DECLARE_ALIAS(reg_gamma, gamma);
}
};

Expand Down Expand Up @@ -127,17 +130,25 @@ class ElasticNet final : public SplitEvaluator {
const override {
auto loss = weight * (2.0 * stats.sum_grad + stats.sum_hess * weight
+ params_.reg_lambda * weight)
+ params_.reg_alpha * std::abs(weight);
+ 2.0 * params_.reg_alpha * std::abs(weight);
return -loss;
}

bst_float ComputeScore(bst_uint parentID, const GradStats &stats) const {
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda);
if (params_.max_delta_step == 0.0f) {
return Sqr(ThresholdL1(stats.sum_grad)) / (stats.sum_hess + params_.reg_lambda);
} else {
return ComputeScore(parentID, stats, ComputeWeight(parentID, stats));
}
}

bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override {
return -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda);
bst_float w = -ThresholdL1(stats.sum_grad) / (stats.sum_hess + params_.reg_lambda);
if (params_.max_delta_step != 0.0f && std::abs(w) > params_.max_delta_step) {
w = std::copysign(params_.max_delta_step, w);
}
return w;
}

private:
Expand All @@ -155,7 +166,7 @@ class ElasticNet final : public SplitEvaluator {
};

XGBOOST_REGISTER_SPLIT_EVALUATOR(ElasticNet, "elastic_net")
.describe("Use an elastic net regulariser and a cost per leaf node")
.describe("Use an elastic net regulariser")
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new ElasticNet(std::move(inner));
});
Expand Down

0 comments on commit ad3a0bb

Please sign in to comment.