Add approx partitioner. (#7467)

This commit is contained in:
Jiaming Yuan
2021-11-27 15:22:06 +08:00
committed by GitHub
parent 85cbd32c5a
commit eee527d264
3 changed files with 271 additions and 1 deletions

View File

@@ -1,4 +1,3 @@
/*!
* Copyright 2021 by Contributors
* \file row_set.h
@@ -77,6 +76,24 @@ class PartitionBuilder {
return {nleft_elems, nright_elems};
}
template <typename Pred>
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
Pred pred) {
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
for (auto row_id : ridx) {
if (pred(row_id)) {
p_left_part[nleft_elems++] = row_id;
} else {
p_right_part[nright_elems++] = row_id;
}
}
return {nleft_elems, nright_elems};
}
template <typename BinIdxType, bool any_missing>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
@@ -123,6 +140,37 @@ class PartitionBuilder {
SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
}
/**
* \brief Partition tree nodes with specific range of row indices.
*
* \tparam Pred Predicate for whether a row should be partitioned to the left node.
*
* \param node_in_set The index of node in current batch of nodes.
* \param nid The cannonical node index (node index in the tree).
* \param range The range of input row index.
* \param fidx Feature index.
* \param p_row_set_collection Pointer to rows that are being partitioned.
* \param pred A callback function that returns whether current row should be
* partitioned to the left node, it should accept the row index as
* input and returns a boolean value.
*/
template <typename Pred>
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
bst_feature_t fidx, common::RowSetCollection* p_row_set_collection,
Pred pred) {
auto& row_set_collection = *p_row_set_collection;
const size_t* p_ridx = row_set_collection[nid].begin;
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
common::Span<size_t> left = this->GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = this->GetRightBuffer(node_in_set, range.begin(), range.end());
std::pair<size_t, size_t> child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred);
const size_t n_left = child_nodes_sizes.first;
const size_t n_right = child_nodes_sizes.second;
this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
}
// allocate thread local memory, should be called for each specific task
void AllocateForTask(size_t id) {