Check whether current updater can modify a tree. (#5406)
* Check whether current updater can modify a tree. * Fix tree model JSON IO for pruned trees.
This commit is contained in:
@@ -51,6 +51,22 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
TEST(GBTree, WrongUpdater) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
std::shared_ptr<DMatrix> p_dmat {*pp_dmat};
|
||||
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
|
||||
// Hist can not be used for updating tree.
|
||||
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
|
||||
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
TEST(GBTree, ChoosePredictor) {
|
||||
size_t constexpr kRows = 17;
|
||||
|
||||
@@ -225,8 +225,6 @@ TEST(Tree, JsonIO) {
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
Json j_tree{Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
std::stringstream ss;
|
||||
Json::Dump(j_tree, &ss);
|
||||
|
||||
auto tparam = j_tree["tree_param"];
|
||||
ASSERT_EQ(get<String>(tparam["num_feature"]), "0");
|
||||
@@ -243,6 +241,23 @@ TEST(Tree, JsonIO) {
|
||||
RegTree loaded_tree;
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
|
||||
|
||||
ASSERT_TRUE(loaded_tree == tree);
|
||||
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.SaveModel(&j_tree);
|
||||
|
||||
tree.ChangeToLeaf(1, 1.0f);
|
||||
ASSERT_EQ(tree[1].LeftChild(), -1);
|
||||
ASSERT_EQ(tree[1].RightChild(), -1);
|
||||
tree.SaveModel(&j_tree);
|
||||
loaded_tree.LoadModel(j_tree);
|
||||
ASSERT_EQ(loaded_tree[1].LeftChild(), -1);
|
||||
ASSERT_EQ(loaded_tree[1].RightChild(), -1);
|
||||
ASSERT_TRUE(tree.Equal(loaded_tree));
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user