Add approx partitioner. (#7467)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
* \file row_set.h
|
||||
@@ -77,6 +76,24 @@ class PartitionBuilder {
|
||||
return {nleft_elems, nright_elems};
|
||||
}
|
||||
|
||||
template <typename Pred>
|
||||
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
|
||||
common::Span<size_t> left_part,
|
||||
common::Span<size_t> right_part,
|
||||
Pred pred) {
|
||||
size_t* p_left_part = left_part.data();
|
||||
size_t* p_right_part = right_part.data();
|
||||
size_t nleft_elems = 0;
|
||||
size_t nright_elems = 0;
|
||||
for (auto row_id : ridx) {
|
||||
if (pred(row_id)) {
|
||||
p_left_part[nleft_elems++] = row_id;
|
||||
} else {
|
||||
p_right_part[nright_elems++] = row_id;
|
||||
}
|
||||
}
|
||||
return {nleft_elems, nright_elems};
|
||||
}
|
||||
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
||||
@@ -123,6 +140,37 @@ class PartitionBuilder {
|
||||
SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Partition tree nodes with specific range of row indices.
|
||||
*
|
||||
* \tparam Pred Predicate for whether a row should be partitioned to the left node.
|
||||
*
|
||||
* \param node_in_set The index of node in current batch of nodes.
|
||||
* \param nid The cannonical node index (node index in the tree).
|
||||
* \param range The range of input row index.
|
||||
* \param fidx Feature index.
|
||||
* \param p_row_set_collection Pointer to rows that are being partitioned.
|
||||
* \param pred A callback function that returns whether current row should be
|
||||
* partitioned to the left node, it should accept the row index as
|
||||
* input and returns a boolean value.
|
||||
*/
|
||||
template <typename Pred>
|
||||
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
bst_feature_t fidx, common::RowSetCollection* p_row_set_collection,
|
||||
Pred pred) {
|
||||
auto& row_set_collection = *p_row_set_collection;
|
||||
const size_t* p_ridx = row_set_collection[nid].begin;
|
||||
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
|
||||
common::Span<size_t> left = this->GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||
common::Span<size_t> right = this->GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||
std::pair<size_t, size_t> child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred);
|
||||
|
||||
const size_t n_left = child_nodes_sizes.first;
|
||||
const size_t n_right = child_nodes_sizes.second;
|
||||
|
||||
this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
|
||||
this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||
}
|
||||
|
||||
// allocate thread local memory, should be called for each specific task
|
||||
void AllocateForTask(size_t id) {
|
||||
|
||||
Reference in New Issue
Block a user