Fix model slicing. (#7149)
* Use correct pointer. * Remove best_iteration/best_score.
This commit is contained in:
parent
36346f8f56
commit
d080b5a953
@ -103,7 +103,9 @@ def _train_internal(params, dtrain,
|
|||||||
# Due to compatibility with version older than 1.4, these attributes are added
|
# Due to compatibility with version older than 1.4, these attributes are added
|
||||||
# to Python object even if early stopping is not used.
|
# to Python object even if early stopping is not used.
|
||||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||||
|
bst.set_attr(best_iteration=str(bst.best_iteration))
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
|
bst.set_attr(best_ntree_limit=str(bst.best_ntree_limit))
|
||||||
|
|
||||||
# Copy to serialise and unserialise booster to reset state and free
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# training memory
|
# training memory
|
||||||
|
|||||||
@ -443,6 +443,8 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
|||||||
CHECK(p_gbtree);
|
CHECK(p_gbtree);
|
||||||
GBTreeModel &out_model = p_gbtree->model_;
|
GBTreeModel &out_model = p_gbtree->model_;
|
||||||
auto layer_trees = this->LayerTrees();
|
auto layer_trees = this->LayerTrees();
|
||||||
|
CHECK_NE(this->model_.learner_model_param->num_feature, 0);
|
||||||
|
CHECK_NE(layer_trees, 0);
|
||||||
|
|
||||||
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
|
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
|
||||||
CHECK_GT(layer_end, layer_begin);
|
CHECK_GT(layer_end, layer_begin);
|
||||||
@ -453,7 +455,13 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
|||||||
std::vector<int32_t> &out_trees_info = out_model.tree_info;
|
std::vector<int32_t> &out_trees_info = out_model.tree_info;
|
||||||
out_trees_info.resize(layer_trees * n_layers);
|
out_trees_info.resize(layer_trees * n_layers);
|
||||||
out_model.param.num_trees = out_model.trees.size();
|
out_model.param.num_trees = out_model.trees.size();
|
||||||
CHECK(this->model_.trees_to_update.empty());
|
if (!this->model_.trees_to_update.empty()) {
|
||||||
|
CHECK_EQ(this->model_.trees_to_update.size(), this->model_.trees.size())
|
||||||
|
<< "Not all trees are updated, "
|
||||||
|
<< this->model_.trees_to_update.size() - this->model_.trees.size()
|
||||||
|
<< " trees remain. Slice the model before making update if you only "
|
||||||
|
"want to update a portion of trees.";
|
||||||
|
}
|
||||||
|
|
||||||
*out_of_bound = detail::SliceTrees(
|
*out_of_bound = detail::SliceTrees(
|
||||||
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
|
layer_begin, layer_end, step, this->model_, tparam_, layer_trees,
|
||||||
|
|||||||
@ -1024,22 +1024,37 @@ class LearnerImpl : public LearnerIO {
|
|||||||
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
|
||||||
bool *out_of_bound) override {
|
bool *out_of_bound) override {
|
||||||
this->Configure();
|
this->Configure();
|
||||||
|
CHECK_NE(this->learner_model_param_.num_feature, 0);
|
||||||
CHECK_GE(begin_layer, 0);
|
CHECK_GE(begin_layer, 0);
|
||||||
auto *out_impl = new LearnerImpl({});
|
auto *out_impl = new LearnerImpl({});
|
||||||
|
out_impl->learner_model_param_ = this->learner_model_param_;
|
||||||
|
out_impl->generic_parameters_ = this->generic_parameters_;
|
||||||
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
|
auto gbm = std::unique_ptr<GradientBooster>(GradientBooster::Create(
|
||||||
this->tparam_.booster, &this->generic_parameters_,
|
this->tparam_.booster, &out_impl->generic_parameters_,
|
||||||
&this->learner_model_param_));
|
&out_impl->learner_model_param_));
|
||||||
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
|
this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound);
|
||||||
out_impl->gbm_ = std::move(gbm);
|
out_impl->gbm_ = std::move(gbm);
|
||||||
|
|
||||||
Json config { Object() };
|
Json config { Object() };
|
||||||
this->SaveConfig(&config);
|
this->SaveConfig(&config);
|
||||||
out_impl->mparam_ = this->mparam_;
|
out_impl->mparam_ = this->mparam_;
|
||||||
out_impl->attributes_ = this->attributes_;
|
out_impl->attributes_ = this->attributes_;
|
||||||
out_impl->learner_model_param_ = this->learner_model_param_;
|
|
||||||
out_impl->SetFeatureNames(this->feature_names_);
|
out_impl->SetFeatureNames(this->feature_names_);
|
||||||
out_impl->SetFeatureTypes(this->feature_types_);
|
out_impl->SetFeatureTypes(this->feature_types_);
|
||||||
out_impl->LoadConfig(config);
|
out_impl->LoadConfig(config);
|
||||||
out_impl->Configure();
|
out_impl->Configure();
|
||||||
|
CHECK_EQ(out_impl->learner_model_param_.num_feature, this->learner_model_param_.num_feature);
|
||||||
|
CHECK_NE(out_impl->learner_model_param_.num_feature, 0);
|
||||||
|
|
||||||
|
auto erase_attr = [&](std::string attr) {
|
||||||
|
// Erase invalid attributes.
|
||||||
|
auto attr_it = out_impl->attributes_.find(attr);
|
||||||
|
if (attr_it != out_impl->attributes_.cend()) {
|
||||||
|
out_impl->attributes_.erase(attr_it);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
erase_attr("best_iteration");
|
||||||
|
erase_attr("best_score");
|
||||||
return out_impl;
|
return out_impl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -397,6 +397,10 @@ std::pair<Json, Json> TestModelSlice(std::string booster) {
|
|||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK sliced model doesn't have dependency on old one
|
||||||
|
learner.reset();
|
||||||
|
CHECK_EQ(sliced->GetNumFeature(), kCols);
|
||||||
|
|
||||||
return std::make_pair(model, sliced_model);
|
return std::make_pair(model, sliced_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -99,6 +99,8 @@ eval[test] = {data_path}
|
|||||||
# CLI model doesn't contain feature info.
|
# CLI model doesn't contain feature info.
|
||||||
booster.feature_names = None
|
booster.feature_names = None
|
||||||
booster.feature_types = None
|
booster.feature_types = None
|
||||||
|
booster.set_attr(best_iteration=None)
|
||||||
|
booster.set_attr(best_ntree_limit=None)
|
||||||
|
|
||||||
booster.save_model(model_out_py)
|
booster.save_model(model_out_py)
|
||||||
py_predt = booster.predict(data)
|
py_predt = booster.predict(data)
|
||||||
|
|||||||
@ -114,7 +114,8 @@ def run_data_iterator(
|
|||||||
if tree_method != "gpu_hist":
|
if tree_method != "gpu_hist":
|
||||||
rtol = 1e-1 # flaky
|
rtol = 1e-1 # flaky
|
||||||
else:
|
else:
|
||||||
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-3)
|
# Model can be sensitive to quantiles, use 1e-2 to relax the test.
|
||||||
|
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-2)
|
||||||
rtol = 1e-6
|
rtol = 1e-6
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user