diff --git a/src/learner.cc b/src/learner.cc index 15adf95bf..5ffafa782 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1032,6 +1032,8 @@ class LearnerImpl : public LearnerIO { out_impl->mparam_ = this->mparam_; out_impl->attributes_ = this->attributes_; out_impl->learner_model_param_ = this->learner_model_param_; + out_impl->SetFeatureNames(this->feature_names_); + out_impl->SetFeatureTypes(this->feature_types_); out_impl->LoadConfig(config); out_impl->Configure(); return out_impl; diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 83efaef77..219a6899a 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -379,10 +379,13 @@ class TestModels: 'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster, 'objective': 'multi:softprob'}, num_boost_round=num_boost_round, dtrain=dtrain) + booster.feature_types = ["q"] * X.shape[1] + assert len(booster.get_dump()) == total_trees beg = 3 end = 7 sliced: xgb.Booster = booster[beg: end] + assert sliced.feature_types == booster.feature_types sliced_trees = (end - beg) * num_parallel_tree * num_classes assert sliced_trees == len(sliced.get_dump())