Allow unique prediction vector for each input matrix (#4275)
This commit is contained in:
parent
09bd9e68cf
commit
8eab966998
@ -485,10 +485,10 @@ class LearnerImpl : public Learner {
|
||||
this->PerformTreeMethodHeuristic(train);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_);
|
||||
this->PredictRaw(train, &preds_[train]);
|
||||
monitor_.Stop("PredictRaw");
|
||||
monitor_.Start("GetGradient");
|
||||
obj_->GetGradient(preds_, train->Info(), iter, &gpair_);
|
||||
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
|
||||
monitor_.Stop("GetGradient");
|
||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||
monitor_.Stop("UpdateOneIter");
|
||||
@ -520,11 +520,12 @@ class LearnerImpl : public Learner {
|
||||
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
this->PredictRaw(data_sets[i], &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
DMatrix * dmat = data_sets[i];
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat]);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
<< ev->Eval(preds_, data_sets[i]->Info(),
|
||||
<< ev->Eval(preds_[dmat], data_sets[i]->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow);
|
||||
}
|
||||
}
|
||||
@ -565,10 +566,10 @@ class LearnerImpl : public Learner {
|
||||
std::string metric) {
|
||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||
this->PredictRaw(data, &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
this->PredictRaw(data, &preds_[data]);
|
||||
obj_->EvalTransform(&preds_[data]);
|
||||
return std::make_pair(metric,
|
||||
ev->Eval(preds_, data->Info(),
|
||||
ev->Eval(preds_[data], data->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow));
|
||||
}
|
||||
|
||||
@ -771,7 +772,7 @@ class LearnerImpl : public Learner {
|
||||
// name of objective function
|
||||
std::string name_obj_;
|
||||
// temporal storages for prediction
|
||||
HostDeviceVector<bst_float> preds_;
|
||||
std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
|
||||
// gradient pairs
|
||||
HostDeviceVector<GradientPair> gpair_;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user