Partitioner for multi-target tree. (#8922)
This commit is contained in:
@@ -1,22 +1,26 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023 XGBoost contributors
|
||||
* \file common_row_partitioner.h
|
||||
* \brief Common partitioner logic for hist and approx methods.
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||
|
||||
#include <algorithm> // std::all_of
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/linalg_op.h" // cbegin
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/partition_builder.h"
|
||||
#include "hist/expand_entry.h" // CPUExpandEntry
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h" // TensorView
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
|
||||
static constexpr size_t kPartitionBlockSize = 2048;
|
||||
|
||||
@@ -34,9 +38,10 @@ class ColumnSplitHelper {
|
||||
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
|
||||
}
|
||||
|
||||
template <typename ExpandEntry>
|
||||
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
|
||||
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
// When data is split by column, we don't have all the feature values in the local worker, so
|
||||
// we first collect all the decisions and whether the feature is missing into bit vectors.
|
||||
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
|
||||
@@ -97,17 +102,18 @@ class CommonRowPartitioner {
|
||||
}
|
||||
}
|
||||
|
||||
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
|
||||
template <typename ExpandEntry>
|
||||
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
auto const& vals = gmat.cut.Values();
|
||||
|
||||
for (std::size_t i = 0; i < nodes.size(); ++i) {
|
||||
bst_node_t const nid = nodes[i].nid;
|
||||
bst_feature_t const fid = tree[nid].SplitIndex();
|
||||
const float split_pt = tree[nid].SplitCond();
|
||||
const uint32_t lower_bound = ptrs[fid];
|
||||
const uint32_t upper_bound = ptrs[fid + 1];
|
||||
bst_node_t const nidx = nodes[i].nid;
|
||||
bst_feature_t const fidx = tree.SplitIndex(nidx);
|
||||
float const split_pt = tree.SplitCond(nidx);
|
||||
std::uint32_t const lower_bound = ptrs[fidx];
|
||||
std::uint32_t const upper_bound = ptrs[fidx + 1];
|
||||
bst_bin_t split_cond = -1;
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
@@ -121,20 +127,22 @@ class CommonRowPartitioner {
|
||||
}
|
||||
}
|
||||
|
||||
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree) {
|
||||
template <typename ExpandEntry>
|
||||
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree const* p_tree) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
for (unsigned int i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
const int32_t nidx = nodes[i].nid;
|
||||
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
||||
const size_t n_right = partition_builder_.GetNRightElems(i);
|
||||
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
|
||||
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
|
||||
n_left, n_right);
|
||||
CHECK_EQ(p_tree->LeftChild(nidx) + 1, p_tree->RightChild(nidx));
|
||||
row_set_collection_.AddSplit(nidx, p_tree->LeftChild(nidx), p_tree->RightChild(nidx), n_left,
|
||||
n_right);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ExpandEntry>
|
||||
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
auto const& column_matrix = gmat.Transpose();
|
||||
if (column_matrix.IsInitialized()) {
|
||||
if (gmat.cut.HasCategorical()) {
|
||||
@@ -152,10 +160,10 @@ class CommonRowPartitioner {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool any_cat>
|
||||
template <bool any_cat, typename ExpandEntry>
|
||||
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
if (column_matrix.AnyMissing()) {
|
||||
this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
|
||||
} else {
|
||||
@@ -163,33 +171,21 @@ class CommonRowPartitioner {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool any_missing, bool any_cat>
|
||||
template <bool any_missing, bool any_cat, typename ExpandEntry>
|
||||
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case common::kUint8BinsTypeSize:
|
||||
this->template UpdatePosition<uint8_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||
nodes, p_tree);
|
||||
break;
|
||||
case common::kUint16BinsTypeSize:
|
||||
this->template UpdatePosition<uint16_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||
nodes, p_tree);
|
||||
break;
|
||||
case common::kUint32BinsTypeSize:
|
||||
this->template UpdatePosition<uint32_t, any_missing, any_cat>(ctx, gmat, column_matrix,
|
||||
nodes, p_tree);
|
||||
break;
|
||||
default:
|
||||
// no default behavior
|
||||
CHECK(false) << column_matrix.GetTypeSize();
|
||||
}
|
||||
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto t) {
|
||||
using T = decltype(t);
|
||||
this->template UpdatePosition<T, any_missing, any_cat>(ctx, gmat, column_matrix, nodes,
|
||||
p_tree);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename BinIdxType, bool any_missing, bool any_cat>
|
||||
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
|
||||
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
|
||||
const common::ColumnMatrix& column_matrix,
|
||||
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
std::vector<ExpandEntry> const& nodes, RegTree const* p_tree) {
|
||||
// 1. Find split condition for each split
|
||||
size_t n_nodes = nodes.size();
|
||||
|
||||
@@ -251,9 +247,9 @@ class CommonRowPartitioner {
|
||||
AddSplitsToRowSet(nodes, p_tree);
|
||||
}
|
||||
|
||||
auto const& Partitions() const { return row_set_collection_; }
|
||||
[[nodiscard]] auto const& Partitions() const { return row_set_collection_; }
|
||||
|
||||
size_t Size() const {
|
||||
[[nodiscard]] std::size_t Size() const {
|
||||
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
||||
}
|
||||
|
||||
@@ -266,12 +262,29 @@ class CommonRowPartitioner {
|
||||
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
|
||||
}
|
||||
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree,
|
||||
linalg::TensorView<GradientPair const, 2> gpair,
|
||||
std::vector<bst_node_t>* p_out_position) const {
|
||||
if (gpair.Shape(1) > 1) {
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
|
||||
auto sample = gpair.Slice(idx, linalg::All());
|
||||
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
|
||||
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
|
||||
});
|
||||
} else {
|
||||
auto s = gpair.Slice(linalg::All(), 0);
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
|
||||
}
|
||||
}
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree,
|
||||
common::Span<GradientPair const> gpair,
|
||||
std::vector<bst_node_t>* p_out_position) const {
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
||||
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -281,6 +294,5 @@ class CommonRowPartitioner {
|
||||
ColumnSplitHelper column_split_helper_;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||
|
||||
Reference in New Issue
Block a user