Implement GPU predict leaf. (#6187)

This commit is contained in:
Jiaming Yuan
2020-11-11 17:33:47 +08:00
committed by GitHub
parent 7f101d1b33
commit 8a17610666
12 changed files with 252 additions and 42 deletions

View File

@@ -1106,7 +1106,7 @@ class LearnerImpl : public LearnerIO {
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
} else {
auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);