Skip to content

Commit

Permalink
Remove old callback deprecated in 1.3.
Browse files Browse the repository at this point in the history
CV.

Lint.

Fix tests.

Remove fmt too.

Fix GPU test.

Apply suggestions from code review

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>

Remove old error.

Remove duplicated code.

Remove callback env.

Unused import.

Lint.

Address reviewer's comment.
  • Loading branch information
trivialfis committed Oct 8, 2021
1 parent 578de9f commit 8e1014e
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 475 deletions.
351 changes: 1 addition & 350 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,262 +10,10 @@
import numpy

from . import rabit
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
from .core import Booster, XGBoostError
from .compat import STRING_TYPES


def _get_callback_context(env):
"""return whether the current callback context is cv or train"""
if env.model is not None and env.cvfolds is None:
context = 'train'
elif env.model is None and env.cvfolds is not None:
context = 'cv'
else:
raise ValueError("Unexpected input with both model and cvfolds.")
return context


def _fmt_metric(value, show_stdv=True):
"""format metric string"""
if len(value) == 2:
return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3:
if show_stdv:
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)


def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.
We print the evaluation results every **period** iterations
and on the first and the last iterations.
Parameters
----------
period : int
The period to log the evaluation results
show_stdv : bool, optional
Whether show stdv if provided
Returns
-------
callback : function
A callback that print evaluation every period iterations.
"""
def callback(env):
"""internal function"""
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
return
i = env.iteration
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print(f"{i}\t{msg}\n")
return callback


def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into **eval_result**.
Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.
Returns
-------
callback : function
The requested callback function.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result has to be a dictionary')
eval_result.clear()

def init(env):
"""internal function"""
for k, _ in env.evaluation_result_list:
pos = k.index('-')
key = k[:pos]
metric = k[pos + 1:]
if key not in eval_result:
eval_result[key] = {}
if metric not in eval_result[key]:
eval_result[key][metric] = []

def callback(env):
"""internal function"""
if not eval_result:
init(env)
for k, v in env.evaluation_result_list:
pos = k.index('-')
key = k[:pos]
metric = k[pos + 1:]
eval_result[key][metric].append(v)
return callback


def reset_learning_rate(learning_rates):
"""Reset learning rate after iteration 1
NOTE: the initial learning rate will still take in-effect on first iteration.
Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g.
yields learning rate decay)
* list ``l``: ``eta = l[boosting_round]``
* function ``f``: ``eta = f(boosting_round, num_boost_round)``
Returns
-------
callback : function
The requested callback function.
"""
def get_learning_rate(i, n, learning_rates):
"""helper providing the learning rate"""
if isinstance(learning_rates, list):
if len(learning_rates) != n:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
new_learning_rate = learning_rates[i]
else:
new_learning_rate = learning_rates(i, n)
return new_learning_rate

def callback(env):
"""internal function"""
context = _get_callback_context(env)

if context == 'train':
bst, i, n = env.model, env.iteration, env.end_iteration
bst.set_param(
'learning_rate', get_learning_rate(i, n, learning_rates))
elif context == 'cv':
i, n = env.iteration, env.end_iteration
for cvpack in env.cvfolds:
bst = cvpack.bst
bst.set_param(
'learning_rate', get_learning_rate(i, n, learning_rates))

callback.before_iteration = False
return callback


def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping.
Validation error needs to decrease at least
every **stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
``bst.best_score``, ``bst.best_iteration``.
Parameters
----------
stopping_rounds : int
The stopping rounds before the trend occur.
maximize : bool
Whether to maximize evaluation metric.
verbose : optional, bool
Whether to print message about early stopping information.
Returns
-------
callback : function
The requested callback function.
"""
state = {}

def init(env):
"""internal function"""
bst = env.model

if not env.evaluation_result_list:
raise ValueError('For early stopping you need at least one set in evals.')
if len(env.evaluation_result_list) > 1 and verbose:
msg = ("Multiple eval metrics have been passed: "
"'{0}' will be used for early stopping.\n\n")
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
maximize_score = maximize
metric_label = env.evaluation_result_list[-1][0]
metric = metric_label.split('-', 1)[-1]

if any(metric.startswith(x) for x in maximize_at_n_metrics):
maximize_score = True

if any(metric.split(":")[0] == x for x in maximize_metrics):
maximize_score = True

if verbose and env.rank == 0:
msg = "Will train until {} hasn't improved in {} rounds.\n"
rabit.tracker_print(msg.format(metric_label, stopping_rounds))

