Simplify inplace-predict. (#7910)

Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
Jiaming Yuan
2022-05-18 17:52:00 +08:00
committed by GitHub
parent 19775ffe15
commit 765097d514
17 changed files with 317 additions and 297 deletions

View File

@@ -1277,15 +1277,12 @@ class LearnerImpl : public LearnerIO {
return (*LearnerAPIThreadLocalStore::Get())[this];
}
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
PredictionType type, float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t iteration_begin,
void InplacePredict(std::shared_ptr<DMatrix> p_m, PredictionType type, float missing,
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
uint32_t iteration_end) override {
this->Configure();
auto& out_predictions = this->GetThreadLocal().prediction_entry;
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
iteration_begin, iteration_end);
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
if (type == PredictionType::kValue) {
obj_->PredTransform(&out_predictions.predictions);
} else if (type == PredictionType::kMargin) {