Fix tree param feature type. (#7565)

This commit is contained in:
Jiaming Yuan 2022-01-16 04:46:29 +08:00 committed by GitHub
parent a1bcd33a3b
commit 465dc63833
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,6 +8,52 @@
#include "../../../src/common/categorical.h"
namespace xgboost {
TEST(Tree, ModelShape) {
bst_feature_t n_features = std::numeric_limits<uint32_t>::max();
RegTree tree;
tree.param.UpdateAllowUnknown(Args{{"num_feature", std::to_string(n_features)}});
ASSERT_EQ(tree.param.num_feature, n_features);
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/tree.model";
{
// binary dump
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(tmp_file.c_str(), "w"));
tree.Save(fo.get());
}
{
// binary load
RegTree new_tree;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
new_tree.Load(fi.get());
ASSERT_EQ(new_tree.param.num_feature, n_features);
}
{
// json
Json j_tree{Object{}};
tree.SaveModel(&j_tree);
std::vector<char> dumped;
Json::Dump(j_tree, &dumped);
RegTree new_tree;
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()});
new_tree.LoadModel(j_loaded);
ASSERT_EQ(new_tree.param.num_feature, n_features);
}
{
// ubjson
Json j_tree{Object{}};
tree.SaveModel(&j_tree);
std::vector<char> dumped;
Json::Dump(j_tree, &dumped, std::ios::binary);
RegTree new_tree;
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}, std::ios::binary);
new_tree.LoadModel(j_loaded);
ASSERT_EQ(new_tree.param.num_feature, n_features);
}
}
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
// Manually construct tree in binary format
// Do not use structs in case they change