Backport note about predict() behavior of DART booster

This commit is contained in:
Philip Cho 2018-09-05 12:30:21 -07:00
parent a8d815fc1e
commit b1233ef2ae
No known key found for this signature in database
GPG Key ID: A758FA046E1F6BB8
2 changed files with 34 additions and 8 deletions

View File

@ -996,10 +996,22 @@ class Booster(object):
"""
Predict with data.
NOTE: This function is not thread safe.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call bst.copy() to make copies
of model object and then call predict
If you want to run prediction using multiple thread, call ``bst.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------

View File

@ -578,10 +578,24 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def predict(self, data, output_margin=False, ntree_limit=0):
"""
Predict with `data`.
NOTE: This function is not thread safe.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call xgb.copy() to make copies
of model object and then call predict
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters
----------
data : DMatrix