Use inplace predict for sklearn. (#6718)

* Use inplace predict for sklearn when possible.
This commit is contained in:
Jiaming Yuan 2021-02-22 12:27:04 +08:00 committed by GitHub
parent 25077564ab
commit 872e559b91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -783,7 +783,7 @@ class XGBModel(XGBModelBase):
"""
Predict with `X`.
.. note:: This function is only thread safe for `gbtree`
.. note:: This function is only thread safe for `gbtree` and `dart`.
Parameters
----------
@ -813,6 +813,24 @@ class XGBModel(XGBModelBase):
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
if self._can_use_inplace_predict():
try:
predts = self.get_booster().inplace_predict(
data=X,
iteration_range=iteration_range,
predict_type="margin" if output_margin else "value",
missing=self.missing,
base_margin=base_margin,
validate_features=validate_features,
)
if _is_cupy_array(predts):
import cupy # pylint: disable=import-error
predts = cupy.asnumpy(predts) # ensure numpy array is used.
return predts
except TypeError:
# coo, csc, dt
pass
test = DMatrix(
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
)
@ -1217,7 +1235,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
) -> np.ndarray:
""" Predict the probability of each `X` example being of a given class.
.. note:: This function is only thread safe for `gbtree`
.. note:: This function is only thread safe for `gbtree` and `dart`.
Parameters
----------