diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 1d00f7754696..80754baaca53 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -273,6 +273,84 @@ actual computation will return a coroutine and hence require awaiting: # Use `client.compute` instead of the `compute` method from dask collection print(await client.compute(prediction)) +***************************** +Evaluation and Early Stopping +***************************** + +.. versionadded:: 1.3.0 + +The Dask interface allows the use of validation sets that are stored in distributed collections (Dask DataFrame or Dask Array). These can be used for evaluation and early stopping. + +To enable early stopping, pass one or more validation sets containing ``DaskDMatrix`` objects. + +.. code-block:: python + + import dask.array as da + import xgboost as xgb + + num_rows = 1e6 + num_features = 100 + num_partitions = 10 + rows_per_chunk = num_rows / num_partitions + + data = da.random.random( + size=(num_rows, num_features), + chunks=(rows_per_chunk, num_features) + ) + + labels = da.random.random( + size=(num_rows, 1), + chunks=(rows_per_chunk, 1) + ) + + X_eval = da.random.random( + size=(num_rows, num_features), + chunks=(rows_per_chunk, num_features) + ) + + y_eval = da.random.random( + size=(num_rows, 1), + chunks=(rows_per_chunk, 1) + ) + + dtrain = xgb.dask.DaskDMatrix( + client=client, + data=data, + label=labels + ) + + dvalid = xgb.dask.DaskDMatrix( + client=client, + data=X_eval, + label=y_eval + ) + + result = xgb.dask.train( + client=client, + params={ + "objective": "reg:squarederror", + }, + dtrain=dtrain, + num_boost_round=10, + evals=[(dvalid, "valid1")], + early_stopping_rounds=3 + ) + +When validation sets are provided to ``xgb.dask.train()`` in this way, the model object returned by ``xgb.dask.train()`` contains a history of evaluation metrics for each validation set, across all boosting rounds. + +.. code-block:: python + + print(result["history"]) + # {'valid1': OrderedDict([('rmse', [0.28857, 0.28858, 0.288592, 0.288598])])} + +If early stopping is enabled by also passing ``early_stopping_rounds``, you can check the best iteration in the returned booster. + +.. code-block:: python + + booster = result["booster"] + print(booster.best_iteration) + best_model = booster[: booster.best_iteration] + ***************************************************************************** Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors *****************************************************************************