Partitioner for multi-target tree. (#8922)

This commit is contained in:
Jiaming Yuan
2023-03-16 18:49:34 +08:00
committed by GitHub
parent 26209a42a5
commit a093770f36
8 changed files with 239 additions and 178 deletions

View File

@@ -1,22 +1,26 @@
/*!
* Copyright 2021-2022 XGBoost contributors
/**
* Copyright 2021-2023 XGBoost contributors
* \file common_row_partitioner.h
* \brief Common partitioner logic for hist and approx methods.
*/
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#include <algorithm> // std::all_of
#include <cinttypes> // std::uint32_t
#include <limits> // std::numeric_limits
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/linalg_op.h" // cbegin
#include "../common/numeric.h" // Iota
#include "../common/partition_builder.h"
#include "hist/expand_entry.h" // CPUExpandEntry
#include "xgboost/base.h"
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h" // TensorView
namespace xgboost {
namespace tree {
namespace xgboost::tree {
static constexpr size_t kPartitionBlockSize = 2048;
@@ -34,9 +38,10 @@ class ColumnSplitHelper {
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
}
template <typename ExpandEntry>
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
// When data is split by column, we don't have all the feature values in the local worker, so
// we first collect all the decisions and whether the feature is missing into bit vectors.
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
@@ -97,17 +102,18 @@ class CommonRowPartitioner {
}
}
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();
for (std::size_t i = 0; i < nodes.size(); ++i) {
bst_node_t const nid = nodes[i].nid;
bst_feature_t const fid = tree[nid].SplitIndex();
const float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = ptrs[fid];
const uint32_t upper_bound = ptrs[fid + 1];
bst_node_t const nidx = nodes[i].nid;
bst_feature_t const fidx = tree.SplitIndex(nidx);
float const split_pt = tree.SplitCond(nidx);
std::uint32_t const lower_bound = ptrs[fidx];
std::uint32_t const upper_bound = ptrs[fidx + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
@@ -121,20 +127,22 @@ class CommonRowPartitioner {
}
}
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree) {
template <typename ExpandEntry>
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree const* p_tree) {
const size_t n_nodes = nodes.size();
for (unsigned int i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid;
const int32_t nidx = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
n_left, n_right);
CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx));
row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left,
n_right);
}
}
template <typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
auto const& column_matrix = gmat.Transpose();
if (column_matrix.IsInitialized()) {
if (gmat.cut.HasCategorical()) {
@@ -152,10 +160,10 @@ class CommonRowPartitioner {
}
}
template <bool any_cat>
template <bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
} else {
@@ -163,33 +171,21 @@ class CommonRowPartitioner {
}
}
template <bool any_missing, bool any_cat>
template <bool any_missing, bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
switch (column_matrix.GetTypeSize()) {
case common::kUint8BinsTypeSize:
this->template UpdatePosition<uint8_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
case common::kUint16BinsTypeSize:
this->template UpdatePosition<uint16_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
case common::kUint32BinsTypeSize:
this->template UpdatePosition<uint32_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
default:
// no default behavior
CHECK(false) << column_matrix.GetTypeSize();
}
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) {
using T = decltype(t);
this->template UpdatePosition<T, any_missing, any_cat>(ctx, gmat, column_matrix, nodes,
p_tree);
});
}
template <typename BinIdxType, bool any_missing, bool any_cat>
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
// 1. Find split condition for each split
size_t n_nodes = nodes.size();
@@ -251,9 +247,9 @@ class CommonRowPartitioner {
AddSplitsToRowSet(nodes, p_tree);
}
auto const& Partitions() const { return row_set_collection_; }
[[nodiscard]] auto const& Partitions() const { return row_set_collection_; }
size_t Size() const {
[[nodiscard]] std::size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
}
@@ -266,12 +262,29 @@ class CommonRowPartitioner {
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
}
void LeafPartition(Context const* ctx, RegTree const& tree,
linalg::TensorView<GradientPair const, 2> gpair,
std::vector<bst_node_t>* p_out_position) const {
if (gpair.Shape(1) > 1) {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
auto sample = gpair.Slice(idx, linalg::All());
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
});
} else {
auto s = gpair.Slice(linalg::All(), 0);
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
}
}
void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
}
private:
@@ -281,6 +294,5 @@ class CommonRowPartitioner {
ColumnSplitHelper column_split_helper_;
};
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_