Fix feature names and types in output model slice. (#7078)

This commit is contained in:
Jiaming Yuan 2021-07-06 11:47:49 +08:00 committed by GitHub
parent ffa66aace0
commit d7e1fa7664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 0 deletions

View File

@ -1032,6 +1032,8 @@ class LearnerImpl : public LearnerIO {
out_impl->mparam_ = this->mparam_; out_impl->mparam_ = this->mparam_;
out_impl->attributes_ = this->attributes_; out_impl->attributes_ = this->attributes_;
out_impl->learner_model_param_ = this->learner_model_param_; out_impl->learner_model_param_ = this->learner_model_param_;
out_impl->SetFeatureNames(this->feature_names_);
out_impl->SetFeatureTypes(this->feature_types_);
out_impl->LoadConfig(config); out_impl->LoadConfig(config);
out_impl->Configure(); out_impl->Configure();
return out_impl; return out_impl;

View File

@ -379,10 +379,13 @@ class TestModels:
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster, 'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
'objective': 'multi:softprob'}, 'objective': 'multi:softprob'},
num_boost_round=num_boost_round, dtrain=dtrain) num_boost_round=num_boost_round, dtrain=dtrain)
booster.feature_types = ["q"] * X.shape[1]
assert len(booster.get_dump()) == total_trees assert len(booster.get_dump()) == total_trees
beg = 3 beg = 3
end = 7 end = 7
sliced: xgb.Booster = booster[beg: end] sliced: xgb.Booster = booster[beg: end]
assert sliced.feature_types == booster.feature_types
sliced_trees = (end - beg) * num_parallel_tree * num_classes sliced_trees = (end - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump()) assert sliced_trees == len(sliced.get_dump())