Skip to content

Commit

Permalink
Use Predictor for dart. (#6693)
Browse files Browse the repository at this point in the history
* Use normal predictor for dart booster.
* Implement `inplace_predict` for dart.
* Enable `dart` for dask interface now that it's thread-safe.
* categorical data should be working out of box for dart now.

The implementation is not very efficient as it has to pull back the data and
apply weight for each tree, but still a significant improvement over previous
implementation as now we no longer binary search for each sample.

* Fix output prediction shape on dataframe.
  • Loading branch information
trivialfis authored Feb 9, 2021
1 parent dbf7e9d commit e8c5c53
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 179 deletions.
11 changes: 11 additions & 0 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ class Predictor {
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);

/**
* \brief Initialize output prediction
*
* \param info Meta info for the DMatrix object used for prediction.
* \param out_predt Prediction vector to be initialized.
* \param model Tree model used for prediction.
*/
virtual void InitOutPredictions(const MetaInfo &info,
HostDeviceVector<bst_float> *out_predt,
const gbm::GBTreeModel &model) const = 0;

/**
* \brief Generate batch predictions for a given feature matrix. May use
* cached predictions if available instead of calculating from scratch.
Expand Down
31 changes: 22 additions & 9 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ async def _train_async(
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)

if params.get("booster", None) is not None and params["booster"] != "gbtree":
if params.get("booster", None) == "gblinear":
raise NotImplementedError(
f"booster `{params['booster']}` is not yet supported for dask."
)
Expand Down Expand Up @@ -949,6 +949,15 @@ async def _direct_predict_impl(
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
# Without this check, dask will finish the prediction silently even if output
# dimension is greater than 3. But during map_partitions, dask passes a
# `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by
# `_convert_unknown_data` since dd.DataFrame is not known to xgboost native
# binding.
raise ValueError(
"Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions."
)
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions
Expand Down Expand Up @@ -1012,6 +1021,7 @@ def _infer_predict_output(
if kwargs.pop("predict_type") == "margin":
kwargs["output_margin"] = True
m = DMatrix(test_sample)
# generated DMatrix doesn't have feature name, so no validation.
test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {}
Expand Down Expand Up @@ -1098,6 +1108,7 @@ def mapped_predict(
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
strict_shape=strict_shape,
)
)
return await _direct_predict_impl(
Expand All @@ -1116,6 +1127,7 @@ def mapped_predict(
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
strict_shape=strict_shape,
)
)
# Prediction on dask DMatrix.
Expand Down Expand Up @@ -1206,10 +1218,9 @@ def predict( # pylint: disable=unused-argument
.. note::
Using ``inplace_predict`` might be faster when some features are not needed. See
:py:meth:`xgboost.Booster.predict` for details on various parameters. When using
``pred_interactions`` with mutli-class model, input should be ``da.Array`` or
``DaskDMatrix`` due to limitation in ``da.map_blocks``.
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output
has more than 2 dimensions (shap value, leaf with strict_shape), input should be
``da.Array`` or ``DaskDMatrix``.
.. versionadded:: 1.0.0
Expand All @@ -1233,8 +1244,8 @@ def predict( # pylint: disable=unused-argument
prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
'''
_assert_dask_support()
Expand Down Expand Up @@ -1297,6 +1308,7 @@ def mapped_predict(
inplace=True,
predict_type=predict_type,
iteration_range=iteration_range,
strict_shape=strict_shape,
)
)
return await _direct_predict_impl(
Expand Down Expand Up @@ -1352,8 +1364,9 @@ def inplace_predict( # pylint: disable=unused-argument
prediction :
When input data is ``dask.array.Array``, the return value is an array, when input
data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
"""
_assert_dask_support()
client = _xgb_get_client(client)
Expand Down
5 changes: 1 addition & 4 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,7 @@ def _can_use_inplace_predict(self) -> bool:
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
params = self.get_params()
booster = self.booster
if params.get("predictor", None) is None and (
booster is None or booster == "gbtree"
):
if params.get("predictor", None) is None and self.booster != "gblinear":
return True
return False

Expand Down
Loading

0 comments on commit e8c5c53

Please sign in to comment.