Send default configuration from metric to objective. (#8760)

This commit is contained in:
Jiaming Yuan
2023-02-09 20:18:07 +08:00
committed by GitHub
parent 5f76edd296
commit 199c421d60
4 changed files with 62 additions and 7 deletions

View File

@@ -520,6 +520,7 @@ class LearnerConfiguration : public Learner {
auto const& objective_fn = learner_parameters.at("objective");
if (!obj_) {
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}
obj_->LoadConfig(objective_fn);
@@ -1311,8 +1312,10 @@ class LearnerImpl : public LearnerIO {
std::ostringstream os;
os.precision(std::numeric_limits<double>::max_digits10);
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
if (metrics_.empty() && tparam_.disable_default_eval_metric <= 0) {
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
auto config = obj_->DefaultMetricConfig();
metrics_.back()->LoadConfig(config);
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}

View File

@@ -134,6 +134,12 @@ class AFTObj : public ObjFunction {
void LoadConfig(Json const& in) override {
FromJson(in["aft_loss_param"], &param_);
}
Json DefaultMetricConfig() const override {
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
config["aft_loss_param"] = ToJson(param_);
return config;
}
private:
AFTParam param_;