Add predictor to skl constructor. (#7000)
This commit is contained in:
parent
55b823b27d
commit
816b789bf0
@ -374,6 +374,7 @@ class XGBModel(XGBModelBase):
|
|||||||
importance_type: str = "gain",
|
importance_type: str = "gain",
|
||||||
gpu_id: Optional[int] = None,
|
gpu_id: Optional[int] = None,
|
||||||
validate_parameters: Optional[bool] = None,
|
validate_parameters: Optional[bool] = None,
|
||||||
|
predictor: Optional[str] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
@ -409,6 +410,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.importance_type = importance_type
|
self.importance_type = importance_type
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.validate_parameters = validate_parameters
|
self.validate_parameters = validate_parameters
|
||||||
|
self.predictor = predictor
|
||||||
|
|
||||||
def _more_tags(self) -> Dict[str, bool]:
|
def _more_tags(self) -> Dict[str, bool]:
|
||||||
'''Tags used for scikit-learn data validation.'''
|
'''Tags used for scikit-learn data validation.'''
|
||||||
|
|||||||
@ -803,7 +803,11 @@ class Dart : public GBTree {
|
|||||||
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
|
bool success = predictor->InplacePredict(x, nullptr, model_, missing,
|
||||||
&predts, i, i + 1);
|
&predts, i, i + 1);
|
||||||
device = predts.predictions.DeviceIdx();
|
device = predts.predictions.DeviceIdx();
|
||||||
CHECK(success) << msg;
|
CHECK(success) << msg << std::endl
|
||||||
|
<< "Current Predictor: "
|
||||||
|
<< (tparam_.predictor == PredictorType::kCPUPredictor
|
||||||
|
? "cpu_predictor"
|
||||||
|
: "gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto w = this->weight_drop_.at(i);
|
auto w = this->weight_drop_.at(i);
|
||||||
|
|||||||
@ -291,7 +291,11 @@ class GBTree : public GradientBooster {
|
|||||||
} else {
|
} else {
|
||||||
bool success = this->GetPredictor()->InplacePredict(
|
bool success = this->GetPredictor()->InplacePredict(
|
||||||
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||||
CHECK(success) << msg;
|
CHECK(success) << msg << std::endl
|
||||||
|
<< "Current Predictor: "
|
||||||
|
<< (tparam_.predictor == PredictorType::kCPUPredictor
|
||||||
|
? "cpu_predictor"
|
||||||
|
: "gpu_predictor");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user