Allow unique prediction vector for each input matrix (#4275)

This commit is contained in:
Rory Mitchell 2019-03-21 11:38:16 +13:00 committed by GitHub
parent 09bd9e68cf
commit 8eab966998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -485,10 +485,10 @@ class LearnerImpl : public Learner {
this->PerformTreeMethodHeuristic(train); this->PerformTreeMethodHeuristic(train);
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_); this->PredictRaw(train, &preds_[train]);
monitor_.Stop("PredictRaw"); monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient"); monitor_.Start("GetGradient");
obj_->GetGradient(preds_, train->Info(), iter, &gpair_); obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient"); monitor_.Stop("GetGradient");
gbm_->DoBoost(train, &gpair_, obj_.get()); gbm_->DoBoost(train, &gpair_, obj_.get());
monitor_.Stop("UpdateOneIter"); monitor_.Stop("UpdateOneIter");
@ -520,11 +520,12 @@ class LearnerImpl : public Learner {
metrics_.back()->Configure(cfg_.begin(), cfg_.end()); metrics_.back()->Configure(cfg_.begin(), cfg_.end());
} }
for (size_t i = 0; i < data_sets.size(); ++i) { for (size_t i = 0; i < data_sets.size(); ++i) {
this->PredictRaw(data_sets[i], &preds_); DMatrix * dmat = data_sets[i];
obj_->EvalTransform(&preds_); this->PredictRaw(data_sets[i], &preds_[dmat]);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) { for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':' os << '\t' << data_names[i] << '-' << ev->Name() << ':'
<< ev->Eval(preds_, data_sets[i]->Info(), << ev->Eval(preds_[dmat], data_sets[i]->Info(),
tparam_.dsplit == DataSplitMode::kRow); tparam_.dsplit == DataSplitMode::kRow);
} }
} }
@ -565,10 +566,10 @@ class LearnerImpl : public Learner {
std::string metric) { std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric(); if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str())); std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
this->PredictRaw(data, &preds_); this->PredictRaw(data, &preds_[data]);
obj_->EvalTransform(&preds_); obj_->EvalTransform(&preds_[data]);
return std::make_pair(metric, return std::make_pair(metric,
ev->Eval(preds_, data->Info(), ev->Eval(preds_[data], data->Info(),
tparam_.dsplit == DataSplitMode::kRow)); tparam_.dsplit == DataSplitMode::kRow));
} }
@ -771,7 +772,7 @@ class LearnerImpl : public Learner {
// name of objective function // name of objective function
std::string name_obj_; std::string name_obj_;
// temporal storages for prediction // temporal storages for prediction
HostDeviceVector<bst_float> preds_; std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
// gradient pairs // gradient pairs
HostDeviceVector<GradientPair> gpair_; HostDeviceVector<GradientPair> gpair_;