From eee527d2647d22a1f68d804672c78f82da4bcd31 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 27 Nov 2021 15:22:06 +0800 Subject: [PATCH] Add approx partitioner. (#7467) --- src/common/partition_builder.h | 50 ++++++++++- src/tree/updater_approx.h | 146 +++++++++++++++++++++++++++++++++ tests/cpp/tree/test_approx.cc | 76 +++++++++++++++++ 3 files changed, 271 insertions(+), 1 deletion(-) create mode 100644 src/tree/updater_approx.h create mode 100644 tests/cpp/tree/test_approx.cc diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 98612359e..5235ea3b9 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,4 +1,3 @@ - /*! * Copyright 2021 by Contributors * \file row_set.h @@ -77,6 +76,24 @@ class PartitionBuilder { return {nleft_elems, nright_elems}; } + template + inline std::pair PartitionRangeKernel(common::Span ridx, + common::Span left_part, + common::Span right_part, + Pred pred) { + size_t* p_left_part = left_part.data(); + size_t* p_right_part = right_part.data(); + size_t nleft_elems = 0; + size_t nright_elems = 0; + for (auto row_id : ridx) { + if (pred(row_id)) { + p_left_part[nleft_elems++] = row_id; + } else { + p_right_part[nright_elems++] = row_id; + } + } + return {nleft_elems, nright_elems}; + } template void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, @@ -123,6 +140,37 @@ class PartitionBuilder { SetNRightElems(node_in_set, range.begin(), range.end(), n_right); } + /** + * \brief Partition tree nodes with specific range of row indices. + * + * \tparam Pred Predicate for whether a row should be partitioned to the left node. + * + * \param node_in_set The index of node in current batch of nodes. + * \param nid The cannonical node index (node index in the tree). + * \param range The range of input row index. + * \param fidx Feature index. + * \param p_row_set_collection Pointer to rows that are being partitioned. + * \param pred A callback function that returns whether current row should be + * partitioned to the left node, it should accept the row index as + * input and returns a boolean value. + */ + template + void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range, + bst_feature_t fidx, common::RowSetCollection* p_row_set_collection, + Pred pred) { + auto& row_set_collection = *p_row_set_collection; + const size_t* p_ridx = row_set_collection[nid].begin; + common::Span ridx(p_ridx + range.begin(), p_ridx + range.end()); + common::Span left = this->GetLeftBuffer(node_in_set, range.begin(), range.end()); + common::Span right = this->GetRightBuffer(node_in_set, range.begin(), range.end()); + std::pair child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred); + + const size_t n_left = child_nodes_sizes.first; + const size_t n_right = child_nodes_sizes.second; + + this->SetNLeftElems(node_in_set, range.begin(), range.end(), n_left); + this->SetNRightElems(node_in_set, range.begin(), range.end(), n_right); + } // allocate thread local memory, should be called for each specific task void AllocateForTask(size_t id) { diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h new file mode 100644 index 000000000..5e16f568f --- /dev/null +++ b/src/tree/updater_approx.h @@ -0,0 +1,146 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief Implementation for the approx tree method. + */ +#ifndef XGBOOST_TREE_UPDATER_APPROX_H_ +#define XGBOOST_TREE_UPDATER_APPROX_H_ + +#include +#include +#include + +#include "../common/partition_builder.h" +#include "../common/random.h" +#include "constraints.h" +#include "driver.h" +#include "hist/evaluate_splits.h" +#include "hist/expand_entry.h" +#include "hist/param.h" +#include "param.h" +#include "xgboost/json.h" +#include "xgboost/tree_updater.h" + +namespace xgboost { +namespace tree { +class ApproxRowPartitioner { + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + common::RowSetCollection row_set_collection_; + + public: + bst_row_t base_rowid = 0; + + static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index, + std::vector const &cut_ptrs, + std::vector const &cut_values) { + int32_t gidx = -1; + auto const &row_ptr = index.row_ptr; + auto get_rid = [&](size_t ridx) { return row_ptr[ridx - index.base_rowid]; }; + + if (index.IsDense()) { + gidx = index.index[get_rid(ridx) + fidx]; + } else { + auto begin = get_rid(ridx); + auto end = get_rid(ridx + 1); + auto f_begin = cut_ptrs[fidx]; + auto f_end = cut_ptrs[fidx + 1]; + gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); + } + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + return cut_values[gidx]; + } + + public: + void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index, + std::vector const &candidates, RegTree const *p_tree) { + size_t n_nodes = candidates.size(); + + auto const &cut_values = index.cut.Values(); + auto const &cut_ptrs = index.cut.Ptrs(); + + common::BlockedSpace2d space{n_nodes, + [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + int32_t nid = candidate.nid; + return row_set_collection_[nid].Size(); + }, + kPartitionBlockSize}; + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr; + auto categories = p_tree->GetCategoriesMatrix().categories; + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + auto is_cat = candidate.split.is_cat; + const int32_t nid = candidate.nid; + auto fidx = candidate.split.SplitIndex(); + const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin()); + partition_builder_.AllocateForTask(task_id); + partition_builder_.PartitionRange( + node_in_set, nid, r, fidx, &row_set_collection_, [&](size_t row_id) { + auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values); + if (std::isnan(cut_value)) { + return candidate.split.DefaultLeft(); + } + bst_node_t nidx = candidate.nid; + auto segment = node_ptr[nidx]; + auto node_cats = categories.subspan(segment.beg, segment.size); + bool go_left = true; + if (is_cat) { + go_left = common::Decision(node_cats, common::AsCat(cut_value)); + } else { + go_left = cut_value <= candidate.split.split_value; + } + return go_left; + }); + }); + + partition_builder_.CalculateRowOffsets(); + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + auto candidate = candidates[node_in_set]; + const int32_t nid = candidate.nid; + partition_builder_.MergeToArray(node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + for (size_t i = 0; i < candidates.size(); ++i) { + auto const &candidate = candidates[i]; + auto nidx = candidate.nid; + auto n_left = partition_builder_.GetNLeftElems(i); + auto n_right = partition_builder_.GetNRightElems(i); + CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size()); + bst_node_t left_nidx = (*p_tree)[nidx].LeftChild(); + bst_node_t right_nidx = (*p_tree)[nidx].RightChild(); + row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right); + } + } + + auto const &Partitions() const { return row_set_collection_; } + + auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } + auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } + + size_t Size() const { + return std::distance(row_set_collection_.begin(), row_set_collection_.end()); + } + + ApproxRowPartitioner() = default; + explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid) + : base_rowid{_base_rowid} { + row_set_collection_.Clear(); + auto p_positions = row_set_collection_.Data(); + p_positions->resize(num_row); + std::iota(p_positions->begin(), p_positions->end(), base_rowid); + row_set_collection_.Init(); + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_APPROX_H_ diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc new file mode 100644 index 000000000..680ed9d4b --- /dev/null +++ b/tests/cpp/tree/test_approx.cc @@ -0,0 +1,76 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include + +#include "../../../src/tree/updater_approx.h" +#include "../helpers.h" + +namespace xgboost { +namespace tree { +TEST(Approx, Partitioner) { + size_t n_samples = 1024, n_features = 1, base_rowid = 0; + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + ASSERT_EQ(partitioner.base_rowid, base_rowid); + ASSERT_EQ(partitioner.Size(), 1); + ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); + + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + GenericParameter ctx; + ctx.InitAllowUnknown(Args{}); + std::vector candidates{{0, 0, 0.4}}; + + for (auto const &page : Xy->GetBatches({GenericParameter::kCpuId, 64})) { + bst_feature_t split_ind = 0; + { + auto min_value = page.cut.MinValues()[split_ind]; + RegTree tree; + tree.ExpandNode( + /*nid=*/0, /*split_index=*/0, /*split_value=*/min_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + candidates.front().split.split_value = min_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + ASSERT_EQ(partitioner.Size(), 3); + ASSERT_EQ(partitioner[1].Size(), 0); + ASSERT_EQ(partitioner[2].Size(), n_samples); + } + { + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + auto ptr = page.cut.Ptrs()[split_ind + 1]; + float split_value = page.cut.Values().at(ptr / 2); + RegTree tree; + tree.ExpandNode( + /*nid=*/RegTree::kRoot, /*split_index=*/split_ind, + /*split_value=*/split_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + auto left_nidx = tree[RegTree::kRoot].LeftChild(); + candidates.front().split.split_value = split_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + + auto elem = partitioner[left_nidx]; + ASSERT_LT(elem.Size(), n_samples); + ASSERT_GT(elem.Size(), 1); + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_LE(value, split_value); + } + auto right_nidx = tree[RegTree::kRoot].RightChild(); + elem = partitioner[right_nidx]; + for (auto it = elem.begin; it != elem.end; ++it) { + auto value = page.cut.Values().at(page.index[*it]); + ASSERT_GT(value, split_value) << *it; + } + } + } +} +} // namespace tree +} // namespace xgboost