Support slicing tree model (#6302)

This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-11-03 02:27:39 -05:00
committed by GitHub
parent 29745c6df2
commit 2cc9662005
19 changed files with 550 additions and 37 deletions

View File

@@ -971,6 +971,26 @@ class LearnerImpl : public LearnerIO {
return gbm_->DumpModel(fmap, with_stats, format);
}
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) override {
this->Configure();
CHECK_GE(begin_layer, 0);
auto *out_impl = new LearnerImpl({});
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
this->tparam_.booster, &this->generic_parameters_,
&this->learner_model_param_));
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
out_impl->gbm_ = std::move(gbm);
Json config { Object() };
this->SaveConfig(&config);
out_impl->mparam_ = this->mparam_;
out_impl->attributes_ = this->attributes_;
out_impl->learner_model_param_ = this->learner_model_param_;
out_impl->LoadConfig(config);
out_impl->Configure();
return out_impl;
}
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);