Skip to content

Commit

Permalink
Fix early stopping in the Python package (#4638)
Browse files Browse the repository at this point in the history
* Fix #4630, #4421: Preserve correct ordering between metrics, and always use last metric for early stopping

* Clarify semantics of early stopping in presence of multiple valid sets and metrics

* Add a test

* Fix lint
  • Loading branch information
hcho3 authored Jul 7, 2019
1 parent 562d9ae commit 1aaf4a6
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 115 deletions.
163 changes: 61 additions & 102 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
"""Fit gradient boosting model
Parameters
----------
Expand All @@ -321,34 +320,39 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
sample_weight : array_like
instance weights
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
eval_metric : str, list of str, or callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use.
If callable, a custom evaluation metric. The call
signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a
DMatrix object such that you may need to call the ``get_label``
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
function. The callable custom objective is always minimized.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have three additional fields: bst.best_score, bst.best_iteration
and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
Expand Down Expand Up @@ -629,56 +633,7 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
Weight for each instance
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int, optional
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. If early stopping occurs, the model will have
three additional fields: bst.best_score, bst.best_iteration and
bst.best_ntree_limit (bst.best_ntree_limit is the ntree_limit parameter
default value in predict method if not any other value is specified).
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]
"""
evals_result = {}
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
Expand Down Expand Up @@ -751,6 +706,9 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,

return self

fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
'Fit gradient boosting classifier', 1)

def predict(self, data, output_margin=False, ntree_limit=None, validate_features=True):
"""
Predict with `data`.
Expand Down Expand Up @@ -1027,14 +985,15 @@ class XGBRanker(XGBModel):
Note
----
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
Note
----
Group information is required for ranking tasks.
Query group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by group. When
Before fitting the model, your data need to be sorted by query group. When
fitting the model, you need to provide an additional array that
contains the size of each group.
contains the size of each query group.
For example, if your original data look like:
Expand Down Expand Up @@ -1086,7 +1045,7 @@ def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval
verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit the gradient boosting model
Fit gradient boosting ranker
Parameters
----------
Expand All @@ -1095,57 +1054,57 @@ def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval
y : array_like
Labels
group : array_like
group size of training data
Size of each query group of training data. Should have as many elements as
the query groups in the training data
sample_weight : array_like
group weights
Query group weights
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
In ranking task, one weight is assigned to each query group (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
group weights on the i-th validation set.
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each group (not each data
point). This is because we only care about the relative ordering of
In ranking task, one weight is assigned to each query group (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
eval_group : list of arrays, optional
A list that contains the group size corresponds to each
(X, y) pair in eval_set
eval_metric : str, callable, optional
A list in which ``eval_group[i]`` is the list containing the sizes of all
query groups in the ``i``-th pair in **eval_set**.
eval_metric : str, list of str, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
doc/parameter.rst.
If a list of str, should be the list of multiple built-in evaluation metrics
to use. The custom evaluation metric is not yet supported for the ranker.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have three additional fields: bst.best_score, bst.best_iteration
and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **eval_set**.
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
Expand Down Expand Up @@ -1199,9 +1158,9 @@ def _dmat_init(group, **params):
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})
raise ValueError('Custom evaluation metric is not yet supported' +
'for XGBRanker.')
params.update({'eval_metric': eval_metric})

self._Booster = train(params, train_dmatrix,
self.n_estimators,
Expand Down
32 changes: 19 additions & 13 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,23 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
num_boost_round: int
Number of boosting iterations.
evals: list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
List of validation sets for which metrics will evaluated during training.
Validation metrics will help us track the performance of the model.
obj : function
Customized objective function.
feval : function
Customized evaluation function.
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
The method returns the model from the last iteration (not the best one).
If there's more than one item in **evals**, the last entry will be used
for early stopping.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if
Expand Down Expand Up @@ -352,16 +355,16 @@ def aggcv(rlist):
for line in rlist:
arr = line.split()
assert idx == arr[0]
for it in arr[1:]:
for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, STRING_TYPES):
it = it.decode()
k, v = it.split(':')
if k not in cvmap:
cvmap[k] = []
cvmap[k].append(float(v))
if (metric_idx, k) not in cvmap:
cvmap[(metric_idx, k)] = []
cvmap[(metric_idx, k)].append(float(v))
msg = idx
results = []
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
v = np.array(v)
if not isinstance(msg, STRING_TYPES):
msg = msg.decode()
Expand Down Expand Up @@ -405,9 +408,12 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
maximize : bool
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
Last entry in evaluation history is the one from best iteration.
Activates early stopping. Cross-Validation metric (average of validation
metric computed over CV folds) needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training.
The last entry in the evaluation history will represent the best iteration.
If there's more than one metric in the **eval_metric** parameter given in
**params**, the last metric will be used for early stopping.
fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those.
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,29 @@ def test_cv_early_stopping(self):
feval=self.evalerror, maximize=True,
early_stopping_rounds=1)
self.assert_metrics_length(cv, 1)

@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping_with_multiple_eval_sets_and_metrics(self):
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True)
dm = xgb.DMatrix(X, label=y)
params = {'objective':'binary:logistic'}

metrics = [['auc'], ['error'], ['logloss'],
['logloss', 'auc'], ['logloss', 'error'], ['error', 'logloss']]

num_iteration_history = []

# If more than one metrics is given, early stopping should use the last metric
for i, m in enumerate(metrics):
result = xgb.cv(params, dm, num_boost_round=1000, nfold=5, stratified=True,
metrics=m, early_stopping_rounds=20, seed=42)
num_iteration_history.append(len(result))
df = result['test-{}-mean'.format(m[-1])]
# When early stopping is invoked, the last metric should be as best it can be.
if m[-1] == 'auc':
assert np.all(df <= df.iloc[-1])
else:
assert np.all(df >= df.iloc[-1])
assert num_iteration_history[:3] == num_iteration_history[3:]

0 comments on commit 1aaf4a6

Please sign in to comment.