[breaking] Save booster feature info in JSON, remove feature name generation. (#6605)
* Save feature info in booster in JSON model. * [breaking] Remove automatic feature name generation in `DMatrix`. This PR is to enable reliable feature validation in Python package.
This commit is contained in:
@@ -360,4 +360,60 @@ TEST(Learner, ConstantSeed) {
|
||||
CHECK_EQ(v_0, v_2);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Learner, FeatureInfo) {
|
||||
size_t constexpr kCols = 10;
|
||||
auto m = RandomDataGenerator{10, kCols, 0}.GenerateDMatrix(true);
|
||||
std::vector<std::string> names(kCols);
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
names[i] = ("f" + std::to_string(i));
|
||||
}
|
||||
|
||||
std::vector<std::string> types(kCols);
|
||||
for (size_t i = 0; i < kCols; ++i) {
|
||||
types[i] = "q";
|
||||
}
|
||||
types[8] = "f";
|
||||
types[0] = "int";
|
||||
types[3] = "i";
|
||||
types[7] = "i";
|
||||
|
||||
std::vector<char const*> c_names(kCols);
|
||||
for (size_t i = 0; i < names.size(); ++i) {
|
||||
c_names[i] = names[i].c_str();
|
||||
}
|
||||
std::vector<char const*> c_types(kCols);
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
c_types[i] = names[i].c_str();
|
||||
}
|
||||
|
||||
std::vector<std::string> out_names;
|
||||
std::vector<std::string> out_types;
|
||||
|
||||
Json model{Object()};
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->Configure();
|
||||
learner->SetFeatureNames(names);
|
||||
learner->GetFeatureNames(&out_names);
|
||||
|
||||
learner->SetFeatureTypes(types);
|
||||
learner->GetFeatureTypes(&out_types);
|
||||
|
||||
ASSERT_TRUE(std::equal(out_names.begin(), out_names.end(), names.begin()));
|
||||
ASSERT_TRUE(std::equal(out_types.begin(), out_types.end(), types.begin()));
|
||||
|
||||
learner->SaveModel(&model);
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->LoadModel(model);
|
||||
|
||||
learner->GetFeatureNames(&out_names);
|
||||
learner->GetFeatureTypes(&out_types);
|
||||
ASSERT_TRUE(std::equal(out_names.begin(), out_names.end(), names.begin()));
|
||||
ASSERT_TRUE(std::equal(out_types.begin(), out_types.end(), types.begin()));
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user