Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
}
|
||||
};
|
||||
|
||||
LearnerModelParam::LearnerModelParam(
|
||||
LearnerModelParamLegacy const &user_param, float base_margin)
|
||||
: base_score{base_margin}, num_feature{user_param.num_feature},
|
||||
num_output_group{user_param.num_class == 0
|
||||
? 1
|
||||
: static_cast<uint32_t>(user_param.num_class)}
|
||||
{}
|
||||
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
|
||||
ObjInfo t)
|
||||
: base_score{base_margin},
|
||||
num_feature{user_param.num_feature},
|
||||
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
|
||||
task{t} {}
|
||||
|
||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
// data split mode, can be row, col, or none.
|
||||
@@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner {
|
||||
// - model is created from scratch.
|
||||
// - model is configured second time due to change of parameter
|
||||
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
learner_model_param_ =
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
|
||||
}
|
||||
|
||||
this->ConfigureGBM(old_tparam, args);
|
||||
@@ -832,7 +831,7 @@ class LearnerIO : public LearnerConfiguration {
|
||||
}
|
||||
|
||||
learner_model_param_ =
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
|
||||
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
|
||||
if (attributes_.find("objective") != attributes_.cend()) {
|
||||
auto obj_str = attributes_.at("objective");
|
||||
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
|
||||
|
||||
Reference in New Issue
Block a user