Send default configuration from metric to objective. (#8760)
This commit is contained in:
@@ -520,6 +520,7 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
auto const& objective_fn = learner_parameters.at("objective");
|
||||
if (!obj_) {
|
||||
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
|
||||
}
|
||||
obj_->LoadConfig(objective_fn);
|
||||
@@ -1311,8 +1312,10 @@ class LearnerImpl : public LearnerIO {
|
||||
std::ostringstream os;
|
||||
os.precision(std::numeric_limits<double>::max_digits10);
|
||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||
if (metrics_.empty() && tparam_.disable_default_eval_metric <= 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
|
||||
auto config = obj_->DefaultMetricConfig();
|
||||
metrics_.back()->LoadConfig(config);
|
||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||
}
|
||||
|
||||
|
||||
@@ -134,6 +134,12 @@ class AFTObj : public ObjFunction {
|
||||
void LoadConfig(Json const& in) override {
|
||||
FromJson(in["aft_loss_param"], ¶m_);
|
||||
}
|
||||
Json DefaultMetricConfig() const override {
|
||||
Json config{Object{}};
|
||||
config["name"] = String{this->DefaultEvalMetric()};
|
||||
config["aft_loss_param"] = ToJson(param_);
|
||||
return config;
|
||||
}
|
||||
|
||||
private:
|
||||
AFTParam param_;
|
||||
|
||||
Reference in New Issue
Block a user