Support graphviz plot for multi-target tree. (#10093)

This commit is contained in:
Jiaming Yuan
2024-03-09 05:35:25 +08:00
committed by GitHub
parent e14c3b9325
commit 2c13f90384
3 changed files with 133 additions and 63 deletions

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
@@ -7,16 +7,23 @@
#include <xgboost/tree_model.h> // for RegTree
namespace xgboost {
TEST(MultiTargetTree, JsonIO) {
namespace {
auto MakeTreeForTest() {
bst_target_t n_targets{3};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
ASSERT_TRUE(tree.IsMultiTarget());
CHECK(tree.IsMultiTarget());
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView());
return tree;
}
} // namespace
TEST(MultiTargetTree, JsonIO) {
auto tree = MakeTreeForTest();
ASSERT_EQ(tree.NumNodes(), 3);
ASSERT_EQ(tree.NumTargets(), 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
@@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
loaded.SaveModel(&jtree1);
check_jtree(jtree1, tree);
}
TEST(MultiTargetTree, DumpDot) {
auto tree = MakeTreeForTest();
auto n_features = tree.NumFeatures();
FeatureMap fmap;
for (bst_feature_t f = 0; f < n_features; ++f) {
auto name = "feat_" + std::to_string(f);
fmap.PushBack(f, name.c_str(), "q");
}
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
{
bst_target_t n_targets{4};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
weight.HostView(), weight.HostView());
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
}
}
} // namespace xgboost