From 872e559b91f6f201a4f690a024398412c128a976 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 22 Feb 2021 12:27:04 +0800 Subject: [PATCH] Use inplace predict for sklearn. (#6718) * Use inplace predict for sklearn when possible. --- python-package/xgboost/sklearn.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 68c06a7ae..f0589cd2e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -783,7 +783,7 @@ class XGBModel(XGBModelBase): """ Predict with `X`. - .. note:: This function is only thread safe for `gbtree` + .. note:: This function is only thread safe for `gbtree` and `dart`. Parameters ---------- @@ -813,6 +813,24 @@ class XGBModel(XGBModelBase): self.get_booster(), ntree_limit, iteration_range ) iteration_range = self._get_iteration_range(iteration_range) + if self._can_use_inplace_predict(): + try: + predts = self.get_booster().inplace_predict( + data=X, + iteration_range=iteration_range, + predict_type="margin" if output_margin else "value", + missing=self.missing, + base_margin=base_margin, + validate_features=validate_features, + ) + if _is_cupy_array(predts): + import cupy # pylint: disable=import-error + predts = cupy.asnumpy(predts) # ensure numpy array is used. + return predts + except TypeError: + # coo, csc, dt + pass + test = DMatrix( X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs ) @@ -1217,7 +1235,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): ) -> np.ndarray: """ Predict the probability of each `X` example being of a given class. - .. note:: This function is only thread safe for `gbtree` + .. note:: This function is only thread safe for `gbtree` and `dart`. Parameters ----------