diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index a34d094dc..4b06d92f1 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -374,6 +374,7 @@ class XGBModel(XGBModelBase): importance_type: str = "gain", gpu_id: Optional[int] = None, validate_parameters: Optional[bool] = None, + predictor: Optional[str] = None, **kwargs: Any ) -> None: if not SKLEARN_INSTALLED: @@ -409,6 +410,7 @@ class XGBModel(XGBModelBase): self.importance_type = importance_type self.gpu_id = gpu_id self.validate_parameters = validate_parameters + self.predictor = predictor def _more_tags(self) -> Dict[str, bool]: '''Tags used for scikit-learn data validation.''' diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index a4e34869d..4b10f854d 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -803,7 +803,11 @@ class Dart : public GBTree { bool success = predictor->InplacePredict(x, nullptr, model_, missing, &predts, i, i + 1); device = predts.predictions.DeviceIdx(); - CHECK(success) << msg; + CHECK(success) << msg << std::endl + << "Current Predictor: " + << (tparam_.predictor == PredictorType::kCPUPredictor + ? "cpu_predictor" + : "gpu_predictor"); } auto w = this->weight_drop_.at(i); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index e4a0a53f6..d948c731e 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -291,7 +291,11 @@ class GBTree : public GradientBooster { } else { bool success = this->GetPredictor()->InplacePredict( x, p_m, model_, missing, out_preds, tree_begin, tree_end); - CHECK(success) << msg; + CHECK(success) << msg << std::endl + << "Current Predictor: " + << (tparam_.predictor == PredictorType::kCPUPredictor + ? "cpu_predictor" + : "gpu_predictor"); } }