Porting elementwise metrics to GPU. (#3952)
* Port elementwise metrics to GPU. * All elementwise metrics are converted to static polymorphic. * Create a reducer for metrics reduction. * Remove const of Metric::Eval to accommodate CubMemory.
This commit is contained in:
@@ -310,6 +310,10 @@ class LearnerImpl : public Learner {
|
||||
if (obj_ != nullptr) {
|
||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
|
||||
for (auto& p_metric : metrics_) {
|
||||
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
}
|
||||
|
||||
void InitModel() override { this->LazyInitModel(); }
|
||||
@@ -407,6 +411,10 @@ class LearnerImpl : public Learner {
|
||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||
|
||||
for (auto& p_metric : metrics_) {
|
||||
p_metric->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
}
|
||||
|
||||
// rabit save model to rabit checkpoint
|
||||
@@ -503,13 +511,14 @@ class LearnerImpl : public Learner {
|
||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
this->PredictRaw(data_sets[i], &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
<< ev->Eval(preds_.ConstHostVector(), data_sets[i]->Info(),
|
||||
<< ev->Eval(preds_, data_sets[i]->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow);
|
||||
}
|
||||
}
|
||||
@@ -553,7 +562,7 @@ class LearnerImpl : public Learner {
|
||||
this->PredictRaw(data, &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
return std::make_pair(metric,
|
||||
ev->Eval(preds_.ConstHostVector(), data->Info(),
|
||||
ev->Eval(preds_, data->Info(),
|
||||
tparam_.dsplit == DataSplitMode::kRow));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user