diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 5015aae14dcf..7a5504bd2b6c 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -10,262 +10,10 @@ import numpy from . import rabit -from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError +from .core import Booster, XGBoostError from .compat import STRING_TYPES -def _get_callback_context(env): - """return whether the current callback context is cv or train""" - if env.model is not None and env.cvfolds is None: - context = 'train' - elif env.model is None and env.cvfolds is not None: - context = 'cv' - else: - raise ValueError("Unexpected input with both model and cvfolds.") - return context - - -def _fmt_metric(value, show_stdv=True): - """format metric string""" - if len(value) == 2: - return f"{value[0]}:{value[1]:.5f}" - if len(value) == 3: - if show_stdv: - return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" - return f"{value[0]}:{value[1]:.5f}" - raise ValueError("wrong metric value", value) - - -def print_evaluation(period=1, show_stdv=True): - """Create a callback that print evaluation result. - - We print the evaluation results every **period** iterations - and on the first and the last iterations. - - Parameters - ---------- - period : int - The period to log the evaluation results - - show_stdv : bool, optional - Whether show stdv if provided - - Returns - ------- - callback : function - A callback that print evaluation every period iterations. - """ - def callback(env): - """internal function""" - if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0: - return - i = env.iteration - if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration: - msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list]) - rabit.tracker_print(f"{i}\t{msg}\n") - return callback - - -def record_evaluation(eval_result): - """Create a call back that records the evaluation history into **eval_result**. - - Parameters - ---------- - eval_result : dict - A dictionary to store the evaluation results. - - Returns - ------- - callback : function - The requested callback function. - """ - if not isinstance(eval_result, dict): - raise TypeError('eval_result has to be a dictionary') - eval_result.clear() - - def init(env): - """internal function""" - for k, _ in env.evaluation_result_list: - pos = k.index('-') - key = k[:pos] - metric = k[pos + 1:] - if key not in eval_result: - eval_result[key] = {} - if metric not in eval_result[key]: - eval_result[key][metric] = [] - - def callback(env): - """internal function""" - if not eval_result: - init(env) - for k, v in env.evaluation_result_list: - pos = k.index('-') - key = k[:pos] - metric = k[pos + 1:] - eval_result[key][metric].append(v) - return callback - - -def reset_learning_rate(learning_rates): - """Reset learning rate after iteration 1 - - NOTE: the initial learning rate will still take in-effect on first iteration. - - Parameters - ---------- - learning_rates: list or function - List of learning rate for each boosting round - or a customized function that calculates eta in terms of - current number of round and the total number of boosting round (e.g. - yields learning rate decay) - - * list ``l``: ``eta = l[boosting_round]`` - * function ``f``: ``eta = f(boosting_round, num_boost_round)`` - - Returns - ------- - callback : function - The requested callback function. - """ - def get_learning_rate(i, n, learning_rates): - """helper providing the learning rate""" - if isinstance(learning_rates, list): - if len(learning_rates) != n: - raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") - new_learning_rate = learning_rates[i] - else: - new_learning_rate = learning_rates(i, n) - return new_learning_rate - - def callback(env): - """internal function""" - context = _get_callback_context(env) - - if context == 'train': - bst, i, n = env.model, env.iteration, env.end_iteration - bst.set_param( - 'learning_rate', get_learning_rate(i, n, learning_rates)) - elif context == 'cv': - i, n = env.iteration, env.end_iteration - for cvpack in env.cvfolds: - bst = cvpack.bst - bst.set_param( - 'learning_rate', get_learning_rate(i, n, learning_rates)) - - callback.before_iteration = False - return callback - - -def early_stop(stopping_rounds, maximize=False, verbose=True): - """Create a callback that activates early stoppping. - - Validation error needs to decrease at least - every **stopping_rounds** round(s) to continue training. - Requires at least one item in **evals**. - If there's more than one, will use the last. - Returns the model from the last iteration (not the best one). - If early stopping occurs, the model will have three additional fields: - ``bst.best_score``, ``bst.best_iteration``. - - Parameters - ---------- - stopping_rounds : int - The stopping rounds before the trend occur. - - maximize : bool - Whether to maximize evaluation metric. - - verbose : optional, bool - Whether to print message about early stopping information. - - Returns - ------- - callback : function - The requested callback function. - """ - state = {} - - def init(env): - """internal function""" - bst = env.model - - if not env.evaluation_result_list: - raise ValueError('For early stopping you need at least one set in evals.') - if len(env.evaluation_result_list) > 1 and verbose: - msg = ("Multiple eval metrics have been passed: " - "'{0}' will be used for early stopping.\n\n") - rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0])) - maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg') - maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@') - maximize_score = maximize - metric_label = env.evaluation_result_list[-1][0] - metric = metric_label.split('-', 1)[-1] - - if any(metric.startswith(x) for x in maximize_at_n_metrics): - maximize_score = True - - if any(metric.split(":")[0] == x for x in maximize_metrics): - maximize_score = True - - if verbose and env.rank == 0: - msg = "Will train until {} hasn't improved in {} rounds.\n" - rabit.tracker_print(msg.format(metric_label, stopping_rounds)) - - state['maximize_score'] = maximize_score - state['best_iteration'] = 0 - if maximize_score: - state['best_score'] = float('-inf') - else: - state['best_score'] = float('inf') - # pylint: disable=consider-using-f-string - msg = '[%d]\t%s' % ( - env.iteration, - '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]) - ) - state['best_msg'] = msg - - if bst is not None: - if bst.attr('best_score') is not None: - state['best_score'] = float(bst.attr('best_score')) - state['best_iteration'] = int(bst.attr('best_iteration')) - state['best_msg'] = bst.attr('best_msg') - else: - bst.set_attr(best_iteration=str(state['best_iteration'])) - bst.set_attr(best_score=str(state['best_score'])) - else: - assert env.cvfolds is not None - - def callback(env): - """internal function""" - if not state: - init(env) - score = env.evaluation_result_list[-1][1] - best_score = state['best_score'] - best_iteration = state['best_iteration'] - maximize_score = state['maximize_score'] - if (maximize_score and score > best_score) or \ - (not maximize_score and score < best_score): - # pylint: disable=consider-using-f-string - msg = '[%d]\t%s' % ( - env.iteration, - '\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])) - state['best_msg'] = msg - state['best_score'] = score - state['best_iteration'] = env.iteration - # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr(best_score=str(state['best_score']), - best_iteration=str(state['best_iteration']), - best_msg=state['best_msg']) - elif env.iteration - best_iteration >= stopping_rounds: - best_msg = state['best_msg'] - if verbose and env.rank == 0: - msg = "Stopping. Best iteration:\n{}\n\n" - rabit.tracker_print(msg.format(best_msg)) - raise EarlyStopException(best_iteration) - return callback - - # The new implementation of callback functions. # Breaking: # - reset learning rate no longer accepts total boosting rounds @@ -741,100 +489,3 @@ def after_iteration(self, model, epoch: int, model.save_model(path) self._epoch += 1 return False - - -class LegacyCallbacks: - '''Adapter for legacy callback functions. - - .. versionadded:: 1.3.0 - - Parameters - ---------- - - callbacks : Sequence - A sequence of legacy callbacks (callbacks that are not instance of - TrainingCallback) - start_iteration : int - Begining iteration. - end_iteration : int - End iteration, normally is the number of boosting rounds. - evals : Sequence - Sequence of evaluation dataset tuples. - feval : Custom evaluation metric. - ''' - def __init__(self, callbacks, start_iteration, end_iteration, - feval, cvfolds=None): - self.callbacks_before_iter = [ - cb for cb in callbacks - if cb.__dict__.get('before_iteration', False)] - self.callbacks_after_iter = [ - cb for cb in callbacks - if not cb.__dict__.get('before_iteration', False)] - - self.start_iteration = start_iteration - self.end_iteration = end_iteration - self.cvfolds = cvfolds - - self.feval = feval - assert self.feval is None or callable(self.feval) - - if cvfolds is not None: - self.aggregated_cv = None - - super().__init__() - - def before_training(self, model): - '''Nothing to do for legacy callbacks''' - return model - - def after_training(self, model): - '''Nothing to do for legacy callbacks''' - return model - - def before_iteration(self, model, epoch, dtrain, evals): - '''Called before each iteration.''' - for cb in self.callbacks_before_iter: - rank = rabit.get_rank() - cb(CallbackEnv(model=None if self.cvfolds is not None else model, - cvfolds=self.cvfolds, - iteration=epoch, - begin_iteration=self.start_iteration, - end_iteration=self.end_iteration, - rank=rank, - evaluation_result_list=None)) - return False - - def after_iteration(self, model, epoch, dtrain, evals): - '''Called after each iteration.''' - evaluation_result_list = [] - if self.cvfolds is not None: - # dtrain is not used here. - scores = model.eval(epoch, self.feval) - self.aggregated_cv = _aggcv(scores) - evaluation_result_list = self.aggregated_cv - - if evals: - # When cv is used, evals are embedded into folds. - assert self.cvfolds is None - bst_eval_set = model.eval_set(evals, epoch, self.feval) - if isinstance(bst_eval_set, STRING_TYPES): - msg = bst_eval_set - else: - msg = bst_eval_set.decode() - res = [x.split(':') for x in msg.split()] - evaluation_result_list = [(k, float(v)) for k, v in res[1:]] - - try: - for cb in self.callbacks_after_iter: - rank = rabit.get_rank() - cb(CallbackEnv(model=None if self.cvfolds is not None else model, - cvfolds=self.cvfolds, - iteration=epoch, - begin_iteration=self.start_iteration, - end_iteration=self.end_iteration, - rank=rank, - evaluation_result_list=evaluation_result_list)) - except EarlyStopException: - return True - - return False diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 45a3d68d6586..5218630e7196 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2,7 +2,6 @@ # pylint: disable=too-many-arguments, too-many-branches, invalid-name # pylint: disable=too-many-lines, too-many-locals, no-self-use """Core XGBoost Library.""" -import collections # pylint: disable=no-name-in-module,import-error from collections.abc import Mapping from typing import List, Optional, Any, Union, Dict, TypeVar @@ -46,18 +45,6 @@ def __init__(self, best_iteration): self.best_iteration = best_iteration -# Callback environment used by callbacks -CallbackEnv = collections.namedtuple( - "XGBoostCallbackEnv", - ["model", - "cvfolds", - "iteration", - "begin_iteration", - "end_iteration", - "rank", - "evaluation_result_list"]) - - def from_pystr_to_cstr(data: Union[str, List[str]]): """Convert a Python str or list of Python str to C pointer diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index c7fd14e435cf..3204dc2b3ff5 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -2,40 +2,24 @@ # pylint: disable=too-many-locals, too-many-arguments, invalid-name # pylint: disable=too-many-branches, too-many-statements """Training Library containing training routines.""" -import warnings import copy +from typing import Optional, List + import numpy as np from .core import Booster, XGBoostError, _get_booster_layer_trees from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback -def _configure_deprecated_callbacks( - verbose_eval, early_stopping_rounds, maximize, start_iteration, - num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds): - link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html' - warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning) - # Most of legacy advanced options becomes callbacks - if early_stopping_rounds is not None: - callbacks.append(callback.early_stop(early_stopping_rounds, - maximize=maximize, - verbose=bool(verbose_eval))) - if isinstance(verbose_eval, bool) and verbose_eval: - callbacks.append(callback.print_evaluation(show_stdv=show_stdv)) - else: - if isinstance(verbose_eval, int): - callbacks.append(callback.print_evaluation(verbose_eval, - show_stdv=show_stdv)) - if evals_result is not None: - callbacks.append(callback.record_evaluation(evals_result)) - callbacks = callback.LegacyCallbacks( - callbacks, start_iteration, num_boost_round, feval, cvfolds=cvfolds) - return callbacks - - -def _is_new_callback(callbacks): - return any(isinstance(c, callback.TrainingCallback) - for c in callbacks) or not callbacks +def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None: + is_new_callback: bool = not callbacks or all( + isinstance(c, callback.TrainingCallback) for c in callbacks + ) + if not is_new_callback: + link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html" + raise ValueError( + f"Old style callback was removed in version 1.6. See: {link}." + ) def _train_internal(params, dtrain, @@ -56,22 +40,15 @@ def _train_internal(params, dtrain, start_iteration = 0 - is_new_callback = _is_new_callback(callbacks) - if is_new_callback: - assert all(isinstance(c, callback.TrainingCallback) - for c in callbacks), "You can't mix new and old callback styles." - if verbose_eval: - verbose_eval = 1 if verbose_eval is True else verbose_eval - callbacks.append(callback.EvaluationMonitor(period=verbose_eval)) - if early_stopping_rounds: - callbacks.append(callback.EarlyStopping( - rounds=early_stopping_rounds, maximize=maximize)) - callbacks = callback.CallbackContainer(callbacks, metric=feval) - else: - callbacks = _configure_deprecated_callbacks( - verbose_eval, early_stopping_rounds, maximize, start_iteration, - num_boost_round, feval, evals_result, callbacks, - show_stdv=False, cvfolds=None) + _assert_new_callback(callbacks) + if verbose_eval: + verbose_eval = 1 if verbose_eval is True else verbose_eval + callbacks.append(callback.EvaluationMonitor(period=verbose_eval)) + if early_stopping_rounds: + callbacks.append( + callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) + ) + callbacks = callback.CallbackContainer(callbacks, metric=feval) bst = callbacks.before_training(bst) @@ -84,7 +61,7 @@ def _train_internal(params, dtrain, bst = callbacks.after_training(bst) - if evals_result is not None and is_new_callback: + if evals_result is not None: evals_result.update(callbacks.history) # These should be moved into callback functions `after_training`, but until old @@ -468,25 +445,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None # setup callbacks callbacks = [] if callbacks is None else callbacks - is_new_callback = _is_new_callback(callbacks) - if is_new_callback: - assert all(isinstance(c, callback.TrainingCallback) - for c in callbacks), "You can't mix new and old callback styles." - if verbose_eval: - verbose_eval = 1 if verbose_eval is True else verbose_eval - callbacks.append( - callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv) - ) - if early_stopping_rounds: - callbacks.append( - callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) - ) - callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True) - else: - callbacks = _configure_deprecated_callbacks( - verbose_eval, early_stopping_rounds, maximize, 0, - num_boost_round, feval, None, callbacks, - show_stdv=show_stdv, cvfolds=cvfolds) + _assert_new_callback(callbacks) + + if verbose_eval: + verbose_eval = 1 if verbose_eval is True else verbose_eval + callbacks.append( + callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv) + ) + if early_stopping_rounds: + callbacks.append( + callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize) + ) + callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True) + booster = _PackedBooster(cvfolds) callbacks.before_training(booster) diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index 3f40999861b9..b65370f10e09 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -41,8 +41,7 @@ def test_custom_objective(self): self.cpu_test_bm.run_custom_objective("gpu_hist") def test_eta_decay_gpu_hist(self): - self.cpu_test_cb.run_eta_decay('gpu_hist', True) - self.cpu_test_cb.run_eta_decay('gpu_hist', False) + self.cpu_test_cb.run_eta_decay('gpu_hist') def test_deterministic_gpu_hist(self): kRows = 1000 diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index f95043fed682..e155ab0478e9 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -76,23 +76,6 @@ def test_metric_config(self): predt_1 = booster.predict(dtrain) np.testing.assert_allclose(predt_0, predt_1) - def test_record_results(self): - dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') - dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') - param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, - 'objective': 'binary:logistic', 'eval_metric': 'error'} - # specify validations set to watch performance - watchlist = [(dtest, 'eval'), (dtrain, 'train')] - num_round = 2 - result = {} - res2 = {} - xgb.train(param, dtrain, num_round, watchlist, - callbacks=[xgb.callback.record_evaluation(result)]) - xgb.train(param, dtrain, num_round, watchlist, - evals_result=res2) - assert result['train']['error'][0] < 0.1 - assert res2 == result - def test_multiclass(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') @@ -254,8 +237,18 @@ def test_cv_explicit_fold_indices_labels(self): ] # Use callback to log the test labels in each fold - def cb(cbackenv): - print([fold.dtest.get_label() for fold in cbackenv.cvfolds]) + class Callback(xgb.callback.TrainingCallback): + def __init__(self) -> None: + super().__init__() + + def after_iteration( + self, model, + epoch: int, + evals_log: xgb.callback.TrainingCallback.EvalsLog + ): + print([fold.dtest.get_label() for fold in model.cvfolds]) + + cb = Callback() # Run cross validation and capture standard out to test callback result with tm.captured_output() as (out, err): diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 204e5383a975..3a2d5eecf1fc 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -249,12 +249,9 @@ def test_early_stopping_continuation(self): assert booster.num_boosted_rounds() == \ booster.best_iteration + early_stopping_rounds + 1 - def run_eta_decay(self, tree_method, deprecated_callback): + def run_eta_decay(self, tree_method): """Test learning rate scheduler, used by both CPU and GPU tests.""" - if deprecated_callback: - scheduler = xgb.callback.reset_learning_rate - else: - scheduler = xgb.callback.LearningRateScheduler + scheduler = xgb.callback.LearningRateScheduler dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') @@ -262,10 +259,7 @@ def run_eta_decay(self, tree_method, deprecated_callback): watchlist = [(dtest, 'eval'), (dtrain, 'train')] num_round = 4 - if deprecated_callback: - warning_check = pytest.warns(UserWarning) - else: - warning_check = tm.noop_context() + warning_check = tm.noop_context() # learning_rates as a list # init eta with 0 to check whether learning_rates work @@ -339,19 +333,9 @@ def eta_decay(ithround, num_boost_round=num_round): with warning_check: xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)]) - @pytest.mark.parametrize( - "tree_method, deprecated_callback", - [ - ("hist", True), - ("hist", False), - ("approx", True), - ("approx", False), - ("exact", True), - ("exact", False), - ], - ) - def test_eta_decay(self, tree_method, deprecated_callback): - self.run_eta_decay(tree_method, deprecated_callback) + @pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"]) + def test_eta_decay(self, tree_method): + self.run_eta_decay(tree_method) def test_check_point(self): from sklearn.datasets import load_breast_cancer diff --git a/tests/python/test_survival.py b/tests/python/test_survival.py index 41a618de5bd5..1fb931545c04 100644 --- a/tests/python/test_survival.py +++ b/tests/python/test_survival.py @@ -22,15 +22,25 @@ def test_aft_survival_toy_data(): # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes # the corresponding predicted label (y_pred) acc_rec = [] - def my_callback(env): - y_pred = env.model.predict(dmat) - acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X)) - acc_rec.append(acc) + + class Callback(xgb.callback.TrainingCallback): + def __init__(self): + super().__init__() + + def after_iteration( + self, model: xgb.Booster, + epoch: int, + evals_log: xgb.callback.TrainingCallback.EvalsLog + ): + y_pred = model.predict(dmat) + acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X)) + acc_rec.append(acc) + return False evals_result = {} - params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0} + params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0} bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result, - callbacks=[my_callback]) + callbacks=[Callback()]) nloglik_rec = evals_result['train']['aft-nloglik'] # AFT metric (negative log likelihood) improve monotonically