Add predictor to skl constructor. (#7000)
This commit is contained in:
parent
55b823b27d
commit
816b789bf0
@ -374,6 +374,7 @@ class XGBModel(XGBModelBase):
|
||||
importance_type: str = "gain",
|
||||
gpu_id: Optional[int] = None,
|
||||
validate_parameters: Optional[bool] = None,
|
||||
predictor: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if not SKLEARN_INSTALLED:
|
||||
@ -409,6 +410,7 @@ class XGBModel(XGBModelBase):
|
||||
self.importance_type = importance_type
|
||||
self.gpu_id = gpu_id
|
||||
self.validate_parameters = validate_parameters
|
||||
self.predictor = predictor
|
||||
|
||||
def _more_tags(self) -> Dict[str, bool]:
|
||||
'''Tags used for scikit-learn data validation.'''
|
||||
|
||||
@ -803,7 +803,11 @@ class Dart : public GBTree {
|
||||
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
|
||||
&predts, i, i + 1);
|
||||
device = predts.predictions.DeviceIdx();
|
||||
CHECK(success) << msg;
|
||||
CHECK(success) << msg << std::endl
|
||||
<< "Current Predictor: "
|
||||
<< (tparam_.predictor == PredictorType::kCPUPredictor
|
||||
? "cpu_predictor"
|
||||
: "gpu_predictor");
|
||||
}
|
||||
|
||||
auto w = this->weight_drop_.at(i);
|
||||
|
||||
@ -291,7 +291,11 @@ class GBTree : public GradientBooster {
|
||||
} else {
|
||||
bool success = this->GetPredictor()->InplacePredict(
|
||||
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
CHECK(success) << msg;
|
||||
CHECK(success) << msg << std::endl
|
||||
<< "Current Predictor: "
|
||||
<< (tparam_.predictor == PredictorType::kCPUPredictor
|
||||
? "cpu_predictor"
|
||||
: "gpu_predictor");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user