xgboost/tests/cpp/tree/test_multi_target_tree_model.cc
Jiaming Yuan 5feee8d4a9
Define core multi-target regression tree structure. (#8884)
- Define a new tree struct embedded in the `RegTree`.
- Provide dispatching functions in `RegTree`.
- Fix some c++-17 warnings about the use of nodiscard (currently we disable the warning on
  the CI).
- Use uint32_t instead of size_t for `bst_target_t` as it has a defined size and can be used
  as part of dmlc parameter.
- Hide the `Segment` struct inside the categorical split matrix.
2023-03-09 19:03:06 +08:00

49 lines
1.8 KiB
C++

/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <xgboost/multi_target_tree_model.h>
#include <xgboost/tree_model.h> // for RegTree
namespace xgboost {
TEST(MultiTargetTree, JsonIO) {
bst_target_t n_targets{3};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
ASSERT_TRUE(tree.IsMultiTarget());
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, Context::kCpuId};
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, Context::kCpuId};
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView());
ASSERT_EQ(tree.param.num_nodes, 3);
ASSERT_EQ(tree.param.size_leaf_vector, 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
ASSERT_EQ(tree.Size(), 3);
Json jtree{Object{}};
tree.SaveModel(&jtree);
auto check_jtree = [](Json jtree, RegTree const& tree) {
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]),
std::to_string(tree.param.num_nodes));
ASSERT_EQ(get<F32Array const>(jtree["base_weights"]).size(),
tree.param.num_nodes * tree.param.size_leaf_vector);
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.param.num_nodes);
};
check_jtree(jtree, tree);
RegTree loaded;
loaded.LoadModel(jtree);
ASSERT_TRUE(loaded.IsMultiTarget());
ASSERT_EQ(loaded.param.num_nodes, 3);
Json jtree1{Object{}};
loaded.SaveModel(&jtree1);
check_jtree(jtree1, tree);
}
} // namespace xgboost