Fix feature names and types in output model slice. (#7078)
This commit is contained in:
parent
ffa66aace0
commit
d7e1fa7664
@ -1032,6 +1032,8 @@ class LearnerImpl : public LearnerIO {
|
|||||||
out_impl->mparam_ = this->mparam_;
|
out_impl->mparam_ = this->mparam_;
|
||||||
out_impl->attributes_ = this->attributes_;
|
out_impl->attributes_ = this->attributes_;
|
||||||
out_impl->learner_model_param_ = this->learner_model_param_;
|
out_impl->learner_model_param_ = this->learner_model_param_;
|
||||||
|
out_impl->SetFeatureNames(this->feature_names_);
|
||||||
|
out_impl->SetFeatureTypes(this->feature_types_);
|
||||||
out_impl->LoadConfig(config);
|
out_impl->LoadConfig(config);
|
||||||
out_impl->Configure();
|
out_impl->Configure();
|
||||||
return out_impl;
|
return out_impl;
|
||||||
|
|||||||
@ -379,10 +379,13 @@ class TestModels:
|
|||||||
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
|
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster,
|
||||||
'objective': 'multi:softprob'},
|
'objective': 'multi:softprob'},
|
||||||
num_boost_round=num_boost_round, dtrain=dtrain)
|
num_boost_round=num_boost_round, dtrain=dtrain)
|
||||||
|
booster.feature_types = ["q"] * X.shape[1]
|
||||||
|
|
||||||
assert len(booster.get_dump()) == total_trees
|
assert len(booster.get_dump()) == total_trees
|
||||||
beg = 3
|
beg = 3
|
||||||
end = 7
|
end = 7
|
||||||
sliced: xgb.Booster = booster[beg: end]
|
sliced: xgb.Booster = booster[beg: end]
|
||||||
|
assert sliced.feature_types == booster.feature_types
|
||||||
|
|
||||||
sliced_trees = (end - beg) * num_parallel_tree * num_classes
|
sliced_trees = (end - beg) * num_parallel_tree * num_classes
|
||||||
assert sliced_trees == len(sliced.get_dump())
|
assert sliced_trees == len(sliced.get_dump())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user