diff --git a/src/learner.cc b/src/learner.cc index cb49a3838..8c92556fc 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -44,8 +44,10 @@ struct LearnerModelParam int num_class; /*! \brief Model contain additional properties */ int contain_extra_attrs; + /*! \brief Model contain eval metrics */ + int contain_eval_metrics; /*! \brief reserved field */ - int reserved[30]; + int reserved[29]; /*! \brief constructor */ LearnerModelParam() { std::memset(this, 0, sizeof(LearnerModelParam)); @@ -147,6 +149,7 @@ class LearnerImpl : public Learner { }; if (std::all_of(metrics_.begin(), metrics_.end(), dup_check)) { metrics_.emplace_back(Metric::Create(kv.second)); + mparam.contain_eval_metrics = 1; } } else { cfg_[kv.first] = kv.second; @@ -273,6 +276,13 @@ class LearnerImpl : public Learner { fi->Read(&max_delta_step); cfg_["max_delta_step"] = max_delta_step; } + if (mparam.contain_eval_metrics != 0) { + std::vector metr; + fi->Read(&metr); + for (auto name : metr) { + metrics_.emplace_back(Metric::Create(name)); + } + } cfg_["num_class"] = common::ToString(mparam.num_class); cfg_["num_feature"] = common::ToString(mparam.num_feature); obj_->Configure(cfg_.begin(), cfg_.end()); @@ -294,6 +304,13 @@ class LearnerImpl : public Learner { if (it != cfg_.end()) fo->Write(it->second); } + if (mparam.contain_eval_metrics != 0) { + std::vector metr; + for (auto& ev : metrics_) { + metr.emplace_back(ev->Name()); + } + fo->Write(metr); + } } void UpdateOneIter(int iter, DMatrix* train) override {