merge latest changes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
@@ -7,16 +7,23 @@
|
||||
#include <xgboost/tree_model.h> // for RegTree
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MultiTargetTree, JsonIO) {
|
||||
namespace {
|
||||
auto MakeTreeForTest() {
|
||||
bst_target_t n_targets{3};
|
||||
bst_feature_t n_features{4};
|
||||
RegTree tree{n_targets, n_features};
|
||||
ASSERT_TRUE(tree.IsMultiTarget());
|
||||
CHECK(tree.IsMultiTarget());
|
||||
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
|
||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
|
||||
left_weight.HostView(), right_weight.HostView());
|
||||
return tree;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(MultiTargetTree, JsonIO) {
|
||||
auto tree = MakeTreeForTest();
|
||||
ASSERT_EQ(tree.NumNodes(), 3);
|
||||
ASSERT_EQ(tree.NumTargets(), 3);
|
||||
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
|
||||
@@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
|
||||
loaded.SaveModel(&jtree1);
|
||||
check_jtree(jtree1, tree);
|
||||
}
|
||||
|
||||
TEST(MultiTargetTree, DumpDot) {
|
||||
auto tree = MakeTreeForTest();
|
||||
auto n_features = tree.NumFeatures();
|
||||
FeatureMap fmap;
|
||||
for (bst_feature_t f = 0; f < n_features; ++f) {
|
||||
auto name = "feat_" + std::to_string(f);
|
||||
fmap.PushBack(f, name.c_str(), "q");
|
||||
}
|
||||
auto str = tree.DumpModel(fmap, true, "dot");
|
||||
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
|
||||
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);
|
||||
|
||||
{
|
||||
bst_target_t n_targets{4};
|
||||
bst_feature_t n_features{4};
|
||||
RegTree tree{n_targets, n_features};
|
||||
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
|
||||
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
|
||||
weight.HostView(), weight.HostView());
|
||||
auto str = tree.DumpModel(fmap, true, "dot");
|
||||
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
* Copyright 2018-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
@@ -18,7 +18,6 @@
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
namespace {
|
||||
template <typename ExpandEntry>
|
||||
void TestPartitioner(bst_target_t n_targets) {
|
||||
@@ -253,5 +252,5 @@ void TestColumnSplit(bst_target_t n_targets) {
|
||||
|
||||
TEST(QuantileHist, ColumnSplit) { TestColumnSplit(1); }
|
||||
|
||||
TEST(QuantileHist, DISABLED_ColumnSplitMultiTarget) { TestColumnSplit(3); }
|
||||
TEST(QuantileHist, ColumnSplitMultiTarget) { TestColumnSplit(3); }
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user