From 8eab966998b54ddd7076b74b597b0c9cc4bf323c Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 21 Mar 2019 11:38:16 +1300 Subject: [PATCH] Allow unique prediction vector for each input matrix (#4275) --- src/learner.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 1269a6b2b..d21f07147 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -485,10 +485,10 @@ class LearnerImpl : public Learner { this->PerformTreeMethodHeuristic(train); monitor_.Start("PredictRaw"); - this->PredictRaw(train, &preds_); + this->PredictRaw(train, &preds_[train]); monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); - obj_->GetGradient(preds_, train->Info(), iter, &gpair_); + obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); gbm_->DoBoost(train, &gpair_, obj_.get()); monitor_.Stop("UpdateOneIter"); @@ -520,11 +520,12 @@ class LearnerImpl : public Learner { metrics_.back()->Configure(cfg_.begin(), cfg_.end()); } for (size_t i = 0; i < data_sets.size(); ++i) { - this->PredictRaw(data_sets[i], &preds_); - obj_->EvalTransform(&preds_); + DMatrix * dmat = data_sets[i]; + this->PredictRaw(data_sets[i], &preds_[dmat]); + obj_->EvalTransform(&preds_[dmat]); for (auto& ev : metrics_) { os << '\t' << data_names[i] << '-' << ev->Name() << ':' - << ev->Eval(preds_, data_sets[i]->Info(), + << ev->Eval(preds_[dmat], data_sets[i]->Info(), tparam_.dsplit == DataSplitMode::kRow); } } @@ -565,10 +566,10 @@ class LearnerImpl : public Learner { std::string metric) { if (metric == "auto") metric = obj_->DefaultEvalMetric(); std::unique_ptr ev(Metric::Create(metric.c_str())); - this->PredictRaw(data, &preds_); - obj_->EvalTransform(&preds_); + this->PredictRaw(data, &preds_[data]); + obj_->EvalTransform(&preds_[data]); return std::make_pair(metric, - ev->Eval(preds_, data->Info(), + ev->Eval(preds_[data], data->Info(), tparam_.dsplit == DataSplitMode::kRow)); } @@ -771,7 +772,7 @@ class LearnerImpl : public Learner { // name of objective function std::string name_obj_; // temporal storages for prediction - HostDeviceVector preds_; + std::map> preds_; // gradient pairs HostDeviceVector gpair_;