-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Dask] Add example of using custom callback in Dask
- Loading branch information
Showing
1 changed file
with
83 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""Example of using callbacks in Dask""" | ||
import tempfile | ||
import math | ||
import numpy as np | ||
import xgboost as xgb | ||
from xgboost.dask import DaskDMatrix | ||
from dask.distributed import Client | ||
from dask.distributed import LocalCluster | ||
from dask import array as da | ||
from dask_ml.datasets import make_regression | ||
from dask_ml.model_selection import train_test_split | ||
|
||
|
||
def probability_for_going_backward(epoch): | ||
return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch)) | ||
|
||
|
||
# All callback functions must inherit from TrainingCallback | ||
class CustomEarlyStopping(xgb.callback.TrainingCallback): | ||
"""A custom early stopping class where early stopping is determined stochastically. | ||
In the beginning, allow the metric to become worse with a probability of 0.8. | ||
As boosting progresses, the probability should be adjusted downward""" | ||
def __init__(self, *, validation_set, target_metric, maximize, seed): | ||
self.validation_set = validation_set | ||
self.target_metric = target_metric | ||
self.maximize = maximize | ||
self.seed = seed | ||
self.rng = np.random.default_rng(seed=seed) | ||
if maximize: | ||
self.better = lambda x, y: x > y | ||
else: | ||
self.better = lambda x, y: x < y | ||
|
||
def after_iteration(self, model, epoch, evals_log): | ||
metric_history = evals_log[self.validation_set][self.target_metric] | ||
if len(metric_history) < 2 or self.better(metric_history[-1], metric_history[-2]): | ||
return False # continue training | ||
p = probability_for_going_backward(epoch) | ||
go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype(np.bool)[0] | ||
print('The validation metric went into the wrong direction. ' | ||
+ f'Stopping training with probability {1 - p}...') | ||
if go_backward: | ||
return False # continue training | ||
else: | ||
return True # stop training | ||
|
||
|
||
def main(client): | ||
m = 100000 | ||
n = 100 | ||
X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||
|
||
dtrain = DaskDMatrix(client, X_train, y_train) | ||
dtest = DaskDMatrix(client, X_test, y_test) | ||
|
||
# Use train method from xgboost.dask instead of xgboost. This | ||
# distributed version of train returns a dictionary containing the | ||
# resulting booster and evaluation history obtained from | ||
# evaluation metrics. | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
output = xgb.dask.train(client, | ||
{'verbosity': 1, | ||
'tree_method': 'hist', | ||
'objective': 'reg:squarederror', | ||
'eval_metric': 'rmse', | ||
'max_depth': 6, | ||
'learning_rate': 1.0}, | ||
dtrain, | ||
num_boost_round=1000, | ||
evals=[(dtrain, 'train'), (dtest, 'test')], | ||
callbacks=[CustomEarlyStopping( | ||
validation_set='test', | ||
target_metric='rmse', | ||
maximize=False, | ||
seed=0)]) | ||
|
||
|
||
if __name__ == '__main__': | ||
# or use other clusters for scaling | ||
with LocalCluster(n_workers=4, threads_per_worker=1) as cluster: | ||
with Client(cluster) as client: | ||
main(client) |