Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support early stopping with training continuation, correct num boosted rounds. #6506

Merged
merged 3 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);

/*!
* \brief Get number of boosted rounds from gradient booster. When process_type is
* update, this number might drop due to removed tree.
* \param handle Handle to booster.
* \param out Pointer to output integer.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out);

/*!
* \brief set parameters
* \param handle handle
Expand Down
3 changes: 3 additions & 0 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable {
virtual bool AllowLazyCheckPoint() const {
return false;
}
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;

/*
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;

void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;

Expand Down
94 changes: 57 additions & 37 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import collections
import os
import pickle
from typing import Callable, List
from typing import Callable, List, Optional, Union, Dict, Tuple
import numpy

from . import rabit
Expand Down Expand Up @@ -285,11 +285,13 @@ def after_training(self, model):
'''Run after training is finished.'''
return model

def before_iteration(self, model, epoch, evals_log):
def before_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run before each iteration. Return True when training should stop.'''
return False

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run after each iteration. Return True when training should stop.'''
return False

Expand Down Expand Up @@ -346,16 +348,21 @@ class CallbackContainer:
.. versionadded:: 1.3.0

'''
def __init__(self, callbacks: List[TrainingCallback],
metric: Callable = None, is_cv: bool = False):

EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]

def __init__(self,
callbacks: List[TrainingCallback],
metric: Callable = None,
is_cv: bool = False):
self.callbacks = set(callbacks)
if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \
'builtin metrics, passing them in training parameter' + \
' will invoke monitor automatically.'
assert callable(metric), msg
self.metric = metric
self.history = collections.OrderedDict()
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
self.is_cv = is_cv

if self.is_cv:
Expand Down Expand Up @@ -383,7 +390,7 @@ def after_training(self, model):
assert isinstance(model, Booster), msg
return model

def before_iteration(self, model, epoch, dtrain, evals):
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks)
Expand All @@ -409,7 +416,7 @@ def _update_history(self, score, epoch):
self.history[data_name][metric_name] = [s]
return False

def after_iteration(self, model, epoch, dtrain, evals):
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.'''
if self.is_cv:
scores = model.eval(epoch, self.metric)
Expand Down Expand Up @@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
rounds.

'''
def __init__(self, learning_rates):
def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence)
if callable(learning_rates):
Expand All @@ -454,53 +461,59 @@ def __init__(self, learning_rates):
self.learning_rates = lambda epoch: learning_rates[epoch]
super().__init__()

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch, evals_log) -> bool:
model.set_param('learning_rate', self.learning_rates(epoch))
return False


# pylint: disable=too-many-instance-attributes
class EarlyStopping(TrainingCallback):
''' Callback function for early stopping
"""Callback function for early stopping

.. versionadded:: 1.3.0

Parameters
----------
rounds : int
rounds
Early stopping rounds.
metric_name : str
metric_name
Name of metric that is used for early stopping.
data_name: str
data_name
Name of dataset that is used for early stopping.
maximize : bool
maximize
Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool
save_best
Whether training should return the best model or the last model.
'''
"""
def __init__(self,
rounds,
metric_name=None,
data_name=None,
maximize=None,
save_best=False):
rounds: int,
metric_name: Optional[str] = None,
data_name: Optional[str] = None,
maximize: Optional[bool] = None,
save_best: Optional[bool] = False) -> None:
self.data = data_name
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
self.maximize = maximize
self.stopping_history = {}
self.stopping_history: CallbackContainer.EvalsLog = {}

if self.maximize is not None:
if self.maximize:
self.improve_op = lambda x, y: x > y
else:
self.improve_op = lambda x, y: x < y

self.current_rounds = 0
self.best_scores = {}
self.current_rounds: int = 0
self.best_scores: dict = {}
self.starting_round: int = 0
super().__init__()

def _update_rounds(self, score, name, metric, model, epoch):
def before_training(self, model):
self.starting_round = model.num_boosted_rounds()
return model

def _update_rounds(self, score, name, metric, model, epoch) -> bool:
# Just to be compatibility with old behavior before 1.3. We should let
# user to decide.
if self.maximize is None:
Expand Down Expand Up @@ -536,7 +549,9 @@ def _update_rounds(self, score, name, metric, model, epoch):
return True
return False

def after_iteration(self, model: Booster, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg
data_name = ''
Expand All @@ -562,12 +577,14 @@ def after_iteration(self, model: Booster, epoch, evals_log):
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)

def after_training(self, model: Booster):
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
def after_training(self, model):
try:
if self.save_best:
model = model[: int(model.attr('best_iteration'))]
model = model[: int(model.attr("best_iteration")) + 1]
except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e
raise XGBoostError(
"`save_best` is not applicable to current booster"
) from e
return model


Expand All @@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
show_stdv : bool
Used in cv to show standard deviation. Users should not specify it.
'''
def __init__(self, rank=0, period=1, show_stdv=False):
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
self.printer_rank = rank
self.show_stdv = show_stdv
self.period = period
assert period > 0
# last error message, useful when early stopping and period are used together.
self._latest = None
self._latest: Optional[str] = None
super().__init__()

def _fmt_metric(self, data, metric, score, std):
def _fmt_metric(self, data, metric, score, std) -> str:
if std is not None and self.show_stdv:
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
else:
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
return msg

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if not evals_log:
return False

msg = f'[{epoch}]'
msg: str = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
stdv: Optional[float] = None
if isinstance(log[-1], tuple):
score = log[-1][0]
stdv = log[-1][1]
else:
score = log[-1]
stdv = None
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'

Expand Down Expand Up @@ -665,7 +683,8 @@ def __init__(self, directory: os.PathLike, name: str = 'model',
self._epoch = 0
super().__init__()

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json'))
Expand All @@ -677,6 +696,7 @@ def after_iteration(self, model, epoch, evals_log):
else:
model.save_model(path)
self._epoch += 1
return False


class LegacyCallbacks:
Expand Down
28 changes: 11 additions & 17 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,23 +1177,6 @@ def copy(self):
"""
return self.__copy__()

def load_rabit_checkpoint(self):
"""Initialize the model by load from rabit checkpoint.

Returns
-------
version: integer
The version number of the model.
"""
version = ctypes.c_int()
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
self.handle, ctypes.byref(version)))
return version.value

def save_rabit_checkpoint(self):
"""Save the current booster to rabit checkpoint."""
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))

def attr(self, key):
"""Get attribute string from the Booster.

Expand Down Expand Up @@ -1745,6 +1728,17 @@ def load_model(self, fname):
else:
raise TypeError('Unknown file type: ', fname)

def num_boosted_rounds(self) -> int:
'''Get number of boosted rounds. For gblinear this is reset to 0 after
serializing the model.

'''
rounds = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterBoostedRounds(
self.handle, ctypes.byref(rounds)))
return rounds.value

def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. Unlike `save_model`, the
output format is primarily used for visualization or interpretation,
Expand Down
Loading