/** * Copyright 2023-2024, XGBoost Contributors */ #include #include // for Context #include #include // for RegTree namespace xgboost { namespace { auto MakeTreeForTest() { bst_target_t n_targets{3}; bst_feature_t n_features{4}; RegTree tree{n_targets, n_features}; CHECK(tree.IsMultiTarget()); linalg::Vector base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()}; tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), left_weight.HostView(), right_weight.HostView()); return tree; } } // namespace TEST(MultiTargetTree, JsonIO) { auto tree = MakeTreeForTest(); ASSERT_EQ(tree.NumNodes(), 3); ASSERT_EQ(tree.NumTargets(), 3); ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); ASSERT_EQ(tree.Size(), 3); Json jtree{Object{}}; tree.SaveModel(&jtree); auto check_jtree = [](Json jtree, RegTree const& tree) { ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes())); ASSERT_EQ(get(jtree["base_weights"]).size(), tree.NumNodes() * tree.NumTargets()); ASSERT_EQ(get(jtree["parents"]).size(), tree.NumNodes()); ASSERT_EQ(get(jtree["left_children"]).size(), tree.NumNodes()); ASSERT_EQ(get(jtree["right_children"]).size(), tree.NumNodes()); }; check_jtree(jtree, tree); RegTree loaded; loaded.LoadModel(jtree); ASSERT_TRUE(loaded.IsMultiTarget()); ASSERT_EQ(loaded.NumNodes(), 3); Json jtree1{Object{}}; loaded.SaveModel(&jtree1); check_jtree(jtree1, tree); } TEST(MultiTargetTree, DumpDot) { auto tree = MakeTreeForTest(); auto n_features = tree.NumFeatures(); FeatureMap fmap; for (bst_feature_t f = 0; f < n_features; ++f) { auto name = "feat_" + std::to_string(f); fmap.PushBack(f, name.c_str(), "q"); } auto str = tree.DumpModel(fmap, true, "dot"); ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos); ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos); { bst_target_t n_targets{4}; bst_feature_t n_features{4}; RegTree tree{n_targets, n_features}; linalg::Vector weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()}; tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(), weight.HostView(), weight.HostView()); auto str = tree.DumpModel(fmap, true, "dot"); ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos); } } } // namespace xgboost