Skip to content

Commit

Permalink
Backport note about predict() behavior of DART booster
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Sep 5, 2018
1 parent a8d815f commit b1233ef
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
20 changes: 16 additions & 4 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,10 +996,22 @@ def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
"""
Predict with data.
NOTE: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call bst.copy() to make copies
of model object and then call predict
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``bst.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------
Expand Down
22 changes: 18 additions & 4 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,24 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
def predict(self, data, output_margin=False, ntree_limit=0):
"""
Predict with `data`.
NOTE: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call xgb.copy() to make copies
of model object and then call predict
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------
data : DMatrix
Expand Down

0 comments on commit b1233ef

Please sign in to comment.