diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index fe999ef80..f121b7a46 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -103,15 +103,18 @@ class SparseColumnIter : public Column { template class DenseColumnIter : public Column { + public: + using ByteType = bool; + private: using Base = Column; /* flags for missing values in dense columns */ - std::vector const& missing_flags_; + std::vector const& missing_flags_; size_t feature_offset_; public: explicit DenseColumnIter(common::Span index, bst_bin_t index_base, - std::vector const& missing_flags, size_t feature_offset) + std::vector const& missing_flags, size_t feature_offset) : Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {} DenseColumnIter(DenseColumnIter const&) = delete; DenseColumnIter(DenseColumnIter&&) = default; @@ -153,6 +156,7 @@ class ColumnMatrix { } public: + using ByteType = bool; // get number of features bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } @@ -195,6 +199,8 @@ class ColumnMatrix { } } + bool IsInitialized() const { return !type_.empty(); } + /** * \brief Push batch of data for Quantile DMatrix support. * @@ -352,6 +358,13 @@ class ColumnMatrix { fi->Read(&row_ind_); fi->Read(&feature_offsets_); + + std::vector missing; + fi->Read(&missing); + missing_flags_.resize(missing.size()); + std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(), + [](std::uint8_t flag) { return !!flag; }); + index_base_ = index_base; #if !DMLC_LITTLE_ENDIAN std::underlying_type::type v; @@ -386,6 +399,11 @@ class ColumnMatrix { #endif // !DMLC_LITTLE_ENDIAN write_vec(row_ind_); write_vec(feature_offsets_); + // dmlc can not handle bool vector + std::vector missing(missing_flags_.size()); + std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(), + [](bool flag) { return static_cast(flag); }); + write_vec(missing); #if !DMLC_LITTLE_ENDIAN auto v = static_cast::type>(bins_type_size_); @@ -413,7 +431,7 @@ class ColumnMatrix { // index_base_[fid]: least bin id for feature fid uint32_t const* index_base_; - std::vector missing_flags_; + std::vector missing_flags_; BinTypeSize bins_type_size_; bool any_missing_; }; diff --git a/src/common/numeric.h b/src/common/numeric.h index cff3e8a12..e839c7119 100644 --- a/src/common/numeric.h +++ b/src/common/numeric.h @@ -4,6 +4,8 @@ #ifndef XGBOOST_COMMON_NUMERIC_H_ #define XGBOOST_COMMON_NUMERIC_H_ +#include // OMPException + #include // std::max #include // std::iterator_traits #include @@ -106,6 +108,26 @@ inline double Reduce(Context const*, HostDeviceVector const&) { * \brief Reduction with summation. */ double Reduce(Context const* ctx, HostDeviceVector const& values); + +template +void Iota(Context const* ctx, It first, It last, + typename std::iterator_traits::value_type const& value) { + auto n = std::distance(first, last); + std::int32_t n_threads = ctx->Threads(); + const size_t block_size = n / n_threads + !!(n % n_threads); + dmlc::OMPException exc; +#pragma omp parallel num_threads(n_threads) + { + exc.Run([&]() { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * block_size; + const size_t iend = std::min(ibegin + block_size, static_cast(n)); + for (size_t i = ibegin; i < iend; ++i) { + first[i] = i + value; + } + }); + } +} } // namespace common } // namespace xgboost diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index b19dd0ae7..34864ee90 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -17,6 +17,7 @@ #include "categorical.h" #include "column_matrix.h" +#include "../tree/hist/expand_entry.h" #include "xgboost/generic_parameters.h" #include "xgboost/tree_model.h" @@ -107,14 +108,17 @@ class PartitionBuilder { } template - void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, + void Partition(const size_t node_in_set, std::vector const &nodes, + const common::Range1d range, const bst_bin_t split_cond, GHistIndexMatrix const& gmat, - const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { + const common::ColumnMatrix& column_matrix, + const RegTree& tree, const size_t* rid) { common::Span rid_span(rid + range.begin(), rid + range.end()); common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); - const bst_uint fid = tree[nid].SplitIndex(); - const bool default_left = tree[nid].DefaultLeft(); + std::size_t nid = nodes[node_in_set].nid; + bst_feature_t fid = tree[nid].SplitIndex(); + bool default_left = tree[nid].DefaultLeft(); bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; auto node_cats = tree.NodeCats(nid); @@ -122,19 +126,24 @@ class PartitionBuilder { auto const& cut_values = gmat.cut.Values(); auto const& cut_ptrs = gmat.cut.Ptrs(); - auto pred = [&](auto ridx, auto bin_id) { + auto gidx_calc = [&](auto ridx) { + auto begin = gmat.RowIdx(ridx); + if (gmat.IsDense()) { + return static_cast(index[begin + fid]); + } + auto end = gmat.RowIdx(ridx + 1); + auto f_begin = cut_ptrs[fid]; + auto f_end = cut_ptrs[fid + 1]; + // bypassing the column matrix as we need the cut value instead of bin idx for categorical + // features. + return BinarySearchBin(begin, end, index, f_begin, f_end); + }; + + auto pred_hist = [&](auto ridx, auto bin_id) { if (any_cat && is_cat) { - auto begin = gmat.RowIdx(ridx); - auto end = gmat.RowIdx(ridx + 1); - auto f_begin = cut_ptrs[fid]; - auto f_end = cut_ptrs[fid + 1]; - // bypassing the column matrix as we need the cut value instead of bin idx for categorical - // features. - auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end); - bool go_left; - if (gidx == -1) { - go_left = default_left; - } else { + auto gidx = gidx_calc(ridx); + bool go_left = default_left; + if (gidx > -1) { go_left = Decision(node_cats, cut_values[gidx], default_left); } return go_left; @@ -143,25 +152,43 @@ class PartitionBuilder { } }; - std::pair child_nodes_sizes; - if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { - auto column = column_matrix.DenseColumn(fid); - if (default_left) { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred); - } else { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred); + auto pred_approx = [&](auto ridx) { + auto gidx = gidx_calc(ridx); + bool go_left = default_left; + if (gidx > -1) { + if (is_cat) { + go_left = Decision(node_cats, cut_values[gidx], default_left); + } else { + go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value; + } } + return go_left; + }; + + std::pair child_nodes_sizes; + if (!column_matrix.IsInitialized()) { + child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx); } else { - CHECK_EQ(any_missing, true); - auto column = column_matrix.SparseColumn(fid, rid_span.front() - gmat.base_rowid); - if (default_left) { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred); + if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) { + auto column = column_matrix.DenseColumn(fid); + if (default_left) { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } else { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } } else { - child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, - gmat.base_rowid, pred); + CHECK_EQ(any_missing, true); + auto column = + column_matrix.SparseColumn(fid, rid_span.front() - gmat.base_rowid); + if (default_left) { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } else { + child_nodes_sizes = PartitionKernel(&column, rid_span, left, right, + gmat.base_rowid, pred_hist); + } } } @@ -172,37 +199,6 @@ class PartitionBuilder { SetNRightElems(node_in_set, range.begin(), n_right); } - /** - * \brief Partition tree nodes with specific range of row indices. - * - * \tparam Pred Predicate for whether a row should be partitioned to the left node. - * - * \param node_in_set The index of node in current batch of nodes. - * \param nid The canonical node index (node index in the tree). - * \param range The range of input row index. - * \param fidx Feature index. - * \param p_row_set_collection Pointer to rows that are being partitioned. - * \param pred A callback function that returns whether current row should be - * partitioned to the left node, it should accept the row index as - * input and returns a boolean value. - */ - template - void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range, - common::RowSetCollection* p_row_set_collection, Pred pred) { - auto& row_set_collection = *p_row_set_collection; - const size_t* p_ridx = row_set_collection[nid].begin; - common::Span ridx(p_ridx + range.begin(), p_ridx + range.end()); - common::Span left = this->GetLeftBuffer(node_in_set, range.begin(), range.end()); - common::Span right = this->GetRightBuffer(node_in_set, range.begin(), range.end()); - std::pair child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred); - - const size_t n_left = child_nodes_sizes.first; - const size_t n_right = child_nodes_sizes.second; - - this->SetNLeftElems(node_in_set, range.begin(), n_left); - this->SetNRightElems(node_in_set, range.begin(), n_right); - } - // allocate thread local memory, should be called for each specific task void AllocateForTask(size_t id) { if (mem_blocks_[id].get() == nullptr) { diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h new file mode 100644 index 000000000..949948856 --- /dev/null +++ b/src/tree/common_row_partitioner.h @@ -0,0 +1,212 @@ +/*! + * Copyright 2021-2022 XGBoost contributors + * \file common_row_partitioner.h + * \brief Common partitioner logic for hist and approx methods. + */ +#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ +#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ + +#include // std::numeric_limits +#include + +#include "../common/numeric.h" // Iota +#include "../common/partition_builder.h" +#include "hist/expand_entry.h" // CPUExpandEntry +#include "xgboost/generic_parameters.h" // Context + +namespace xgboost { +namespace tree { +class CommonRowPartitioner { + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + common::RowSetCollection row_set_collection_; + + public: + bst_row_t base_rowid = 0; + + CommonRowPartitioner() = default; + CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid) + : base_rowid{_base_rowid} { + row_set_collection_.Clear(); + std::vector& row_indices = *row_set_collection_.Data(); + row_indices.resize(num_row); + + std::size_t* p_row_indices = row_indices.data(); + common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid); + row_set_collection_.Init(); + } + + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrix& gmat, std::vector* split_conditions) { + for (size_t i = 0; i < nodes.size(); ++i) { + const int32_t nid = nodes[i].nid; + const bst_uint fid = tree[nid].SplitIndex(); + const bst_float split_pt = tree[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + bst_bin_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); + for (auto bound = lower_bound; bound < upper_bound; ++bound) { + if (split_pt == gmat.cut.Values()[bound]) { + split_cond = static_cast(bound); + } + } + (*split_conditions).at(i) = split_cond; + } + } + + void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { + const size_t n_nodes = nodes.size(); + for (unsigned int i = 0; i < n_nodes; ++i) { + const int32_t nid = nodes[i].nid; + const size_t n_left = partition_builder_.GetNLeftElems(i); + const size_t n_right = partition_builder_.GetNRightElems(i); + CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), + n_left, n_right); + } + } + + void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, + std::vector const& nodes, RegTree const* p_tree) { + auto const& column_matrix = gmat.Transpose(); + if (column_matrix.IsInitialized()) { + if (gmat.cut.HasCategorical()) { + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); + } else { + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); + } + } else { + /* ColumnMatrix is not initilized. + * It means that we use 'approx' method. + * any_missing and any_cat don't metter in this case. + * Jump directly to the main method. + */ + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); + } + } + + template + void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, + const common::ColumnMatrix& column_matrix, + std::vector const& nodes, RegTree const* p_tree) { + if (column_matrix.AnyMissing()) { + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); + } else { + this->template UpdatePosition(ctx, gmat, column_matrix, nodes, p_tree); + } + } + + template + void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, + const common::ColumnMatrix& column_matrix, + std::vector const& nodes, RegTree const* p_tree) { + switch (column_matrix.GetTypeSize()) { + case common::kUint8BinsTypeSize: + this->template UpdatePosition(ctx, gmat, column_matrix, + nodes, p_tree); + break; + case common::kUint16BinsTypeSize: + this->template UpdatePosition(ctx, gmat, column_matrix, + nodes, p_tree); + break; + case common::kUint32BinsTypeSize: + this->template UpdatePosition(ctx, gmat, column_matrix, + nodes, p_tree); + break; + default: + // no default behavior + CHECK(false) << column_matrix.GetTypeSize(); + } + } + + template + void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat, + const common::ColumnMatrix& column_matrix, + std::vector const& nodes, RegTree const* p_tree) { + // 1. Find split condition for each split + size_t n_nodes = nodes.size(); + + std::vector split_conditions; + if (column_matrix.IsInitialized()) { + split_conditions.resize(n_nodes); + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + } + + // 2.1 Create a blocked space of size SUM(samples in each node) + common::BlockedSpace2d space( + n_nodes, + [&](size_t node_in_set) { + int32_t nid = nodes[node_in_set].nid; + return row_set_collection_[nid].Size(); + }, + kPartitionBlockSize); + + // 2.2 Initialize the partition builder + // allocate buffers for storage intermediate results by each thread + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + CHECK_EQ(base_rowid, gmat.base_rowid); + + // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node + // Store results in intermediate buffers from partition_builder_ + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + size_t begin = r.begin(); + const int32_t nid = nodes[node_in_set].nid; + const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); + partition_builder_.AllocateForTask(task_id); + bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0; + partition_builder_.template Partition( + node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree, + row_set_collection_[nid].begin); + }); + + // 3. Compute offsets to copy blocks of row-indexes + // from partition_builder_ to row_set_collection_ + partition_builder_.CalculateRowOffsets(); + + // 4. Copy elements from partition_builder_ to row_set_collection_ back + // with updated row-indexes for each tree-node + common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { + const int32_t nid = nodes[node_in_set].nid; + partition_builder_.MergeToArray(node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + + // 5. Add info about splits into row_set_collection_ + AddSplitsToRowSet(nodes, p_tree); + } + + auto const& Partitions() const { return row_set_collection_; } + + size_t Size() const { + return std::distance(row_set_collection_.begin(), row_set_collection_.end()); + } + + auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } + auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } + + void LeafPartition(Context const* ctx, RegTree const& tree, common::Span hess, + std::vector* p_out_position) const { + partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position, + [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); + } + + void LeafPartition(Context const* ctx, RegTree const& tree, + common::Span gpair, + std::vector* p_out_position) const { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, + [&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); + } +}; + +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 5b56eaa52..734138da5 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -3,14 +3,13 @@ * * \brief Implementation for the approx tree method. */ -#include "updater_approx.h" - #include #include #include #include "../common/random.h" #include "../data/gradient_index.h" +#include "common_row_partitioner.h" #include "constraints.h" #include "driver.h" #include "hist/evaluate_splits.h" @@ -46,7 +45,7 @@ class GloablApproxBuilder { Context const *ctx_; ObjInfo const task_; - std::vector partitioner_; + std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. RegTree *p_last_tree_{nullptr}; common::Monitor *monitor_; @@ -69,7 +68,7 @@ class GloablApproxBuilder { } else { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } - partitioner_.emplace_back(page.Size(), page.base_rowid); + partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid); n_batches_++; } @@ -151,7 +150,7 @@ class GloablApproxBuilder { monitor_->Stop(__func__); } - void LeafPartition(RegTree const &tree, common::Span hess, + void LeafPartition(RegTree const &tree, common::Span hess, std::vector *p_out_position) { monitor_->Start(__func__); if (!task_.UpdateTreeLeaf()) { diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h deleted file mode 100644 index 3be12fc57..000000000 --- a/src/tree/updater_approx.h +++ /dev/null @@ -1,150 +0,0 @@ -/*! - * Copyright 2021-2022 XGBoost contributors - * - * \brief Implementation for the approx tree method. - */ -#ifndef XGBOOST_TREE_UPDATER_APPROX_H_ -#define XGBOOST_TREE_UPDATER_APPROX_H_ - -#include -#include -#include - -#include "../common/partition_builder.h" -#include "../common/random.h" -#include "constraints.h" -#include "driver.h" -#include "hist/evaluate_splits.h" -#include "hist/expand_entry.h" -#include "param.h" -#include "xgboost/generic_parameters.h" -#include "xgboost/json.h" -#include "xgboost/tree_updater.h" - -namespace xgboost { -namespace tree { -class ApproxRowPartitioner { - static constexpr size_t kPartitionBlockSize = 2048; - common::PartitionBuilder partition_builder_; - common::RowSetCollection row_set_collection_; - - public: - bst_row_t base_rowid = 0; - - static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index, - std::vector const &cut_ptrs, - std::vector const &cut_values) { - int32_t gidx = -1; - if (index.IsDense()) { - // RowIdx returns the starting pos of this row - gidx = index.index[index.RowIdx(ridx) + fidx]; - } else { - auto begin = index.RowIdx(ridx); - auto end = index.RowIdx(ridx + 1); - auto f_begin = cut_ptrs[fidx]; - auto f_end = cut_ptrs[fidx + 1]; - gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); - } - if (gidx == -1) { - return std::numeric_limits::quiet_NaN(); - } - return cut_values[gidx]; - } - - public: - void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index, - std::vector const &candidates, RegTree const *p_tree) { - size_t n_nodes = candidates.size(); - - auto const &cut_values = index.cut.Values(); - auto const &cut_ptrs = index.cut.Ptrs(); - - common::BlockedSpace2d space{n_nodes, - [&](size_t node_in_set) { - auto candidate = candidates[node_in_set]; - int32_t nid = candidate.nid; - return row_set_collection_[nid].Size(); - }, - kPartitionBlockSize}; - partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { - auto candidate = candidates[node_in_set]; - const int32_t nid = candidate.nid; - const size_t size = row_set_collection_[nid].Size(); - const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); - return n_tasks; - }); - auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr; - auto categories = p_tree->GetCategoriesMatrix().categories; - common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { - auto candidate = candidates[node_in_set]; - auto is_cat = candidate.split.is_cat; - const int32_t nid = candidate.nid; - auto fidx = candidate.split.SplitIndex(); - const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin()); - partition_builder_.AllocateForTask(task_id); - partition_builder_.PartitionRange( - node_in_set, nid, r, &row_set_collection_, [&](size_t row_id) { - auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values); - if (std::isnan(cut_value)) { - return candidate.split.DefaultLeft(); - } - bst_node_t nidx = candidate.nid; - auto segment = node_ptr[nidx]; - auto node_cats = categories.subspan(segment.beg, segment.size); - bool go_left = true; - if (is_cat) { - go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft()); - } else { - go_left = cut_value <= candidate.split.split_value; - } - return go_left; - }); - }); - - partition_builder_.CalculateRowOffsets(); - common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { - auto candidate = candidates[node_in_set]; - const int32_t nid = candidate.nid; - partition_builder_.MergeToArray(node_in_set, r.begin(), - const_cast(row_set_collection_[nid].begin)); - }); - for (size_t i = 0; i < candidates.size(); ++i) { - auto const &candidate = candidates[i]; - auto nidx = candidate.nid; - auto n_left = partition_builder_.GetNLeftElems(i); - auto n_right = partition_builder_.GetNRightElems(i); - CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size()); - bst_node_t left_nidx = (*p_tree)[nidx].LeftChild(); - bst_node_t right_nidx = (*p_tree)[nidx].RightChild(); - row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right); - } - } - - auto const &Partitions() const { return row_set_collection_; } - - void LeafPartition(Context const *ctx, RegTree const &tree, common::Span hess, - std::vector *p_out_position) const { - partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position, - [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); - } - - auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } - auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } - - size_t Size() const { - return std::distance(row_set_collection_.begin(), row_set_collection_.end()); - } - - ApproxRowPartitioner() = default; - explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid) - : base_rowid{_base_rowid} { - row_set_collection_.Clear(); - auto p_positions = row_set_collection_.Data(); - p_positions->resize(num_row); - std::iota(p_positions->begin(), p_positions->end(), base_rowid); - row_set_collection_.Init(); - } -}; -} // namespace tree -} // namespace xgboost -#endif // XGBOOST_TREE_UPDATER_APPROX_H_ diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index cd6345619..1e9d76d4f 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -12,7 +12,9 @@ #include #include +#include "common_row_partitioner.h" #include "constraints.h" +#include "hist/histogram.h" #include "hist/evaluate_splits.h" #include "param.h" #include "xgboost/logging.h" @@ -309,7 +311,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, } else { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } - partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads()); + partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid); ++page_id; } histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, @@ -331,44 +333,6 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, monitor_->Stop(__func__); } -void HistRowPartitioner::FindSplitConditions(const std::vector &nodes, - const RegTree &tree, const GHistIndexMatrix &gmat, - std::vector *split_conditions) { - const size_t n_nodes = nodes.size(); - split_conditions->resize(n_nodes); - - for (size_t i = 0; i < nodes.size(); ++i) { - const int32_t nid = nodes[i].nid; - const bst_uint fid = tree[nid].SplitIndex(); - const bst_float split_pt = tree[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - bst_bin_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); - for (auto bound = lower_bound; bound < upper_bound; ++bound) { - if (split_pt == gmat.cut.Values()[bound]) { - split_cond = static_cast(bound); - } - } - (*split_conditions)[i] = split_cond; - } -} - -void HistRowPartitioner::AddSplitsToRowSet(const std::vector &nodes, - RegTree const *p_tree) { - const size_t n_nodes = nodes.size(); - for (unsigned int i = 0; i < n_nodes; ++i) { - const int32_t nid = nodes[i].nid; - const size_t n_left = partition_builder_.GetNLeftElems(i); - const size_t n_right = partition_builder_.GetNRightElems(i); - CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); - row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), - n_left, n_right); - } -} - XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body([](GenericParameter const *ctx, ObjInfo task) { diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 4c939bf7d..29bda34d4 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -24,6 +24,7 @@ #include "hist/histogram.h" #include "hist/expand_entry.h" +#include "common_row_partitioner.h" #include "constraints.h" #include "./param.h" #include "./driver.h" @@ -77,155 +78,6 @@ struct RandomReplace { }; namespace tree { -class HistRowPartitioner { - // heuristically chosen block size of parallel partitioning - static constexpr size_t kPartitionBlockSize = 2048; - // worker class that partition a block of rows - common::PartitionBuilder partition_builder_; - // storage for row index - common::RowSetCollection row_set_collection_; - - /** - * \brief Turn split values into discrete bin indices. - */ - static void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrix& gmat, - std::vector* split_conditions); - /** - * \brief Update the row set for new splits specifed by nodes. - */ - void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree); - - public: - bst_row_t base_rowid = 0; - - public: - HistRowPartitioner(size_t n_samples, size_t base_rowid, int32_t n_threads) { - row_set_collection_.Clear(); - const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads); - dmlc::OMPException exc; - std::vector& row_indices = *row_set_collection_.Data(); - row_indices.resize(n_samples); - size_t* p_row_indices = row_indices.data(); - // parallel initialization o f row indices. (std::iota) -#pragma omp parallel num_threads(n_threads) - { - exc.Run([&]() { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * block_size; - const size_t iend = std::min(static_cast(ibegin + block_size), n_samples); - for (size_t i = ibegin; i < iend; ++i) { - p_row_indices[i] = i + base_rowid; - } - }); - } - row_set_collection_.Init(); - this->base_rowid = base_rowid; - } - - template - void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& gmat, - common::ColumnMatrix const& column_matrix, - std::vector const& nodes, RegTree const* p_tree) { - // 1. Find split condition for each split - const size_t n_nodes = nodes.size(); - std::vector split_conditions; - FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); - // 2.1 Create a blocked space of size SUM(samples in each node) - common::BlockedSpace2d space( - n_nodes, - [&](size_t node_in_set) { - int32_t nid = nodes[node_in_set].nid; - return row_set_collection_[nid].Size(); - }, - kPartitionBlockSize); - // 2.2 Initialize the partition builder - // allocate buffers for storage intermediate results by each thread - partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { - const int32_t nid = nodes[node_in_set].nid; - const size_t size = row_set_collection_[nid].Size(); - const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); - return n_tasks; - }); - CHECK_EQ(base_rowid, gmat.base_rowid); - // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node - // Store results in intermediate buffers from partition_builder_ - common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { - size_t begin = r.begin(); - const int32_t nid = nodes[node_in_set].nid; - const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); - partition_builder_.AllocateForTask(task_id); - switch (column_matrix.GetTypeSize()) { - case common::kUint8BinsTypeSize: - partition_builder_.template Partition( - node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree, - row_set_collection_[nid].begin); - break; - case common::kUint16BinsTypeSize: - partition_builder_.template Partition( - node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree, - row_set_collection_[nid].begin); - break; - case common::kUint32BinsTypeSize: - partition_builder_.template Partition( - node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree, - row_set_collection_[nid].begin); - break; - default: - // no default behavior - CHECK(false) << column_matrix.GetTypeSize(); - } - }); - // 3. Compute offsets to copy blocks of row-indexes - // from partition_builder_ to row_set_collection_ - partition_builder_.CalculateRowOffsets(); - - // 4. Copy elements from partition_builder_ to row_set_collection_ back - // with updated row-indexes for each tree-node - common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { - const int32_t nid = nodes[node_in_set].nid; - partition_builder_.MergeToArray(node_in_set, r.begin(), - const_cast(row_set_collection_[nid].begin)); - }); - // 5. Add info about splits into row_set_collection_ - AddSplitsToRowSet(nodes, p_tree); - } - - void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page, - std::vector const& applied, RegTree const* p_tree) { - auto const& column_matrix = page.Transpose(); - if (page.cut.HasCategorical()) { - if (column_matrix.AnyMissing()) { - this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); - } else { - this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); - } - } else { - if (column_matrix.AnyMissing()) { - this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); - } else { - this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); - } - } - } - - auto const& Partitions() const { return row_set_collection_; } - size_t Size() const { - return std::distance(row_set_collection_.begin(), row_set_collection_.end()); - } - - void LeafPartition(Context const* ctx, RegTree const& tree, - common::Span gpair, - std::vector* p_out_position) const { - partition_builder_.LeafPartition( - ctx, tree, this->Partitions(), p_out_position, - [&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); - } - - auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } - auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } -}; - inline BatchParam HistBatch(TrainParam const& param) { return {param.max_bin, param.sparse_threshold}; } @@ -314,7 +166,7 @@ class QuantileHistMaker: public TreeUpdater { std::vector gpair_local_; std::unique_ptr> evaluator_; - std::vector partitioner_; + std::vector partitioner_; // back pointers to tree and data matrix const RegTree* p_last_tree_{nullptr}; diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 7e1d285e7..7000240df 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -5,8 +5,8 @@ #include #include "../../../../src/common/hist_util.h" +#include "../../../../src/tree/common_row_partitioner.h" #include "../../../../src/tree/hist/evaluate_splits.h" -#include "../../../../src/tree/updater_quantile_hist.h" #include "../test_evaluate_splits.h" #include "../../helpers.h" diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 0ae23557f..ba8d6f129 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -4,7 +4,7 @@ #include #include "../../../src/common/numeric.h" -#include "../../../src/tree/updater_approx.h" +#include "../../../src/tree/common_row_partitioner.h" #include "../helpers.h" #include "test_partitioner.h" @@ -12,13 +12,13 @@ namespace xgboost { namespace tree { TEST(Approx, Partitioner) { size_t n_samples = 1024, n_features = 1, base_rowid = 0; - ApproxRowPartitioner partitioner{n_samples, base_rowid}; + GenericParameter ctx; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; ASSERT_EQ(partitioner.base_rowid, base_rowid); ASSERT_EQ(partitioner.Size(), 1); ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - GenericParameter ctx; ctx.InitAllowUnknown(Args{}); std::vector candidates{{0, 0, 0.4}}; @@ -32,7 +32,7 @@ TEST(Approx, Partitioner) { { auto min_value = page.cut.MinValues()[split_ind]; RegTree tree; - ApproxRowPartitioner partitioner{n_samples, base_rowid}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; GetSplit(&tree, min_value, &candidates); partitioner.UpdatePosition(&ctx, page, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); @@ -40,7 +40,7 @@ TEST(Approx, Partitioner) { ASSERT_EQ(partitioner[2].Size(), n_samples); } { - ApproxRowPartitioner partitioner{n_samples, base_rowid}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; auto ptr = page.cut.Ptrs()[split_ind + 1]; float split_value = page.cut.Values().at(ptr / 2); RegTree tree; @@ -65,14 +65,15 @@ TEST(Approx, Partitioner) { } } } + namespace { void TestLeafPartition(size_t n_samples) { size_t const n_features = 2, base_rowid = 0; + GenericParameter ctx; common::RowSetCollection row_set; - ApproxRowPartitioner partitioner{n_samples, base_rowid}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); - GenericParameter ctx; std::vector candidates{{0, 0, 0.4}}; RegTree tree; std::vector hess(n_samples, 0); @@ -81,11 +82,9 @@ void TestLeafPartition(size_t n_samples) { size_t const kSampleFactor{3}; return i % kSampleFactor != 0; }; - size_t n{0}; for (size_t i = 0; i < hess.size(); ++i) { if (not_sampled(i)) { hess[i] = 1.0f; - ++n; } } diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 90dc0a411..17dcb4c93 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -12,8 +12,7 @@ TEST(GrowHistMaker, InteractionConstraint) { size_t constexpr kRows = 32; size_t constexpr kCols = 16; - GenericParameter param; - param.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + Context ctx; auto p_dmat = RandomDataGenerator{kRows, kCols, 0.6f}.Seed(3).GenerateDMatrix(); @@ -35,7 +34,7 @@ TEST(GrowHistMaker, InteractionConstraint) { tree.param.num_feature = kCols; std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})}; + TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); @@ -54,7 +53,7 @@ TEST(GrowHistMaker, InteractionConstraint) { tree.param.num_feature = kCols; std::unique_ptr updater{ - TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})}; + TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); std::vector> position(1); updater->Update(&gradients, p_dmat.get(), position, {&tree}); diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index f1491b829..222339aae 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -11,7 +11,7 @@ #include "../../../src/tree/param.h" #include "../../../src/tree/split_evaluator.h" -#include "../../../src/tree/updater_quantile_hist.h" +#include "../../../src/tree/common_row_partitioner.h" #include "../helpers.h" #include "test_partitioner.h" #include "xgboost/data.h" @@ -23,7 +23,7 @@ TEST(QuantileHist, Partitioner) { GenericParameter ctx; ctx.InitAllowUnknown(Args{}); - HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; ASSERT_EQ(partitioner.base_rowid, base_rowid); ASSERT_EQ(partitioner.Size(), 1); ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples); @@ -41,7 +41,7 @@ TEST(QuantileHist, Partitioner) { { auto min_value = gmat.cut.MinValues()[split_ind]; RegTree tree; - HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; GetSplit(&tree, min_value, &candidates); partitioner.UpdatePosition(&ctx, gmat, column_indices, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); @@ -49,7 +49,7 @@ TEST(QuantileHist, Partitioner) { ASSERT_EQ(partitioner[2].Size(), n_samples); } { - HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; + CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid}; auto ptr = gmat.cut.Ptrs()[split_ind + 1]; float split_value = gmat.cut.Values().at(ptr / 2); RegTree tree; diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index 723ca34eb..1a4ee5acb 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -1,6 +1,6 @@ -#include -#include #include +#include +#include #include "../helpers.h" @@ -21,9 +21,10 @@ class UpdaterTreeStatTest : public ::testing::Test { } void RunTest(std::string updater) { - auto tparam = CreateEmptyGenericParam(0); + Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) + : CreateEmptyGenericParam(Context::kCpuId)); auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kRegression})}; + TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})}; up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; @@ -41,22 +42,14 @@ class UpdaterTreeStatTest : public ::testing::Test { }; #if defined(XGBOOST_USE_CUDA) -TEST_F(UpdaterTreeStatTest, GpuHist) { - this->RunTest("grow_gpu_hist"); -} +TEST_F(UpdaterTreeStatTest, GpuHist) { this->RunTest("grow_gpu_hist"); } #endif // defined(XGBOOST_USE_CUDA) -TEST_F(UpdaterTreeStatTest, Hist) { - this->RunTest("grow_quantile_histmaker"); -} +TEST_F(UpdaterTreeStatTest, Hist) { this->RunTest("grow_quantile_histmaker"); } -TEST_F(UpdaterTreeStatTest, Exact) { - this->RunTest("grow_colmaker"); -} +TEST_F(UpdaterTreeStatTest, Exact) { this->RunTest("grow_colmaker"); } -TEST_F(UpdaterTreeStatTest, Approx) { - this->RunTest("grow_histmaker"); -} +TEST_F(UpdaterTreeStatTest, Approx) { this->RunTest("grow_histmaker"); } class UpdaterEtaTest : public ::testing::Test { protected: @@ -74,14 +67,15 @@ class UpdaterEtaTest : public ::testing::Test { } void RunTest(std::string updater) { - auto tparam = CreateEmptyGenericParam(0); + GenericParameter ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) + : CreateEmptyGenericParam(Context::kCpuId)); float eta = 0.4; auto up_0 = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kClassification})}; + TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})}; up_0->Configure(Args{{"eta", std::to_string(eta)}}); auto up_1 = std::unique_ptr{ - TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kClassification})}; + TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})}; up_1->Configure(Args{{"eta", "1.0"}}); for (size_t iter = 0; iter < 4; ++iter) { @@ -130,7 +124,7 @@ class TestMinSplitLoss : public ::testing::Test { gpair_ = GenerateRandomGradients(kRows); } - int32_t Update(std::string updater, float gamma) { + std::int32_t Update(std::string updater, float gamma) { Args args{{"max_depth", "1"}, {"max_leaves", "0"}, @@ -146,9 +140,12 @@ class TestMinSplitLoss : public ::testing::Test { // test gamma {"gamma", std::to_string(gamma)}}; - GenericParameter generic_param(CreateEmptyGenericParam(0)); + std::cout << "updater:" << updater << std::endl; + GenericParameter ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0) + : CreateEmptyGenericParam(Context::kCpuId)); + std::cout << ctx.gpu_id << std::endl; auto up = std::unique_ptr{ - TreeUpdater::Create(updater, &generic_param, ObjInfo{ObjInfo::kRegression})}; + TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})}; up->Configure(args); RegTree tree;