Add approx partitioner. (#7467)
This commit is contained in:
parent
85cbd32c5a
commit
eee527d264
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Copyright 2021 by Contributors
|
* Copyright 2021 by Contributors
|
||||||
* \file row_set.h
|
* \file row_set.h
|
||||||
@ -77,6 +76,24 @@ class PartitionBuilder {
|
|||||||
return {nleft_elems, nright_elems};
|
return {nleft_elems, nright_elems};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Pred>
|
||||||
|
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx,
|
||||||
|
common::Span<size_t> left_part,
|
||||||
|
common::Span<size_t> right_part,
|
||||||
|
Pred pred) {
|
||||||
|
size_t* p_left_part = left_part.data();
|
||||||
|
size_t* p_right_part = right_part.data();
|
||||||
|
size_t nleft_elems = 0;
|
||||||
|
size_t nright_elems = 0;
|
||||||
|
for (auto row_id : ridx) {
|
||||||
|
if (pred(row_id)) {
|
||||||
|
p_left_part[nleft_elems++] = row_id;
|
||||||
|
} else {
|
||||||
|
p_right_part[nright_elems++] = row_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {nleft_elems, nright_elems};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename BinIdxType, bool any_missing>
|
template <typename BinIdxType, bool any_missing>
|
||||||
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
|
||||||
@ -123,6 +140,37 @@ class PartitionBuilder {
|
|||||||
SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Partition tree nodes with specific range of row indices.
|
||||||
|
*
|
||||||
|
* \tparam Pred Predicate for whether a row should be partitioned to the left node.
|
||||||
|
*
|
||||||
|
* \param node_in_set The index of node in current batch of nodes.
|
||||||
|
* \param nid The cannonical node index (node index in the tree).
|
||||||
|
* \param range The range of input row index.
|
||||||
|
* \param fidx Feature index.
|
||||||
|
* \param p_row_set_collection Pointer to rows that are being partitioned.
|
||||||
|
* \param pred A callback function that returns whether current row should be
|
||||||
|
* partitioned to the left node, it should accept the row index as
|
||||||
|
* input and returns a boolean value.
|
||||||
|
*/
|
||||||
|
template <typename Pred>
|
||||||
|
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||||
|
bst_feature_t fidx, common::RowSetCollection* p_row_set_collection,
|
||||||
|
Pred pred) {
|
||||||
|
auto& row_set_collection = *p_row_set_collection;
|
||||||
|
const size_t* p_ridx = row_set_collection[nid].begin;
|
||||||
|
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
|
||||||
|
common::Span<size_t> left = this->GetLeftBuffer(node_in_set, range.begin(), range.end());
|
||||||
|
common::Span<size_t> right = this->GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||||
|
std::pair<size_t, size_t> child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred);
|
||||||
|
|
||||||
|
const size_t n_left = child_nodes_sizes.first;
|
||||||
|
const size_t n_right = child_nodes_sizes.second;
|
||||||
|
|
||||||
|
this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
|
||||||
|
this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||||
|
}
|
||||||
|
|
||||||
// allocate thread local memory, should be called for each specific task
|
// allocate thread local memory, should be called for each specific task
|
||||||
void AllocateForTask(size_t id) {
|
void AllocateForTask(size_t id) {
|
||||||
|
|||||||
146
src/tree/updater_approx.h
Normal file
146
src/tree/updater_approx.h
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*
|
||||||
|
* \brief Implementation for the approx tree method.
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_UPDATER_APPROX_H_
|
||||||
|
#define XGBOOST_TREE_UPDATER_APPROX_H_
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/partition_builder.h"
|
||||||
|
#include "../common/random.h"
|
||||||
|
#include "constraints.h"
|
||||||
|
#include "driver.h"
|
||||||
|
#include "hist/evaluate_splits.h"
|
||||||
|
#include "hist/expand_entry.h"
|
||||||
|
#include "hist/param.h"
|
||||||
|
#include "param.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
class ApproxRowPartitioner {
|
||||||
|
static constexpr size_t kPartitionBlockSize = 2048;
|
||||||
|
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||||
|
common::RowSetCollection row_set_collection_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
bst_row_t base_rowid = 0;
|
||||||
|
|
||||||
|
static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index,
|
||||||
|
std::vector<uint32_t> const &cut_ptrs,
|
||||||
|
std::vector<float> const &cut_values) {
|
||||||
|
int32_t gidx = -1;
|
||||||
|
auto const &row_ptr = index.row_ptr;
|
||||||
|
auto get_rid = [&](size_t ridx) { return row_ptr[ridx - index.base_rowid]; };
|
||||||
|
|
||||||
|
if (index.IsDense()) {
|
||||||
|
gidx = index.index[get_rid(ridx) + fidx];
|
||||||
|
} else {
|
||||||
|
auto begin = get_rid(ridx);
|
||||||
|
auto end = get_rid(ridx + 1);
|
||||||
|
auto f_begin = cut_ptrs[fidx];
|
||||||
|
auto f_end = cut_ptrs[fidx + 1];
|
||||||
|
gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end);
|
||||||
|
}
|
||||||
|
if (gidx == -1) {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
return cut_values[gidx];
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index,
|
||||||
|
std::vector<CPUExpandEntry> const &candidates, RegTree const *p_tree) {
|
||||||
|
size_t n_nodes = candidates.size();
|
||||||
|
|
||||||
|
auto const &cut_values = index.cut.Values();
|
||||||
|
auto const &cut_ptrs = index.cut.Ptrs();
|
||||||
|
|
||||||
|
common::BlockedSpace2d space{n_nodes,
|
||||||
|
[&](size_t node_in_set) {
|
||||||
|
auto candidate = candidates[node_in_set];
|
||||||
|
int32_t nid = candidate.nid;
|
||||||
|
return row_set_collection_[nid].Size();
|
||||||
|
},
|
||||||
|
kPartitionBlockSize};
|
||||||
|
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
||||||
|
auto candidate = candidates[node_in_set];
|
||||||
|
const int32_t nid = candidate.nid;
|
||||||
|
const size_t size = row_set_collection_[nid].Size();
|
||||||
|
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
||||||
|
return n_tasks;
|
||||||
|
});
|
||||||
|
auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr;
|
||||||
|
auto categories = p_tree->GetCategoriesMatrix().categories;
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
auto candidate = candidates[node_in_set];
|
||||||
|
auto is_cat = candidate.split.is_cat;
|
||||||
|
const int32_t nid = candidate.nid;
|
||||||
|
auto fidx = candidate.split.SplitIndex();
|
||||||
|
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin());
|
||||||
|
partition_builder_.AllocateForTask(task_id);
|
||||||
|
partition_builder_.PartitionRange(
|
||||||
|
node_in_set, nid, r, fidx, &row_set_collection_, [&](size_t row_id) {
|
||||||
|
auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values);
|
||||||
|
if (std::isnan(cut_value)) {
|
||||||
|
return candidate.split.DefaultLeft();
|
||||||
|
}
|
||||||
|
bst_node_t nidx = candidate.nid;
|
||||||
|
auto segment = node_ptr[nidx];
|
||||||
|
auto node_cats = categories.subspan(segment.beg, segment.size);
|
||||||
|
bool go_left = true;
|
||||||
|
if (is_cat) {
|
||||||
|
go_left = common::Decision(node_cats, common::AsCat(cut_value));
|
||||||
|
} else {
|
||||||
|
go_left = cut_value <= candidate.split.split_value;
|
||||||
|
}
|
||||||
|
return go_left;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
partition_builder_.CalculateRowOffsets();
|
||||||
|
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||||
|
auto candidate = candidates[node_in_set];
|
||||||
|
const int32_t nid = candidate.nid;
|
||||||
|
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
||||||
|
const_cast<size_t *>(row_set_collection_[nid].begin));
|
||||||
|
});
|
||||||
|
for (size_t i = 0; i < candidates.size(); ++i) {
|
||||||
|
auto const &candidate = candidates[i];
|
||||||
|
auto nidx = candidate.nid;
|
||||||
|
auto n_left = partition_builder_.GetNLeftElems(i);
|
||||||
|
auto n_right = partition_builder_.GetNRightElems(i);
|
||||||
|
CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size());
|
||||||
|
bst_node_t left_nidx = (*p_tree)[nidx].LeftChild();
|
||||||
|
bst_node_t right_nidx = (*p_tree)[nidx].RightChild();
|
||||||
|
row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const &Partitions() const { return row_set_collection_; }
|
||||||
|
|
||||||
|
auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
|
||||||
|
auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
||||||
|
|
||||||
|
size_t Size() const {
|
||||||
|
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
ApproxRowPartitioner() = default;
|
||||||
|
explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid)
|
||||||
|
: base_rowid{_base_rowid} {
|
||||||
|
row_set_collection_.Clear();
|
||||||
|
auto p_positions = row_set_collection_.Data();
|
||||||
|
p_positions->resize(num_row);
|
||||||
|
std::iota(p_positions->begin(), p_positions->end(), base_rowid);
|
||||||
|
row_set_collection_.Init();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TREE_UPDATER_APPROX_H_
|
||||||
76
tests/cpp/tree/test_approx.cc
Normal file
76
tests/cpp/tree/test_approx.cc
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/tree/updater_approx.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
TEST(Approx, Partitioner) {
|
||||||
|
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||||
|
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||||
|
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||||
|
ASSERT_EQ(partitioner.Size(), 1);
|
||||||
|
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||||
|
|
||||||
|
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||||
|
GenericParameter ctx;
|
||||||
|
ctx.InitAllowUnknown(Args{});
|
||||||
|
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||||
|
|
||||||
|
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 64})) {
|
||||||
|
bst_feature_t split_ind = 0;
|
||||||
|
{
|
||||||
|
auto min_value = page.cut.MinValues()[split_ind];
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/0, /*split_index=*/0, /*split_value=*/min_value,
|
||||||
|
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f,
|
||||||
|
/*right_sum=*/0.0f);
|
||||||
|
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||||
|
candidates.front().split.split_value = min_value;
|
||||||
|
candidates.front().split.sindex = 0;
|
||||||
|
candidates.front().split.sindex |= (1U << 31);
|
||||||
|
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||||
|
ASSERT_EQ(partitioner.Size(), 3);
|
||||||
|
ASSERT_EQ(partitioner[1].Size(), 0);
|
||||||
|
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||||
|
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||||
|
float split_value = page.cut.Values().at(ptr / 2);
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/RegTree::kRoot, /*split_index=*/split_ind,
|
||||||
|
/*split_value=*/split_value,
|
||||||
|
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||||
|
/*left_sum=*/0.0f,
|
||||||
|
/*right_sum=*/0.0f);
|
||||||
|
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||||
|
candidates.front().split.split_value = split_value;
|
||||||
|
candidates.front().split.sindex = 0;
|
||||||
|
candidates.front().split.sindex |= (1U << 31);
|
||||||
|
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||||
|
|
||||||
|
auto elem = partitioner[left_nidx];
|
||||||
|
ASSERT_LT(elem.Size(), n_samples);
|
||||||
|
ASSERT_GT(elem.Size(), 1);
|
||||||
|
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||||
|
auto value = page.cut.Values().at(page.index[*it]);
|
||||||
|
ASSERT_LE(value, split_value);
|
||||||
|
}
|
||||||
|
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||||
|
elem = partitioner[right_nidx];
|
||||||
|
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||||
|
auto value = page.cut.Values().at(page.index[*it]);
|
||||||
|
ASSERT_GT(value, split_value) << *it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user