Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #3663: Allow sklearn API to use callbacks #3682

Merged
merged 3 commits into from
Sep 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/python/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ Plotting API
.. autofunction:: xgboost.plot_tree

.. autofunction:: xgboost.to_graphviz

.. _callback_api:

Callback API
------------
.. autofunction:: xgboost.callback.print_evaluation

.. autofunction:: xgboost.callback.record_evaluation

.. autofunction:: xgboost.callback.reset_learning_rate

.. autofunction:: xgboost.callback.early_stop
23 changes: 12 additions & 11 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _fmt_metric(value, show_stdv=True):
def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.

We print the evaluation results every ``period`` iterations
We print the evaluation results every **period** iterations
and on the first and the last iterations.

Parameters
Expand Down Expand Up @@ -60,7 +60,7 @@ def callback(env):


def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result.
"""Create a call back that records the evaluation history into **eval_result**.

Parameters
----------
Expand Down Expand Up @@ -109,10 +109,11 @@ def reset_learning_rate(learning_rates):
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: eta = l[boosting_round]
- function f: eta = f(boosting_round, num_boost_round)
current number of round and the total number of boosting round (e.g.
yields learning rate decay)

* list ``l``: ``eta = l[boosting_round]``
* function ``f``: ``eta = f(boosting_round, num_boost_round)``

Returns
-------
Expand Down Expand Up @@ -150,14 +151,14 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping.

Validation error needs to decrease at least
every <stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
every **stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if ``num_parallel_tree``
and/or ``num_class`` appears in the parameters)

Parameters
----------
Expand Down
81 changes: 56 additions & 25 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
"""Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import

Expand Down Expand Up @@ -69,9 +69,9 @@ class XGBModel(XGBModelBase):
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread)
Number of parallel threads used to run xgboost. (replaces ``nthread``)
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Expand Down Expand Up @@ -242,7 +242,7 @@ def load_model(self, fname):

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -285,6 +285,14 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:

.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]
"""
if sample_weight is not None:
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
Expand Down Expand Up @@ -325,7 +333,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model)
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

if evals_result:
for val in evals_result.items():
Expand Down Expand Up @@ -413,10 +422,10 @@ def apply(self, X, ntree_limit=0):
def evals_result(self):
"""Return the evaluation results.

