Support early stopping with training continuation, correct num boosted rounds. (#6506)

* Implement early stopping with training continuation.

* Add new C API for obtaining boosted rounds.

* Fix off by 1 in `save_best`.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-12-17 19:59:19 +08:00
committed by GitHub
parent 125b3c0f2d
commit ca3da55de4
16 changed files with 210 additions and 118 deletions

View File

@@ -502,6 +502,14 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
API_END();
}
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Learner*>(handle)->Configure();
*out = static_cast<Learner*>(handle)->BoostedRounds();
API_END();
}
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN();
CHECK_HANDLE();

View File

@@ -73,6 +73,10 @@ class GBLinear : public GradientBooster {
}
}
int32_t BoostedRounds() const override {
return model_.num_boosted_rounds;
}
void Load(dmlc::Stream* fi) override {
model_.Load(fi);
}
@@ -122,7 +126,7 @@ class GBLinear : public GradientBooster {
if (!this->CheckConvergence()) {
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
}
model_.num_boosted_rounds++;
monitor_.Stop("DoBoost");
}

View File

@@ -44,11 +44,12 @@ class GBLinearModel : public Model {
DeprecatedGBLinearModelParam param_;
public:
int32_t num_boosted_rounds;
LearnerModelParam const* learner_model_param;
public:
explicit GBLinearModel(LearnerModelParam const* learner_model_param) :
learner_model_param {learner_model_param} {}
num_boosted_rounds{0}, learner_model_param {learner_model_param} {}
void Configure(Args const &) { }
// weight for each of feature, bias is the last one

View File

@@ -249,10 +249,17 @@ class GBTree : public GradientBooster {
auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree;
return n_trees;
}
// slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override;
int32_t BoostedRounds() const override {
CHECK_NE(tparam_.num_parallel_tree, 0);
CHECK_NE(model_.learner_model_param->num_output_group, 0);
return model_.trees.size() / this->LayerTrees();
}
void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool training,

View File

@@ -1107,6 +1107,12 @@ class LearnerImpl : public LearnerIO {
}
}
int32_t BoostedRounds() const override {
if (!this->gbm_) { return 0; } // haven't call train or LoadModel.
CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds();
}
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
return (*LearnerAPIThreadLocalStore::Get())[this];
}