diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 68c06a7ae..f0589cd2e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -783,7 +783,7 @@ class XGBModel(XGBModelBase): """ Predict with `X`. - .. note:: This function is only thread safe for `gbtree` + .. note:: This function is only thread safe for `gbtree` and `dart`. Parameters ---------- @@ -813,6 +813,24 @@ class XGBModel(XGBModelBase): self.get_booster(), ntree_limit, iteration_range ) iteration_range = self._get_iteration_range(iteration_range) + if self._can_use_inplace_predict(): + try: + predts = self.get_booster().inplace_predict( + data=X, + iteration_range=iteration_range, + predict_type="margin" if output_margin else "value", + missing=self.missing, + base_margin=base_margin, + validate_features=validate_features, + ) + if _is_cupy_array(predts): + import cupy # pylint: disable=import-error + predts = cupy.asnumpy(predts) # ensure numpy array is used. + return predts + except TypeError: + # coo, csc, dt + pass + test = DMatrix( X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs ) @@ -1217,7 +1235,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): ) -> np.ndarray: """ Predict the probability of each `X` example being of a given class. - .. note:: This function is only thread safe for `gbtree` + .. note:: This function is only thread safe for `gbtree` and `dart`. Parameters ----------