Skip to content

Commit

Permalink
Fix dmlc#3663: Allow sklearn API to use callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
mxxun committed Sep 7, 2018
1 parent 5a8bbb3 commit 306edef
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def load_model(self, fname):

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
sample_weight_eval_set=None, callbacks=None):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -285,6 +285,14 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example:
.. code-block:: none
[xgb.callback.reset_learning_rate(custom_rates)]
"""
if sample_weight is not None:
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
Expand Down Expand Up @@ -325,7 +333,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model)
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

if evals_result:
for val in evals_result.items():
Expand Down Expand Up @@ -492,7 +501,7 @@ def __init__(self, max_depth=3, learning_rate=0.1,

def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
sample_weight_eval_set=None):
sample_weight_eval_set=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Expand Down Expand Up @@ -535,6 +544,14 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example:
.. code-block:: none
[xgb.callback.reset_learning_rate(custom_rates)]
"""
evals_result = {}
self.classes_ = np.unique(y)
Expand Down Expand Up @@ -592,7 +609,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=None)
verbose_eval=verbose, xgb_model=None,
callbacks=callbacks)

self.objective = xgb_options["objective"]
if evals_result:
Expand Down Expand Up @@ -863,7 +881,7 @@ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,

def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None, early_stopping_rounds=None,
verbose=False, xgb_model=None):
verbose=False, xgb_model=None, callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit the gradient boosting model
Expand Down Expand Up @@ -911,6 +929,14 @@ def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example:
.. code-block:: none
[xgb.callback.reset_learning_rate(custom_rates)]
"""
# check if group information is provided
if group is None:
Expand Down Expand Up @@ -963,7 +989,8 @@ def _dmat_init(group, **params):
self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model)
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

self.objective = params["objective"]

Expand Down

0 comments on commit 306edef

Please sign in to comment.