diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index df151ce9a..e5e6971e5 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 by Contributors +/** + * Copyright 2021-2023 by Contributors * \file row_set.h * \brief Quick Utility to compute subset of rows * \author Philip Cho, Tianqi Chen @@ -10,6 +10,7 @@ #include #include +#include // for size_t #include #include #include @@ -21,9 +22,7 @@ #include "xgboost/context.h" #include "xgboost/tree_model.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { // The builder is required for samples partition to left and rights children for set of nodes // Responsible for: // 1) Effective memory allocation for intermediate results for multi-thread work @@ -109,18 +108,17 @@ class PartitionBuilder { return {nleft_elems, nright_elems}; } - template - void Partition(const size_t node_in_set, std::vector const &nodes, - const common::Range1d range, - const bst_bin_t split_cond, GHistIndexMatrix const& gmat, - const common::ColumnMatrix& column_matrix, + template + void Partition(const size_t node_in_set, std::vector const& nodes, + const common::Range1d range, const bst_bin_t split_cond, + GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { common::Span rid_span(rid + range.begin(), rid + range.end()); common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); std::size_t nid = nodes[node_in_set].nid; - bst_feature_t fid = tree[nid].SplitIndex(); - bool default_left = tree[nid].DefaultLeft(); + bst_feature_t fid = tree.SplitIndex(nid); + bool default_left = tree.DefaultLeft(nid); bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; auto node_cats = tree.NodeCats(nid); auto const& cut_values = gmat.cut.Values(); @@ -190,10 +188,10 @@ class PartitionBuilder { * worker, so we go through all the rows and mark the bit vectors on whether the decision is made * to go right, or if the feature value used for the split is missing. */ - void MaskRows(const size_t node_in_set, std::vector const &nodes, + template + void MaskRows(const size_t node_in_set, std::vector const& nodes, const common::Range1d range, GHistIndexMatrix const& gmat, - const common::ColumnMatrix& column_matrix, - const RegTree& tree, const size_t* rid, + const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid, BitVector* decision_bits, BitVector* missing_bits) { common::Span rid_span(rid + range.begin(), rid + range.end()); std::size_t nid = nodes[node_in_set].nid; @@ -228,8 +226,8 @@ class PartitionBuilder { * @brief Once we've aggregated the decision and missing bits from all the workers, we can then * use them to partition the rows accordingly. */ - void PartitionByMask(const size_t node_in_set, - std::vector const& nodes, + template + void PartitionByMask(const size_t node_in_set, std::vector const& nodes, const common::Range1d range, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid, BitVector const& decision_bits, @@ -293,11 +291,11 @@ class PartitionBuilder { } - size_t GetNLeftElems(int nid) const { + [[nodiscard]] std::size_t GetNLeftElems(int nid) const { return left_right_nodes_sizes_[nid].first; } - size_t GetNRightElems(int nid) const { + [[nodiscard]] std::size_t GetNRightElems(int nid) const { return left_right_nodes_sizes_[nid].second; } @@ -349,7 +347,7 @@ class PartitionBuilder { if (node.node_id < 0) { return; } - CHECK(tree[node.node_id].IsLeaf()); + CHECK(tree.IsLeaf(node.node_id)); if (node.begin) { // guard for empty node. size_t ptr_offset = node.end - p_begin; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; @@ -384,8 +382,5 @@ class PartitionBuilder { std::vector> mem_blocks_; size_t max_n_tasks_ = 0; }; - -} // namespace common -} // namespace xgboost - +} // namespace xgboost::common #endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ diff --git a/src/learner.cc b/src/learner.cc index d91add70d..e1b5605ca 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -343,8 +343,8 @@ struct LearnerTrainParam : public XGBoostParameter { .add_enum("monolithic", MultiStrategy::kMonolithic) .set_default(MultiStrategy::kComposite) .describe( - "Strategy used for training multi-target models. `mono` means building one single tree " - "for all targets."); + "Strategy used for training multi-target models. `monolithic` means building one " + "single tree for all targets."); } }; diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index a58dbb452..ba69d8921 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -1,22 +1,26 @@ -/*! - * Copyright 2021-2022 XGBoost contributors +/** + * Copyright 2021-2023 XGBoost contributors * \file common_row_partitioner.h * \brief Common partitioner logic for hist and approx methods. */ #ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ +#include // std::all_of +#include // std::uint32_t #include // std::numeric_limits #include #include "../collective/communicator-inl.h" +#include "../common/linalg_op.h" // cbegin #include "../common/numeric.h" // Iota #include "../common/partition_builder.h" #include "hist/expand_entry.h" // CPUExpandEntry +#include "xgboost/base.h" #include "xgboost/context.h" // Context +#include "xgboost/linalg.h" // TensorView -namespace xgboost { -namespace tree { +namespace xgboost::tree { static constexpr size_t kPartitionBlockSize = 2048; @@ -34,9 +38,10 @@ class ColumnSplitHelper { missing_bits_ = BitVector(common::Span(missing_storage_)); } + template void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads, GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { // When data is split by column, we don't have all the feature values in the local worker, so // we first collect all the decisions and whether the feature is missing into bit vectors. std::fill(decision_storage_.begin(), decision_storage_.end(), 0); @@ -97,17 +102,18 @@ class CommonRowPartitioner { } } - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + template + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, const GHistIndexMatrix& gmat, std::vector* split_conditions) { auto const& ptrs = gmat.cut.Ptrs(); auto const& vals = gmat.cut.Values(); for (std::size_t i = 0; i < nodes.size(); ++i) { - bst_node_t const nid = nodes[i].nid; - bst_feature_t const fid = tree[nid].SplitIndex(); - const float split_pt = tree[nid].SplitCond(); - const uint32_t lower_bound = ptrs[fid]; - const uint32_t upper_bound = ptrs[fid + 1]; + bst_node_t const nidx = nodes[i].nid; + bst_feature_t const fidx = tree.SplitIndex(nidx); + float const split_pt = tree.SplitCond(nidx); + std::uint32_t const lower_bound = ptrs[fidx]; + std::uint32_t const upper_bound = ptrs[fidx + 1]; bst_bin_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points @@ -121,20 +127,22 @@ class CommonRowPartitioner { } } - void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { + template + void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { const size_t n_nodes = nodes.size(); for (unsigned int i = 0; i < n_nodes; ++i) { - const int32_t nid = nodes[i].nid; + const int32_t nidx = nodes[i].nid; const size_t n_left = partition_builder_.GetNLeftElems(i); const size_t n_right = partition_builder_.GetNRightElems(i); - CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); - row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), - n_left, n_right); + CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx)); + row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left, + n_right); } } + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { auto const& column_matrix = gmat.Transpose(); if (column_matrix.IsInitialized()) { if (gmat.cut.HasCategorical()) { @@ -152,10 +160,10 @@ class CommonRowPartitioner { } } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { if (column_matrix.AnyMissing()) { this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); } else { @@ -163,33 +171,21 @@ class CommonRowPartitioner { } } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { - switch (column_matrix.GetTypeSize()) { - case common::kUint8BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - case common::kUint16BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - case common::kUint32BinsTypeSize: - this->template UpdatePosition(ctx, gmat, column_matrix, - nodes, p_tree); - break; - default: - // no default behavior - CHECK(false) << column_matrix.GetTypeSize(); - } + std::vector const& nodes, RegTree const* p_tree) { + common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) { + using T = decltype(t); + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, + p_tree); + }); } - template + template void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { + std::vector const& nodes, RegTree const* p_tree) { // 1. Find split condition for each split size_t n_nodes = nodes.size(); @@ -251,9 +247,9 @@ class CommonRowPartitioner { AddSplitsToRowSet(nodes, p_tree); } - auto const& Partitions() const { return row_set_collection_; } + [[nodiscard]] auto const& Partitions() const { return row_set_collection_; } - size_t Size() const { + [[nodiscard]] std::size_t Size() const { return std::distance(row_set_collection_.begin(), row_set_collection_.end()); } @@ -266,12 +262,29 @@ class CommonRowPartitioner { [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); } + void LeafPartition(Context const* ctx, RegTree const& tree, + linalg::TensorView gpair, + std::vector* p_out_position) const { + if (gpair.Shape(1) > 1) { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool { + auto sample = gpair.Slice(idx, linalg::All()); + return std::all_of(linalg::cbegin(sample), linalg::cend(sample), + [](GradientPair const& g) { return g.GetHess() - .0f == .0f; }); + }); + } else { + auto s = gpair.Slice(linalg::All(), 0); + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, + [&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; }); + } + } void LeafPartition(Context const* ctx, RegTree const& tree, common::Span gpair, std::vector* p_out_position) const { partition_builder_.LeafPartition( ctx, tree, this->Partitions(), p_out_position, - [&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); + [&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); } private: @@ -281,6 +294,5 @@ class CommonRowPartitioner { ColumnSplitHelper column_split_helper_; }; -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree #endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ diff --git a/tests/cpp/common/test_partition_builder.cc b/tests/cpp/common/test_partition_builder.cc index 4e6d800a7..08dd345f2 100644 --- a/tests/cpp/common/test_partition_builder.cc +++ b/tests/cpp/common/test_partition_builder.cc @@ -1,15 +1,17 @@ +/** + * Copyright 2020-2023 by XGBoost contributors + */ #include -#include + #include #include +#include -#include "../../../src/common/row_set.h" #include "../../../src/common/partition_builder.h" +#include "../../../src/common/row_set.h" #include "../helpers.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { TEST(PartitionBuilder, BasicTest) { constexpr size_t kBlockSize = 16; constexpr size_t kNodes = 5; @@ -74,6 +76,4 @@ TEST(PartitionBuilder, BasicTest) { ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]); } } - -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 308ae0823..6f2b83511 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -148,78 +148,5 @@ TEST(Approx, PartitionerColSplit) { RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy, &hess, min_value, mid_value, mid_partitioner); } - -namespace { -void TestLeafPartition(size_t n_samples) { - size_t const n_features = 2, base_rowid = 0; - Context ctx; - common::RowSetCollection row_set; - CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; - - auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0}}; - candidates.front().split.loss_chg = 0.4; - RegTree tree; - std::vector hess(n_samples, 0); - // emulate sampling - auto not_sampled = [](size_t i) { - size_t const kSampleFactor{3}; - return i % kSampleFactor != 0; - }; - for (size_t i = 0; i < hess.size(); ++i) { - if (not_sampled(i)) { - hess[i] = 1.0f; - } - } - - std::vector h_nptr; - float split_value{0}; - for (auto const& page : Xy->GetBatches({Context::kCpuId, 64})) { - bst_feature_t const split_ind = 0; - auto ptr = page.cut.Ptrs()[split_ind + 1]; - split_value = page.cut.Values().at(ptr / 2); - GetSplit(&tree, split_value, &candidates); - partitioner.UpdatePosition(&ctx, page, candidates, &tree); - std::vector position; - partitioner.LeafPartition(&ctx, tree, hess, &position); - std::sort(position.begin(), position.end()); - size_t beg = std::distance( - position.begin(), - std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; })); - std::vector nptr; - common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr); - std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; }); - auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg); - ASSERT_EQ(nptr.size(), n_uniques + 1); - ASSERT_EQ(nptr[0], beg); - ASSERT_EQ(nptr.back(), n_samples); - - h_nptr = nptr; - } - - if (h_nptr.front() == n_samples) { - return; - } - - ASSERT_GE(h_nptr.size(), 2); - - for (auto const& page : Xy->GetBatches()) { - auto batch = page.GetView(); - size_t left{0}; - for (size_t i = 0; i < batch.Size(); ++i) { - if (not_sampled(i) && batch[i].front().fvalue < split_value) { - left++; - } - } - ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left - } -} -} // anonymous namespace - -TEST(Approx, LeafPartition) { - for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) { - TestLeafPartition(n_samples); - } -} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_common_partitioner.cc b/tests/cpp/tree/test_common_partitioner.cc new file mode 100644 index 000000000..7e47ec289 --- /dev/null +++ b/tests/cpp/tree/test_common_partitioner.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2022-2023 by XGBoost contributors. + */ +#include +#include // for bst_node_t +#include // for Context + +#include // for transform +#include // for distance +#include // for vector + +#include "../../../src/common/numeric.h" // for ==RunLengthEncode +#include "../../../src/common/row_set.h" // for RowSetCollection +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#include "../../../src/tree/common_row_partitioner.h" +#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry +#include "../helpers.h" // for RandomDataGenerator +#include "test_partitioner.h" // for GetSplit + +namespace xgboost::tree { +namespace { +void TestLeafPartition(size_t n_samples) { + size_t const n_features = 2, base_rowid = 0; + Context ctx; + common::RowSetCollection row_set; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; + + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + std::vector candidates{{0, 0}}; + candidates.front().split.loss_chg = 0.4; + RegTree tree; + std::vector hess(n_samples, 0); + // emulate sampling + auto not_sampled = [](size_t i) { + size_t const kSampleFactor{3}; + return i % kSampleFactor != 0; + }; + for (size_t i = 0; i < hess.size(); ++i) { + if (not_sampled(i)) { + hess[i] = 1.0f; + } + } + + std::vector h_nptr; + float split_value{0}; + for (auto const& page : Xy->GetBatches({Context::kCpuId, 64})) { + bst_feature_t const split_ind = 0; + auto ptr = page.cut.Ptrs()[split_ind + 1]; + split_value = page.cut.Values().at(ptr / 2); + GetSplit(&tree, split_value, &candidates); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + std::vector position; + partitioner.LeafPartition(&ctx, tree, hess, &position); + std::sort(position.begin(), position.end()); + size_t beg = std::distance( + position.begin(), + std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; })); + std::vector nptr; + common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr); + std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; }); + auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg); + ASSERT_EQ(nptr.size(), n_uniques + 1); + ASSERT_EQ(nptr[0], beg); + ASSERT_EQ(nptr.back(), n_samples); + + h_nptr = nptr; + } + + if (h_nptr.front() == n_samples) { + return; + } + + ASSERT_GE(h_nptr.size(), 2); + + for (auto const& page : Xy->GetBatches()) { + auto batch = page.GetView(); + size_t left{0}; + for (size_t i = 0; i < batch.Size(); ++i) { + if (not_sampled(i) && batch[i].front().fvalue < split_value) { + left++; + } + } + ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left + } +} +} // anonymous namespace + +TEST(CommonRowPartitioner, LeafPartition) { + for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) { + TestLeafPartition(n_samples); + } +} +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_partitioner.h b/tests/cpp/tree/test_partitioner.h index 093aa69eb..fbd98ddf9 100644 --- a/tests/cpp/tree/test_partitioner.h +++ b/tests/cpp/tree/test_partitioner.h @@ -1,17 +1,20 @@ -/*! - * Copyright 2021-2022, XGBoost contributors. +/** + * Copyright 2021-2023 by XGBoost contributors. */ #ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ -#include +#include // for Context +#include // for Constant, Vector +#include // for CHECK +#include // for RegTree -#include +#include // for vector -#include "../../../src/tree/hist/expand_entry.h" +#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry -namespace xgboost { -namespace tree { +namespace xgboost::tree { inline void GetSplit(RegTree *tree, float split_value, std::vector *candidates) { + CHECK(!tree->IsMultiTarget()); tree->ExpandNode( /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, @@ -21,6 +24,22 @@ inline void GetSplit(RegTree *tree, float split_value, std::vectorfront().split.sindex = 0; candidates->front().split.sindex |= (1U << 31); } -} // namespace tree -} // namespace xgboost + +inline void GetMultiSplitForTest(RegTree *tree, float split_value, + std::vector *candidates) { + CHECK(tree->IsMultiTarget()); + auto n_targets = tree->NumTargets(); + Context ctx; + linalg::Vector base_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; + linalg::Vector left_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; + linalg::Vector right_weight{linalg::Constant(&ctx, 0.0f, n_targets)}; + + tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, + /*default_left=*/true, base_weight.HostView(), left_weight.HostView(), + right_weight.HostView()); + candidates->front().split.split_value = split_value; + candidates->front().split.sindex = 0; + candidates->front().split.sindex |= (1U << 31); +} +} // namespace xgboost::tree #endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 42edc2124..2aa1b8f47 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -1,25 +1,29 @@ -/*! - * Copyright 2018-2022 by XGBoost Contributors +/** + * Copyright 2018-2023 by XGBoost Contributors */ #include #include #include #include +#include // for size_t #include #include +#include "../../../src/tree/common_row_partitioner.h" +#include "../../../src/tree/hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry #include "../../../src/tree/param.h" #include "../../../src/tree/split_evaluator.h" -#include "../../../src/tree/common_row_partitioner.h" #include "../helpers.h" #include "test_partitioner.h" #include "xgboost/data.h" -namespace xgboost { -namespace tree { -TEST(QuantileHist, Partitioner) { - size_t n_samples = 1024, n_features = 1, base_rowid = 0; +namespace xgboost::tree { +template +void TestPartitioner(bst_target_t n_targets) { + std::size_t n_samples = 1024, base_rowid = 0; + bst_feature_t n_features = 1; + Context ctx; ctx.InitAllowUnknown(Args{}); @@ -29,7 +33,7 @@ TEST(QuantileHist, Partitioner) { ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - std::vector candidates{{0, 0}}; + std::vector candidates{{0, 0}}; candidates.front().split.loss_chg = 0.4; auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads()); @@ -41,9 +45,13 @@ TEST(QuantileHist, Partitioner) { column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads()); { auto min_value = gmat.cut.MinValues()[split_ind]; - RegTree tree; + RegTree tree{n_targets, n_features}; CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; - GetSplit(&tree, min_value, &candidates); + if constexpr (std::is_same::value) { + GetSplit(&tree, min_value, &candidates); + } else { + GetMultiSplitForTest(&tree, min_value, &candidates); + } partitioner.UpdatePosition(&ctx, gmat, column_indices, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner[1].Size(), 0); @@ -53,9 +61,13 @@ TEST(QuantileHist, Partitioner) { CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; auto ptr = gmat.cut.Ptrs()[split_ind + 1]; float split_value = gmat.cut.Values().at(ptr / 2); - RegTree tree; - GetSplit(&tree, split_value, &candidates); - auto left_nidx = tree[RegTree::kRoot].LeftChild(); + RegTree tree{n_targets, n_features}; + if constexpr (std::is_same::value) { + GetSplit(&tree, split_value, &candidates); + } else { + GetMultiSplitForTest(&tree, split_value, &candidates); + } + auto left_nidx = tree.LeftChild(RegTree::kRoot); partitioner.UpdatePosition(&ctx, gmat, column_indices, candidates, &tree); auto elem = partitioner[left_nidx]; @@ -65,14 +77,17 @@ TEST(QuantileHist, Partitioner) { auto value = gmat.cut.Values().at(gmat.index[*it]); ASSERT_LE(value, split_value); } - auto right_nidx = tree[RegTree::kRoot].RightChild(); + auto right_nidx = tree.RightChild(RegTree::kRoot); elem = partitioner[right_nidx]; for (auto it = elem.begin; it != elem.end; ++it) { auto value = gmat.cut.Values().at(gmat.index[*it]); - ASSERT_GT(value, split_value) << *it; + ASSERT_GT(value, split_value); } } } } -} // namespace tree -} // namespace xgboost + +TEST(QuantileHist, Partitioner) { TestPartitioner(1); } + +TEST(QuantileHist, MultiPartitioner) { TestPartitioner(3); } +} // namespace xgboost::tree