Pass infomation about objective to tree methods. (#7385)

* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
Jiaming Yuan
2021-11-04 01:52:44 +08:00
committed by GitHub
parent ccdabe4512
commit 4100827971
28 changed files with 178 additions and 69 deletions

View File

@@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
}
};
LearnerModelParam::LearnerModelParam(
LearnerModelParamLegacy const &user_param, float base_margin)
: base_score{base_margin}, num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0
? 1
: static_cast<uint32_t>(user_param.num_class)}
{}
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
ObjInfo t)
: base_score{base_margin},
num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
task{t} {}
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
@@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner {
// - model is created from scratch.
// - model is configured second time due to change of parameter
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
}
this->ConfigureGBM(old_tparam, args);
@@ -832,7 +831,7 @@ class LearnerIO : public LearnerConfiguration {
}
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});