state['maximize_score'] = maximize_score
state['best_iteration'] = 0
if maximize_score:
state['best_score'] = float('-inf')
else:
state['best_score'] = float('inf')
# pylint: disable=consider-using-f-string
msg = '[%d]\t%s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
)
state['best_msg'] = msg

if bst is not None:
if bst.attr('best_score') is not None:
state['best_score'] = float(bst.attr('best_score'))
state['best_iteration'] = int(bst.attr('best_iteration'))
state['best_msg'] = bst.attr('best_msg')
else:
bst.set_attr(best_iteration=str(state['best_iteration']))
bst.set_attr(best_score=str(state['best_score']))
else:
assert env.cvfolds is not None

def callback(env):
"""internal function"""
if not state:
init(env)
score = env.evaluation_result_list[-1][1]
best_score = state['best_score']
best_iteration = state['best_iteration']
maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
# pylint: disable=consider-using-f-string
msg = '[%d]\t%s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
state['best_msg'] = msg
state['best_score'] = score
state['best_iteration'] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(best_score=str(state['best_score']),
best_iteration=str(state['best_iteration']),
best_msg=state['best_msg'])
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose and env.rank == 0:
msg = "Stopping. Best iteration:\n{}\n\n"
rabit.tracker_print(msg.format(best_msg))
raise EarlyStopException(best_iteration)
return callback


# The new implementation of callback functions.
# Breaking:
# - reset learning rate no longer accepts total boosting rounds
Expand Down Expand Up @@ -741,100 +489,3 @@ def after_iteration(self, model, epoch: int,
model.save_model(path)
self._epoch += 1
return False


class LegacyCallbacks:
'''Adapter for legacy callback functions.
.. versionadded:: 1.3.0
Parameters
----------
callbacks : Sequence
A sequence of legacy callbacks (callbacks that are not instance of
TrainingCallback)
start_iteration : int
Begining iteration.
end_iteration : int
End iteration, normally is the number of boosting rounds.
evals : Sequence
Sequence of evaluation dataset tuples.
feval : Custom evaluation metric.
'''
def __init__(self, callbacks, start_iteration, end_iteration,
feval, cvfolds=None):
self.callbacks_before_iter = [
cb for cb in callbacks
if cb.__dict__.get('before_iteration', False)]
self.callbacks_after_iter = [
cb for cb in callbacks
if not cb.__dict__.get('before_iteration', False)]

self.start_iteration = start_iteration
self.end_iteration = end_iteration
self.cvfolds = cvfolds

self.feval = feval
assert self.feval is None or callable(self.feval)

if cvfolds is not None:
self.aggregated_cv = None

super().__init__()

def before_training(self, model):
'''Nothing to do for legacy callbacks'''
return model

def after_training(self, model):
'''Nothing to do for legacy callbacks'''
return model

def before_iteration(self, model, epoch, dtrain, evals):
'''Called before each iteration.'''
for cb in self.callbacks_before_iter:
rank = rabit.get_rank()
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
cvfolds=self.cvfolds,
iteration=epoch,
begin_iteration=self.start_iteration,
end_iteration=self.end_iteration,
rank=rank,
evaluation_result_list=None))
return False

def after_iteration(self, model, epoch, dtrain, evals):
'''Called after each iteration.'''
evaluation_result_list = []
if self.cvfolds is not None:
# dtrain is not used here.
scores = model.eval(epoch, self.feval)
self.aggregated_cv = _aggcv(scores)
evaluation_result_list = self.aggregated_cv

if evals:
# When cv is used, evals are embedded into folds.
assert self.cvfolds is None
bst_eval_set = model.eval_set(evals, epoch, self.feval)
if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
res = [x.split(':') for x in msg.split()]
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]

try:
for cb in self.callbacks_after_iter:
rank = rabit.get_rank()
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
cvfolds=self.cvfolds,
iteration=epoch,
begin_iteration=self.start_iteration,
end_iteration=self.end_iteration,
rank=rank,
evaluation_result_list=evaluation_result_list))
except EarlyStopException:
return True

return False
13 changes: 0 additions & 13 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals, no-self-use
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict, TypeVar
Expand Down Expand Up @@ -46,18 +45,6 @@ def __init__(self, best_iteration):
self.best_iteration = best_iteration


# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"XGBoostCallbackEnv",
["model",
"cvfolds",
"iteration",
"begin_iteration",
"end_iteration",
"rank",
"evaluation_result_list"])


def from_pystr_to_cstr(data: Union[str, List[str]]):
"""Convert a Python str or list of Python str to C pointer
Expand Down
Loading

0 comments on commit 8e1014e

Please sign in to comment.