Allow unique prediction vector for each input matrix (#4275)
This commit is contained in:
parent
09bd9e68cf
commit
8eab966998
@ -485,10 +485,10 @@ class LearnerImpl : public Learner {
|
|||||||
this->PerformTreeMethodHeuristic(train);
|
this->PerformTreeMethodHeuristic(train);
|
||||||
|
|
||||||
monitor_.Start("PredictRaw");
|
monitor_.Start("PredictRaw");
|
||||||
this->PredictRaw(train, &preds_);
|
this->PredictRaw(train, &preds_[train]);
|
||||||
monitor_.Stop("PredictRaw");
|
monitor_.Stop("PredictRaw");
|
||||||
monitor_.Start("GetGradient");
|
monitor_.Start("GetGradient");
|
||||||
obj_->GetGradient(preds_, train->Info(), iter, &gpair_);
|
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
|
||||||
monitor_.Stop("GetGradient");
|
monitor_.Stop("GetGradient");
|
||||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||||
monitor_.Stop("UpdateOneIter");
|
monitor_.Stop("UpdateOneIter");
|
||||||
@ -520,11 +520,12 @@ class LearnerImpl : public Learner {
|
|||||||
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||||
this->PredictRaw(data_sets[i], &preds_);
|
DMatrix * dmat = data_sets[i];
|
||||||
obj_->EvalTransform(&preds_);
|
this->PredictRaw(data_sets[i], &preds_[dmat]);
|
||||||
|
obj_->EvalTransform(&preds_[dmat]);
|
||||||
for (auto& ev : metrics_) {
|
for (auto& ev : metrics_) {
|
||||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||||
<< ev->Eval(preds_, data_sets[i]->Info(),
|
<< ev->Eval(preds_[dmat], data_sets[i]->Info(),
|
||||||
tparam_.dsplit == DataSplitMode::kRow);
|
tparam_.dsplit == DataSplitMode::kRow);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -565,10 +566,10 @@ class LearnerImpl : public Learner {
|
|||||||
std::string metric) {
|
std::string metric) {
|
||||||
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
if (metric == "auto") metric = obj_->DefaultEvalMetric();
|
||||||
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
|
||||||
this->PredictRaw(data, &preds_);
|
this->PredictRaw(data, &preds_[data]);
|
||||||
obj_->EvalTransform(&preds_);
|
obj_->EvalTransform(&preds_[data]);
|
||||||
return std::make_pair(metric,
|
return std::make_pair(metric,
|
||||||
ev->Eval(preds_, data->Info(),
|
ev->Eval(preds_[data], data->Info(),
|
||||||
tparam_.dsplit == DataSplitMode::kRow));
|
tparam_.dsplit == DataSplitMode::kRow));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -771,7 +772,7 @@ class LearnerImpl : public Learner {
|
|||||||
// name of objective function
|
// name of objective function
|
||||||
std::string name_obj_;
|
std::string name_obj_;
|
||||||
// temporal storages for prediction
|
// temporal storages for prediction
|
||||||
HostDeviceVector<bst_float> preds_;
|
std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
|
||||||
// gradient pairs
|
// gradient pairs
|
||||||
HostDeviceVector<GradientPair> gpair_;
|
HostDeviceVector<GradientPair> gpair_;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user