Partitioner for multi-target tree. (#8922)

This commit is contained in:
Jiaming Yuan 2023-03-16 18:49:34 +08:00 committed by GitHub
parent 26209a42a5
commit a093770f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 239 additions and 178 deletions

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2021-2022 by Contributors * Copyright 2021-2023 by Contributors
* \file row_set.h * \file row_set.h
* \brief Quick Utility to compute subset of rows * \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -10,6 +10,7 @@
#include <xgboost/data.h> #include <xgboost/data.h>
#include <algorithm> #include <algorithm>
#include <cstddef> // for size_t
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
@ -21,9 +22,7 @@
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
namespace xgboost { namespace xgboost::common {
namespace common {
// The builder is required for samples partition to left and rights children for set of nodes // The builder is required for samples partition to left and rights children for set of nodes
// Responsible for: // Responsible for:
// 1) Effective memory allocation for intermediate results for multi-thread work // 1) Effective memory allocation for intermediate results for multi-thread work
@ -109,18 +108,17 @@ class PartitionBuilder {
return {nleft_elems, nright_elems}; return {nleft_elems, nright_elems};
} }
template <typename BinIdxType, bool any_missing, bool any_cat> template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes, void Partition(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, const common::Range1d range, const bst_bin_t split_cond,
const bst_bin_t split_cond, GHistIndexMatrix const& gmat, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix,
const common::ColumnMatrix& column_matrix,
const RegTree& tree, const size_t* rid) { const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree[nid].SplitIndex(); bst_feature_t fid = tree.SplitIndex(nid);
bool default_left = tree[nid].DefaultLeft(); bool default_left = tree.DefaultLeft(nid);
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid); auto node_cats = tree.NodeCats(nid);
auto const& cut_values = gmat.cut.Values(); auto const& cut_values = gmat.cut.Values();
@ -190,10 +188,10 @@ class PartitionBuilder {
* worker, so we go through all the rows and mark the bit vectors on whether the decision is made * worker, so we go through all the rows and mark the bit vectors on whether the decision is made
* to go right, or if the feature value used for the split is missing. * to go right, or if the feature value used for the split is missing.
*/ */
void MaskRows(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes, template <typename ExpandEntry>
void MaskRows(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, GHistIndexMatrix const& gmat, const common::Range1d range, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid,
const RegTree& tree, const size_t* rid,
BitVector* decision_bits, BitVector* missing_bits) { BitVector* decision_bits, BitVector* missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
@ -228,8 +226,8 @@ class PartitionBuilder {
* @brief Once we've aggregated the decision and missing bits from all the workers, we can then * @brief Once we've aggregated the decision and missing bits from all the workers, we can then
* use them to partition the rows accordingly. * use them to partition the rows accordingly.
*/ */
void PartitionByMask(const size_t node_in_set, template <typename ExpandEntry>
std::vector<xgboost::tree::CPUExpandEntry> const& nodes, void PartitionByMask(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, GHistIndexMatrix const& gmat, const common::Range1d range, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const RegTree& tree, const common::ColumnMatrix& column_matrix, const RegTree& tree,
const size_t* rid, BitVector const& decision_bits, const size_t* rid, BitVector const& decision_bits,
@ -293,11 +291,11 @@ class PartitionBuilder {
} }
size_t GetNLeftElems(int nid) const { [[nodiscard]] std::size_t GetNLeftElems(int nid) const {
return left_right_nodes_sizes_[nid].first; return left_right_nodes_sizes_[nid].first;
} }
size_t GetNRightElems(int nid) const { [[nodiscard]] std::size_t GetNRightElems(int nid) const {
return left_right_nodes_sizes_[nid].second; return left_right_nodes_sizes_[nid].second;
} }
@ -349,7 +347,7 @@ class PartitionBuilder {
if (node.node_id < 0) { if (node.node_id < 0) {
return; return;
} }
CHECK(tree[node.node_id].IsLeaf()); CHECK(tree.IsLeaf(node.node_id));
if (node.begin) { // guard for empty node. if (node.begin) { // guard for empty node.
size_t ptr_offset = node.end - p_begin; size_t ptr_offset = node.end - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
@ -384,8 +382,5 @@ class PartitionBuilder {
std::vector<std::shared_ptr<BlockInfo>> mem_blocks_; std::vector<std::shared_ptr<BlockInfo>> mem_blocks_;
size_t max_n_tasks_ = 0; size_t max_n_tasks_ = 0;
}; };
} // namespace xgboost::common
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_PARTITION_BUILDER_H_ #endif // XGBOOST_COMMON_PARTITION_BUILDER_H_

View File

@ -343,8 +343,8 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.add_enum("monolithic", MultiStrategy::kMonolithic) .add_enum("monolithic", MultiStrategy::kMonolithic)
.set_default(MultiStrategy::kComposite) .set_default(MultiStrategy::kComposite)
.describe( .describe(
"Strategy used for training multi-target models. `mono` means building one single tree " "Strategy used for training multi-target models. `monolithic` means building one "
"for all targets."); "single tree for all targets.");
} }
}; };

View File

@ -1,22 +1,26 @@
/*! /**
* Copyright 2021-2022 XGBoost contributors * Copyright 2021-2023 XGBoost contributors
* \file common_row_partitioner.h * \file common_row_partitioner.h
* \brief Common partitioner logic for hist and approx methods. * \brief Common partitioner logic for hist and approx methods.
*/ */
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#include <algorithm> // std::all_of
#include <cinttypes> // std::uint32_t
#include <limits> // std::numeric_limits #include <limits> // std::numeric_limits
#include <vector> #include <vector>
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
#include "../common/linalg_op.h" // cbegin
#include "../common/numeric.h" // Iota #include "../common/numeric.h" // Iota
#include "../common/partition_builder.h" #include "../common/partition_builder.h"
#include "hist/expand_entry.h" // CPUExpandEntry #include "hist/expand_entry.h" // CPUExpandEntry
#include "xgboost/base.h"
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/linalg.h" // TensorView
namespace xgboost { namespace xgboost::tree {
namespace tree {
static constexpr size_t kPartitionBlockSize = 2048; static constexpr size_t kPartitionBlockSize = 2048;
@ -34,9 +38,10 @@ class ColumnSplitHelper {
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_)); missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
} }
template <typename ExpandEntry>
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads, void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix, GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) { std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
// When data is split by column, we don't have all the feature values in the local worker, so // When data is split by column, we don't have all the feature values in the local worker, so
// we first collect all the decisions and whether the feature is missing into bit vectors. // we first collect all the decisions and whether the feature is missing into bit vectors.
std::fill(decision_storage_.begin(), decision_storage_.end(), 0); std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
@ -97,17 +102,18 @@ class CommonRowPartitioner {
} }
} }
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree, template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) { const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
auto const& ptrs = gmat.cut.Ptrs(); auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values(); auto const& vals = gmat.cut.Values();
for (std::size_t i = 0; i < nodes.size(); ++i) { for (std::size_t i = 0; i < nodes.size(); ++i) {
bst_node_t const nid = nodes[i].nid; bst_node_t const nidx = nodes[i].nid;
bst_feature_t const fid = tree[nid].SplitIndex(); bst_feature_t const fidx = tree.SplitIndex(nidx);
const float split_pt = tree[nid].SplitCond(); float const split_pt = tree.SplitCond(nidx);
const uint32_t lower_bound = ptrs[fid]; std::uint32_t const lower_bound = ptrs[fidx];
const uint32_t upper_bound = ptrs[fid + 1]; std::uint32_t const upper_bound = ptrs[fidx + 1];
bst_bin_t split_cond = -1; bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id // convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points // split_cond = -1 indicates that split_pt is less than all known cut points
@ -121,20 +127,22 @@ class CommonRowPartitioner {
} }
} }
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree) { template <typename ExpandEntry>
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree const* p_tree) {
const size_t n_nodes = nodes.size(); const size_t n_nodes = nodes.size();
for (unsigned int i = 0; i < n_nodes; ++i) { for (unsigned int i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid; const int32_t nidx = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i); const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i); const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx));
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left,
n_left, n_right); n_right);
} }
} }
template <typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) { std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
auto const& column_matrix = gmat.Transpose(); auto const& column_matrix = gmat.Transpose();
if (column_matrix.IsInitialized()) { if (column_matrix.IsInitialized()) {
if (gmat.cut.HasCategorical()) { if (gmat.cut.HasCategorical()) {
@ -152,10 +160,10 @@ class CommonRowPartitioner {
} }
} }
template <bool any_cat> template <bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) { std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
if (column_matrix.AnyMissing()) { if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree); this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
} else { } else {
@ -163,33 +171,21 @@ class CommonRowPartitioner {
} }
} }
template <bool any_missing, bool any_cat> template <bool any_missing, bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) { std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
switch (column_matrix.GetTypeSize()) { common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) {
case common::kUint8BinsTypeSize: using T = decltype(t);
this->template UpdatePosition<uint8_t, any_missing, any_cat>(ctx, gmat, column_matrix, this->template UpdatePosition<T, any_missing, any_cat>(ctx, gmat, column_matrix, nodes,
nodes, p_tree); p_tree);
break; });
case common::kUint16BinsTypeSize:
this->template UpdatePosition<uint16_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
case common::kUint32BinsTypeSize:
this->template UpdatePosition<uint32_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
default:
// no default behavior
CHECK(false) << column_matrix.GetTypeSize();
}
} }
template <typename BinIdxType, bool any_missing, bool any_cat> template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) { std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
// 1. Find split condition for each split // 1. Find split condition for each split
size_t n_nodes = nodes.size(); size_t n_nodes = nodes.size();
@ -251,9 +247,9 @@ class CommonRowPartitioner {
AddSplitsToRowSet(nodes, p_tree); AddSplitsToRowSet(nodes, p_tree);
} }
auto const& Partitions() const { return row_set_collection_; } [[nodiscard]] auto const& Partitions() const { return row_set_collection_; }
size_t Size() const { [[nodiscard]] std::size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end()); return std::distance(row_set_collection_.begin(), row_set_collection_.end());
} }
@ -266,12 +262,29 @@ class CommonRowPartitioner {
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
} }
void LeafPartition(Context const* ctx, RegTree const& tree,
linalg::TensorView<GradientPair const, 2> gpair,
std::vector<bst_node_t>* p_out_position) const {
if (gpair.Shape(1) > 1) {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
auto sample = gpair.Slice(idx, linalg::All());
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
});
} else {
auto s = gpair.Slice(linalg::All(), 0);
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
}
}
void LeafPartition(Context const* ctx, RegTree const& tree, void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const { std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition( partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position, ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); [&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
} }
private: private:
@ -281,6 +294,5 @@ class CommonRowPartitioner {
ColumnSplitHelper column_split_helper_; ColumnSplitHelper column_split_helper_;
}; };
} // namespace tree } // namespace xgboost::tree
} // namespace xgboost
#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_

