Use inplace predict for sklearn. (#6718)
* Use inplace predict for sklearn when possible.
This commit is contained in:
parent
25077564ab
commit
872e559b91
@ -783,7 +783,7 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
Predict with `X`.
|
Predict with `X`.
|
||||||
|
|
||||||
.. note:: This function is only thread safe for `gbtree`
|
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -813,6 +813,24 @@ class XGBModel(XGBModelBase):
|
|||||||
self.get_booster(), ntree_limit, iteration_range
|
self.get_booster(), ntree_limit, iteration_range
|
||||||
)
|
)
|
||||||
iteration_range = self._get_iteration_range(iteration_range)
|
iteration_range = self._get_iteration_range(iteration_range)
|
||||||
|
if self._can_use_inplace_predict():
|
||||||
|
try:
|
||||||
|
predts = self.get_booster().inplace_predict(
|
||||||
|
data=X,
|
||||||
|
iteration_range=iteration_range,
|
||||||
|
predict_type="margin" if output_margin else "value",
|
||||||
|
missing=self.missing,
|
||||||
|
base_margin=base_margin,
|
||||||
|
validate_features=validate_features,
|
||||||
|
)
|
||||||
|
if _is_cupy_array(predts):
|
||||||
|
import cupy # pylint: disable=import-error
|
||||||
|
predts = cupy.asnumpy(predts) # ensure numpy array is used.
|
||||||
|
return predts
|
||||||
|
except TypeError:
|
||||||
|
# coo, csc, dt
|
||||||
|
pass
|
||||||
|
|
||||||
test = DMatrix(
|
test = DMatrix(
|
||||||
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
|
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
|
||||||
)
|
)
|
||||||
@ -1217,7 +1235,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
""" Predict the probability of each `X` example being of a given class.
|
""" Predict the probability of each `X` example being of a given class.
|
||||||
|
|
||||||
.. note:: This function is only thread safe for `gbtree`
|
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user