Simplify inplace-predict. (#7910)
Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
@@ -1277,15 +1277,12 @@ class LearnerImpl : public LearnerIO {
|
||||
return (*LearnerAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t iteration_begin,
|
||||
void InplacePredict(std::shared_ptr<DMatrix> p_m, PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
|
||||
uint32_t iteration_end) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
|
||||
iteration_begin, iteration_end);
|
||||
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
|
||||
if (type == PredictionType::kValue) {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == PredictionType::kMargin) {
|
||||
|
||||
Reference in New Issue
Block a user