View File

@ -1,15 +1,17 @@
/**
* Copyright 2020-2023 by XGBoost contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "../../../src/common/row_set.h"
#include "../../../src/common/partition_builder.h" #include "../../../src/common/partition_builder.h"
#include "../../../src/common/row_set.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost { namespace xgboost::common {
namespace common {
TEST(PartitionBuilder, BasicTest) { TEST(PartitionBuilder, BasicTest) {
constexpr size_t kBlockSize = 16; constexpr size_t kBlockSize = 16;
constexpr size_t kNodes = 5; constexpr size_t kNodes = 5;
@ -74,6 +76,4 @@ TEST(PartitionBuilder, BasicTest) {
ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]); ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]);
} }
} }
} // namespace xgboost::common
} // namespace common
} // namespace xgboost

View File

@ -148,78 +148,5 @@ TEST(Approx, PartitionerColSplit) {
RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy, RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy,
&hess, min_value, mid_value, mid_partitioner); &hess, min_value, mid_value, mid_partitioner);
} }
namespace {
void TestLeafPartition(size_t n_samples) {
size_t const n_features = 2, base_rowid = 0;
Context ctx;
common::RowSetCollection row_set;
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0}};
candidates.front().split.loss_chg = 0.4;
RegTree tree;
std::vector<float> hess(n_samples, 0);
// emulate sampling
auto not_sampled = [](size_t i) {
size_t const kSampleFactor{3};
return i % kSampleFactor != 0;
};
for (size_t i = 0; i < hess.size(); ++i) {
if (not_sampled(i)) {
hess[i] = 1.0f;
}
}
std::vector<size_t> h_nptr;
float split_value{0};
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({Context::kCpuId, 64})) {
bst_feature_t const split_ind = 0;
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
std::vector<bst_node_t> position;
partitioner.LeafPartition(&ctx, tree, hess, &position);
std::sort(position.begin(), position.end());
size_t beg = std::distance(
position.begin(),
std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; }));
std::vector<size_t> nptr;
common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr);
std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; });
auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg);
ASSERT_EQ(nptr.size(), n_uniques + 1);
ASSERT_EQ(nptr[0], beg);
ASSERT_EQ(nptr.back(), n_samples);
h_nptr = nptr;
}
if (h_nptr.front() == n_samples) {
return;
}
ASSERT_GE(h_nptr.size(), 2);
for (auto const& page : Xy->GetBatches<SparsePage>()) {
auto batch = page.GetView();
size_t left{0};
for (size_t i = 0; i < batch.Size(); ++i) {
if (not_sampled(i) && batch[i].front().fvalue < split_value) {
left++;
}
}
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
}
}
} // anonymous namespace
TEST(Approx, LeafPartition) {
for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) {
TestLeafPartition(n_samples);
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,93 @@
/**
* Copyright 2022-2023 by XGBoost contributors.
*/
#include <gtest/gtest.h>
#include <xgboost/base.h> // for bst_node_t
#include <xgboost/context.h> // for Context
#include <algorithm> // for transform
#include <iterator> // for distance
#include <vector> // for vector
#include "../../../src/common/numeric.h" // for ==RunLengthEncode
#include "../../../src/common/row_set.h" // for RowSetCollection
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/tree/common_row_partitioner.h"
#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
#include "../helpers.h" // for RandomDataGenerator
#include "test_partitioner.h" // for GetSplit
namespace xgboost::tree {
namespace {
void TestLeafPartition(size_t n_samples) {
size_t const n_features = 2, base_rowid = 0;
Context ctx;
common::RowSetCollection row_set;
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0}};
candidates.front().split.loss_chg = 0.4;
RegTree tree;
std::vector<float> hess(n_samples, 0);
// emulate sampling
auto not_sampled = [](size_t i) {
size_t const kSampleFactor{3};
return i % kSampleFactor != 0;
};
for (size_t i = 0; i < hess.size(); ++i) {
if (not_sampled(i)) {
hess[i] = 1.0f;
}
}
std::vector<size_t> h_nptr;
float split_value{0};
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>({Context::kCpuId, 64})) {
bst_feature_t const split_ind = 0;
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
std::vector<bst_node_t> position;
partitioner.LeafPartition(&ctx, tree, hess, &position);
std::sort(position.begin(), position.end());
size_t beg = std::distance(
position.begin(),
std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; }));
std::vector<size_t> nptr;
common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr);
std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; });
auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg);
ASSERT_EQ(nptr.size(), n_uniques + 1);
ASSERT_EQ(nptr[0], beg);
ASSERT_EQ(nptr.back(), n_samples);
h_nptr = nptr;
}
if (h_nptr.front() == n_samples) {
return;
}
ASSERT_GE(h_nptr.size(), 2);
for (auto const& page : Xy->GetBatches<SparsePage>()) {
auto batch = page.GetView();
size_t left{0};
for (size_t i = 0; i < batch.Size(); ++i) {
if (not_sampled(i) && batch[i].front().fvalue < split_value) {
left++;
}
}
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
}
}
} // anonymous namespace
TEST(CommonRowPartitioner, LeafPartition) {
for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) {
TestLeafPartition(n_samples);
}
}
} // namespace xgboost::tree