If ``eval_set`` is passed to the `fit` function, you can call ``evals_result()`` to
get evaluation results for all passed eval_sets. When ``eval_metric`` is also
passed to the ``fit`` function, the ``evals_result`` will contain the ``eval_metrics``
passed to the ``fit`` function
If **eval_set** is passed to the `fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**.
When **eval_metric** is also passed to the `fit` function, the
**evals_result** will contain the **eval_metrics** passed to the `fit` function.

Returns
-------
Expand All @@ -438,9 +447,9 @@ def evals_result(self):

evals_result = clf.evals_result()

The variable evals_result will contain:
The variable **evals_result** will contain:

.. code-block:: none
.. code-block:: python

{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
Expand Down Expand Up @@ -492,7 +501,7 @@ def __init__(self, max_depth=3, learning_rate=0.1,

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Expand Down Expand Up @@ -535,6 +544,14 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:

.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]
"""
evals_result = {}
self.classes_ = np.unique(y)
Expand Down Expand Up @@ -592,7 +609,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=None)
verbose_eval=verbose, xgb_model=None,
callbacks=callbacks)

self.objective = xgb_options["objective"]
if evals_result:
Expand Down Expand Up @@ -705,10 +723,10 @@ def predict_proba(self, data, ntree_limit=None, validate_features=True):
def evals_result(self):
"""Return the evaluation results.

If eval_set is passed to the `fit` function, you can call evals_result() to
get evaluation results for all passed eval_sets. When eval_metric is also
passed to the `fit` function, the evals_result will contain the eval_metrics
passed to the `fit` function
If **eval_set** is passed to the `fit` function, you can call
``evals_result()`` to get evaluation results for all passed **eval_sets**.
When **eval_metric** is also passed to the `fit` function, the
**evals_result** will contain the **eval_metrics** passed to the `fit` function.

Returns
-------
Expand All @@ -730,9 +748,9 @@ def evals_result(self):

evals_result = clf.evals_result()

The variable ``evals_result`` will contain
The variable **evals_result** will contain

.. code-block:: none
.. code-block:: python

{'validation_0': {'logloss': ['0.604835', '0.531479']},
'validation_1': {'logloss': ['0.41965', '0.17686']}}
Expand Down Expand Up @@ -771,9 +789,9 @@ class XGBRanker(XGBModel):
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread)
Number of parallel threads used to run xgboost. (replaces ``nthread``)
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Expand Down Expand Up @@ -816,8 +834,12 @@ class XGBRanker(XGBModel):
----
A custom objective function is currently not supported by XGBRanker.

Group information is required for ranking tasks. Before fitting the model, your data need to
be sorted by group. When fitting the model, you need to provide an additional array that
Note
----
Group information is required for ranking tasks.

Before fitting the model, your data need to be sorted by group. When
fitting the model, you need to provide an additional array that
contains the size of each group.

For example, if your original data look like:
Expand Down Expand Up @@ -863,7 +885,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,

def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None, early_stopping_rounds=None,
verbose=False, xgb_model=None):
verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -911,6 +933,14 @@ def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:

.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]
"""
# check if group information is provided
if group is None:
Expand Down Expand Up @@ -963,7 +993,8 @@ def _dmat_init(group, **params):
self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model)
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

self.objective = params["objective"]

Expand Down
41 changes: 24 additions & 17 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,34 +137,35 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Whether to maximize feval.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
every **early_stopping_rounds** round(s) to continue training.
Requires at least one item in **evals**.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
``bst.best_score``, ``bst.best_iteration`` and ``bst.best_ntree_limit``.
(Use ``bst.best_ntree_limit`` to get the correct value if
``num_parallel_tree`` and/or ``num_class`` appears in the parameters)
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist.

Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
a parameter containing ('eval_metric': 'logloss'), the **evals_result**
returns
Example: with a watchlist containing
``[(dtest,'eval'), (dtrain,'train')]`` and
a parameter containing ``('eval_metric': 'logloss')``,
the **evals_result** returns

.. code-block:: none
.. code-block:: python

{'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}

verbose_eval : bool or int
Requires at least one item in evals.
Requires at least one item in **evals**.
If **verbose_eval** is True then the evaluation metric on the validation set is
printed at each boosting stage.
If **verbose_eval** is an integer then the evaluation metric on the validation set
is printed at every given **verbose_eval** boosting stage. The last boosting stage
/ the boosting stage found by using **early_stopping_rounds** is also printed.
Example: with ``verbose_eval=4`` and at least one item in evals, an evaluation metric
Example: with ``verbose_eval=4`` and at least one item in **evals**, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage.
learning_rates: list or function (deprecated - use callback API instead)
List of learning rate for each boosting round
Expand All @@ -175,12 +176,17 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Xgb model to be loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example: [xgb.callback.reset_learning_rate(custom_rates)]
It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`.
Example:

.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]

Returns
-------
booster : a trained booster model
Booster : a trained booster model
"""
callbacks = [] if callbacks is None else callbacks

Expand Down Expand Up @@ -334,7 +340,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
folds : a KFold or StratifiedKFold instance or list of fold indices
Sklearn KFolds or StratifiedKFolds object.
Alternatively may explicitly pass sample indices for each fold.
For ``n`` folds, ``folds`` should be a length ``n`` list of tuples.
For ``n`` folds, **folds** should be a length ``n`` list of tuples.
Each tuple is ``(in,out)`` where ``in`` is a list of indices to be used
as the training samples for the ``n`` th fold and ``out`` is a list of
indices to be used as the testing samples for the ``n`` th fold.
Expand Down Expand Up @@ -368,10 +374,11 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`.
Example:

.. code-block:: none
.. code-block:: python

[xgb.callback.reset_learning_rate(custom_rates)]
shuffle : bool
Expand Down