Fix feature names and types in output model slice. (#7078)
This commit is contained in:
parent
ffa66aace0
commit
d7e1fa7664
@ -1032,6 +1032,8 @@ class LearnerImpl : public LearnerIO {
|
||||
out_impl->mparam_ = this->mparam_;
|
||||
out_impl->attributes_ = this->attributes_;
|
||||
out_impl->learner_model_param_ = this->learner_model_param_;
|
||||
out_impl->SetFeatureNames(this->feature_names_);
|
||||
out_impl->SetFeatureTypes(this->feature_types_);
|
||||
out_impl->LoadConfig(config);
|
||||
out_impl->Configure();
|
||||
return out_impl;
|
||||
|
||||
@ -379,10 +379,13 @@ class TestModels:
|
||||
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
|
||||
'objective': 'multi:softprob'},
|
||||
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||
booster.feature_types = ["q"] * X.shape[1]
|
||||
|
||||
assert len(booster.get_dump()) == total_trees
|
||||
beg = 3
|
||||
end = 7
|
||||
sliced: xgb.Booster = booster[beg: end]
|
||||
assert sliced.feature_types == booster.feature_types
|
||||
|
||||
sliced_trees = (end - beg) * num_parallel_tree * num_classes
|
||||
assert sliced_trees == len(sliced.get_dump())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user