View File

@ -1,17 +1,20 @@
/*! /**
* Copyright 2021-2022, XGBoost contributors. * Copyright 2021-2023 by XGBoost contributors.
*/ */
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/tree_model.h> #include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for Constant, Vector
#include <xgboost/logging.h> // for CHECK
#include <xgboost/tree_model.h> // for RegTree
#include <vector> #include <vector> // for vector
#include "../../../src/tree/hist/expand_entry.h" #include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry
namespace xgboost { namespace xgboost::tree {
namespace tree {
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) { inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
CHECK(!tree->IsMultiTarget());
tree->ExpandNode( tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
@ -21,6 +24,22 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
candidates->front().split.sindex = 0; candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31); candidates->front().split.sindex |= (1U << 31);
} }
} // namespace tree
} // namespace xgboost inline void GetMultiSplitForTest(RegTree *tree, float split_value,
std::vector<MultiExpandEntry> *candidates) {
CHECK(tree->IsMultiTarget());
auto n_targets = tree->NumTargets();
Context ctx;
linalg::Vector<float> base_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> left_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> right_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, base_weight.HostView(), left_weight.HostView(),
right_weight.HostView());
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // namespace xgboost::tree
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_

View File

@ -1,25 +1,29 @@
/*! /**
* Copyright 2018-2022 by XGBoost Contributors * Copyright 2018-2023 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <algorithm> #include <algorithm>
#include <cstddef> // for size_t
#include <string> #include <string>
#include <vector> #include <vector>
#include "../../../src/tree/common_row_partitioner.h"
#include "../../../src/tree/hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "../../../src/tree/param.h" #include "../../../src/tree/param.h"
#include "../../../src/tree/split_evaluator.h" #include "../../../src/tree/split_evaluator.h"
#include "../../../src/tree/common_row_partitioner.h"
#include "../helpers.h" #include "../helpers.h"
#include "test_partitioner.h" #include "test_partitioner.h"
#include "xgboost/data.h" #include "xgboost/data.h"
namespace xgboost { namespace xgboost::tree {
namespace tree { template <typename ExpandEntry>
TEST(QuantileHist, Partitioner) { void TestPartitioner(bst_target_t n_targets) {
size_t n_samples = 1024, n_features = 1, base_rowid = 0; std::size_t n_samples = 1024, base_rowid = 0;
bst_feature_t n_features = 1;
Context ctx; Context ctx;
ctx.InitAllowUnknown(Args{}); ctx.InitAllowUnknown(Args{});
@ -29,7 +33,7 @@ TEST(QuantileHist, Partitioner) {
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
std::vector<CPUExpandEntry> candidates{{0, 0}}; std::vector<ExpandEntry> candidates{{0, 0}};
candidates.front().split.loss_chg = 0.4; candidates.front().split.loss_chg = 0.4;
auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads()); auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads());
@ -41,9 +45,13 @@ TEST(QuantileHist, Partitioner) {
column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads()); column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads());
{ {
auto min_value = gmat.cut.MinValues()[split_ind]; auto min_value = gmat.cut.MinValues()[split_ind];
RegTree tree; RegTree tree{n_targets, n_features};
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
if constexpr (std::is_same<ExpandEntry, CPUExpandEntry>::value) {
GetSplit(&tree, min_value, &candidates); GetSplit(&tree, min_value, &candidates);
} else {
GetMultiSplitForTest(&tree, min_value, &candidates);
}
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner.Size(), 3);
ASSERT_EQ(partitioner[1].Size(), 0); ASSERT_EQ(partitioner[1].Size(), 0);
@ -53,9 +61,13 @@ TEST(QuantileHist, Partitioner) {
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false}; CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
auto ptr = gmat.cut.Ptrs()[split_ind + 1]; auto ptr = gmat.cut.Ptrs()[split_ind + 1];
float split_value = gmat.cut.Values().at(ptr / 2); float split_value = gmat.cut.Values().at(ptr / 2);
RegTree tree; RegTree tree{n_targets, n_features};
if constexpr (std::is_same<ExpandEntry, CPUExpandEntry>::value) {
GetSplit(&tree, split_value, &candidates); GetSplit(&tree, split_value, &candidates);
auto left_nidx = tree[RegTree::kRoot].LeftChild(); } else {
GetMultiSplitForTest(&tree, split_value, &candidates);
}
auto left_nidx = tree.LeftChild(RegTree::kRoot);
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
auto elem = partitioner[left_nidx]; auto elem = partitioner[left_nidx];
@ -65,14 +77,17 @@ TEST(QuantileHist, Partitioner) {
auto value = gmat.cut.Values().at(gmat.index[*it]); auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_LE(value, split_value); ASSERT_LE(value, split_value);
} }
auto right_nidx = tree[RegTree::kRoot].RightChild(); auto right_nidx = tree.RightChild(RegTree::kRoot);
elem = partitioner[right_nidx]; elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) { for (auto it = elem.begin; it != elem.end; ++it) {
auto value = gmat.cut.Values().at(gmat.index[*it]); auto value = gmat.cut.Values().at(gmat.index[*it]);
ASSERT_GT(value, split_value) << *it; ASSERT_GT(value, split_value);
} }
} }
} }
} }
} // namespace tree
} // namespace xgboost TEST(QuantileHist, Partitioner) { TestPartitioner<CPUExpandEntry>(1); }
TEST(QuantileHist, MultiPartitioner) { TestPartitioner<MultiExpandEntry>(3); }
} // namespace xgboost::tree