Use inplace predict for sklearn. (#6718)
* Use inplace predict for sklearn when possible.
This commit is contained in:
parent
25077564ab
commit
872e559b91
@ -783,7 +783,7 @@ class XGBModel(XGBModelBase):
|
||||
"""
|
||||
Predict with `X`.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree`
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -813,6 +813,24 @@ class XGBModel(XGBModelBase):
|
||||
self.get_booster(), ntree_limit, iteration_range
|
||||
)
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
if self._can_use_inplace_predict():
|
||||
try:
|
||||
predts = self.get_booster().inplace_predict(
|
||||
data=X,
|
||||
iteration_range=iteration_range,
|
||||
predict_type="margin" if output_margin else "value",
|
||||
missing=self.missing,
|
||||
base_margin=base_margin,
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if _is_cupy_array(predts):
|
||||
import cupy # pylint: disable=import-error
|
||||
predts = cupy.asnumpy(predts) # ensure numpy array is used.
|
||||
return predts
|
||||
except TypeError:
|
||||
# coo, csc, dt
|
||||
pass
|
||||
|
||||
test = DMatrix(
|
||||
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
|
||||
)
|
||||
@ -1217,7 +1235,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
) -> np.ndarray:
|
||||
""" Predict the probability of each `X` example being of a given class.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree`
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user