Add predictor to skl constructor. (#7000)

This commit is contained in:
Jiaming Yuan 2021-05-29 04:52:56 +08:00 committed by GitHub
parent 55b823b27d
commit 816b789bf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 2 deletions

View File

@ -374,6 +374,7 @@ class XGBModel(XGBModelBase):
importance_type: str = "gain", importance_type: str = "gain",
gpu_id: Optional[int] = None, gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None, validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
@ -409,6 +410,7 @@ class XGBModel(XGBModelBase):
self.importance_type = importance_type self.importance_type = importance_type
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.validate_parameters = validate_parameters self.validate_parameters = validate_parameters
self.predictor = predictor
def _more_tags(self) -> Dict[str, bool]: def _more_tags(self) -> Dict[str, bool]:
'''Tags used for scikit-learn data validation.''' '''Tags used for scikit-learn data validation.'''

View File

@ -803,7 +803,11 @@ class Dart : public GBTree {
bool success = predictor->InplacePredict(x, nullptr, model_, missing, bool success = predictor->InplacePredict(x, nullptr, model_, missing,
&predts, i, i + 1); &predts, i, i + 1);
device = predts.predictions.DeviceIdx(); device = predts.predictions.DeviceIdx();
CHECK(success) << msg; CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor
? "cpu_predictor"
: "gpu_predictor");
} }
auto w = this->weight_drop_.at(i); auto w = this->weight_drop_.at(i);

View File

@ -291,7 +291,11 @@ class GBTree : public GradientBooster {
} else { } else {
bool success = this->GetPredictor()->InplacePredict( bool success = this->GetPredictor()->InplacePredict(
x, p_m, model_, missing, out_preds, tree_begin, tree_end); x, p_m, model_, missing, out_preds, tree_begin, tree_end);
CHECK(success) << msg; CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor
? "cpu_predictor"
: "gpu_predictor");
} }
} }