Store metrics with learner (#2241)

Storing and then loading a model loses any eval_metric that was
provided. This causes implementations that always store/load, like
xgboost4j-spark, to be unable to eval with the desired metric.
This commit is contained in:
ebernhardson 2017-04-30 14:23:24 -07:00 committed by Nan Zhu
parent d3b866e3fd
commit da58f34ff8

View File

@ -44,8 +44,10 @@ struct LearnerModelParam
int num_class;
/*! \brief Model contain additional properties */
int contain_extra_attrs;
/*! \brief Model contain eval metrics */
int contain_eval_metrics;
/*! \brief reserved field */
int reserved[30];
int reserved[29];
/*! \brief constructor */
LearnerModelParam() {
std::memset(this, 0, sizeof(LearnerModelParam));
@ -147,6 +149,7 @@ class LearnerImpl : public Learner {
};
if (std::all_of(metrics_.begin(), metrics_.end(), dup_check)) {
metrics_.emplace_back(Metric::Create(kv.second));
mparam.contain_eval_metrics = 1;
}
} else {
cfg_[kv.first] = kv.second;
@ -273,6 +276,13 @@ class LearnerImpl : public Learner {
fi->Read(&max_delta_step);
cfg_["max_delta_step"] = max_delta_step;
}
if (mparam.contain_eval_metrics != 0) {
std::vector<std::string> metr;
fi->Read(&metr);
for (auto name : metr) {
metrics_.emplace_back(Metric::Create(name));
}
}
cfg_["num_class"] = common::ToString(mparam.num_class);
cfg_["num_feature"] = common::ToString(mparam.num_feature);
obj_->Configure(cfg_.begin(), cfg_.end());
@ -294,6 +304,13 @@ class LearnerImpl : public Learner {
if (it != cfg_.end())
fo->Write(it->second);
}
if (mparam.contain_eval_metrics != 0) {
std::vector<std::string> metr;
for (auto& ev : metrics_) {
metr.emplace_back(ev->Name());
}
fo->Write(metr);
}
}
void UpdateOneIter(int iter, DMatrix* train) override {