Avoid the use of size_t in the partitioner. (#10541)

- Avoid the use of size_t in the partitioner.
- Use `Span` instead of `Elem` where `node_id` is not needed.
- Remove the `const_cast`.
- Make sure the constness is not removed in the `Elem` by making it reference only.

size_t is implementation-defined, which causes issue when we want to pass pointer or span.
This commit is contained in:
Jiaming Yuan 2024-07-11 00:43:08 +08:00 committed by GitHub
parent baba3e9eb0
commit 34b154c284
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 203 additions and 181 deletions

View File

@ -187,15 +187,14 @@ class GHistBuildingManager {
}; };
template <bool do_prefetch, class BuildingManager> template <bool do_prefetch, class BuildingManager>
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair, void RowsWiseBuildHistKernel(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat, GHistRow hist) {
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing; constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage; constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType; using BinIdxType = typename BuildingManager::BinIdxType;
const size_t size = row_indices.Size(); const size_t size = row_indices.size();
const size_t *rid = row_indices.begin; bst_idx_t const *rid = row_indices.data();
auto const *p_gpair = reinterpret_cast<const float *>(gpair.data()); auto const *p_gpair = reinterpret_cast<const float *>(gpair.data());
const BinIdxType *gradient_index = gmat.index.data<BinIdxType>(); const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();
@ -216,9 +215,9 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
return kFirstPage ? ridx : (ridx - base_rowid); return kFirstPage ? ridx : (ridx - base_rowid);
}; };
CHECK_NE(row_indices.Size(), 0); CHECK_NE(row_indices.size(), 0);
const size_t n_features = const size_t n_features =
get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); get_row_ptr(row_indices.data()[0] + 1) - get_row_ptr(row_indices.data()[0]);
auto hist_data = reinterpret_cast<double *>(hist.data()); auto hist_data = reinterpret_cast<double *>(hist.data());
const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains
// 2 FP values: gradient and hessian. // 2 FP values: gradient and hessian.
@ -264,14 +263,13 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
} }
template <class BuildingManager> template <class BuildingManager>
void ColsWiseBuildHistKernel(Span<GradientPair const> gpair, void ColsWiseBuildHistKernel(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat, GHistRow hist) {
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing; constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage; constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType; using BinIdxType = typename BuildingManager::BinIdxType;
const size_t size = row_indices.Size(); const size_t size = row_indices.size();
const size_t *rid = row_indices.begin; bst_idx_t const *rid = row_indices.data();
auto const *pgh = reinterpret_cast<const float *>(gpair.data()); auto const *pgh = reinterpret_cast<const float *>(gpair.data());
const BinIdxType *gradient_index = gmat.index.data<BinIdxType>(); const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();
@ -315,31 +313,31 @@ void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
} }
template <class BuildingManager> template <class BuildingManager>
void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices, void BuildHistDispatch(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const GHistIndexMatrix &gmat, GHistRow hist) { const GHistIndexMatrix &gmat, GHistRow hist) {
if (BuildingManager::kReadByColumn) { if (BuildingManager::kReadByColumn) {
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist); ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
} else { } else {
const size_t nrows = row_indices.Size(); const size_t nrows = row_indices.size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
// if need to work with all rows from bin-matrix (e.g. root node) // if need to work with all rows from bin-matrix (e.g. root node)
const bool contiguousBlock = const bool contiguousBlock =
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); (row_indices.begin()[nrows - 1] - row_indices.begin()[0]) == (nrows - 1);
if (contiguousBlock) { if (contiguousBlock) {
// contiguous memory access, built-in HW prefetching is enough if (row_indices.empty()) {
if (row_indices.Size() == 0) {
return; return;
} }
// contiguous memory access, built-in HW prefetching is enough
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, row_indices, gmat, hist); RowsWiseBuildHistKernel<false, BuildingManager>(gpair, row_indices, gmat, hist);
} else { } else {
const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); auto span1 = row_indices.subspan(0, row_indices.size() - no_prefetch_size);
if (span1.Size() != 0) { if (!span1.empty()) {
RowsWiseBuildHistKernel<true, BuildingManager>(gpair, span1, gmat, hist); RowsWiseBuildHistKernel<true, BuildingManager>(gpair, span1, gmat, hist);
} }
// no prefetching to avoid loading extra memory // no prefetching to avoid loading extra memory
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); auto span2 = row_indices.subspan(row_indices.size() - no_prefetch_size);
if (span2.Size() != 0) { if (!span2.empty()) {
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, span2, gmat, hist); RowsWiseBuildHistKernel<false, BuildingManager>(gpair, span2, gmat, hist);
} }
} }
@ -347,7 +345,7 @@ void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::E
} }
template <bool any_missing> template <bool any_missing>
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices, void BuildHist(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const GHistIndexMatrix &gmat, GHistRow hist, bool force_read_by_column) { const GHistIndexMatrix &gmat, GHistRow hist, bool force_read_by_column) {
/* force_read_by_column is used for testing the columnwise building of histograms. /* force_read_by_column is used for testing the columnwise building of histograms.
* default force_read_by_column = false * default force_read_by_column = false
@ -365,13 +363,11 @@ void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_
}); });
} }
template void BuildHist<true>(Span<GradientPair const> gpair, template void BuildHist<true>(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist, const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column); bool force_read_by_column);
template void BuildHist<false>(Span<GradientPair const> gpair, template void BuildHist<false>(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist, const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column); bool force_read_by_column);
} // namespace xgboost::common } // namespace xgboost::common

View File

@ -635,7 +635,7 @@ class ParallelGHistBuilder {
// construct a histogram via histogram aggregation // construct a histogram via histogram aggregation
template <bool any_missing> template <bool any_missing>
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices, void BuildHist(Span<GradientPair const> gpair, Span<bst_idx_t const> row_indices,
const GHistIndexMatrix& gmat, GHistRow hist, bool force_read_by_column = false); const GHistIndexMatrix& gmat, GHistRow hist, bool force_read_by_column = false);
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021-2023 by Contributors * Copyright 2021-2024, XGBoost Contributors
* \file row_set.h * \file row_set.h
* \brief Quick Utility to compute subset of rows * \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -16,7 +16,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../tree/hist/expand_entry.h"
#include "categorical.h" #include "categorical.h"
#include "column_matrix.h" #include "column_matrix.h"
#include "xgboost/context.h" #include "xgboost/context.h"
@ -54,23 +53,23 @@ class PartitionBuilder {
// Handle dense columns // Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner // Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate> template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column, std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
common::Span<const size_t> row_indices, common::Span<const bst_idx_t> row_indices,
common::Span<size_t> left_part, common::Span<bst_idx_t> left_part,
common::Span<size_t> right_part, common::Span<bst_idx_t> right_part,
size_t base_rowid, Predicate&& pred) { bst_idx_t base_rowid, Predicate&& pred) {
auto& column = *p_column; auto& column = *p_column;
size_t* p_left_part = left_part.data(); bst_idx_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data(); bst_idx_t* p_right_part = right_part.data();
size_t nleft_elems = 0; bst_idx_t nleft_elems = 0;
size_t nright_elems = 0; bst_idx_t nright_elems = 0;
auto p_row_indices = row_indices.data(); auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size(); auto n_samples = row_indices.size();
for (size_t i = 0; i < n_samples; ++i) { for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i]; auto rid = p_row_indices[i];
const int32_t bin_id = column[rid - base_rowid]; bst_bin_t const bin_id = column[rid - base_rowid];
if (any_missing && bin_id == ColumnType::kMissingId) { if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) { if (default_left) {
p_left_part[nleft_elems++] = rid; p_left_part[nleft_elems++] = rid;
@ -90,14 +89,14 @@ class PartitionBuilder {
} }
template <typename Pred> template <typename Pred>
inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const size_t> ridx, inline std::pair<size_t, size_t> PartitionRangeKernel(common::Span<const bst_idx_t> ridx,
common::Span<size_t> left_part, common::Span<bst_idx_t> left_part,
common::Span<size_t> right_part, common::Span<bst_idx_t> right_part,
Pred pred) { Pred pred) {
size_t* p_left_part = left_part.data(); bst_idx_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data(); bst_idx_t* p_right_part = right_part.data();
size_t nleft_elems = 0; bst_idx_t nleft_elems = 0;
size_t nright_elems = 0; bst_idx_t nright_elems = 0;
for (auto row_id : ridx) { for (auto row_id : ridx) {
if (pred(row_id)) { if (pred(row_id)) {
p_left_part[nleft_elems++] = row_id; p_left_part[nleft_elems++] = row_id;
@ -112,10 +111,10 @@ class PartitionBuilder {
void Partition(const size_t node_in_set, std::vector<ExpandEntry> const& nodes, void Partition(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, const bst_bin_t split_cond, const common::Range1d range, const bst_bin_t split_cond,
GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix, GHistIndexMatrix const& gmat, const common::ColumnMatrix& column_matrix,
const RegTree& tree, const size_t* rid) { const RegTree& tree, bst_idx_t const* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<bst_idx_t const> rid_span{rid + range.begin(), rid + range.end()};
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<bst_idx_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end()); common::Span<bst_idx_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree.SplitIndex(nid); bst_feature_t fid = tree.SplitIndex(nid);
bool default_left = tree.DefaultLeft(nid); bool default_left = tree.DefaultLeft(nid);
@ -184,8 +183,9 @@ class PartitionBuilder {
} }
template <bool any_missing, typename ColumnType, typename Predicate> template <bool any_missing, typename ColumnType, typename Predicate>
void MaskKernel(ColumnType* p_column, common::Span<const size_t> row_indices, size_t base_rowid, void MaskKernel(ColumnType* p_column, common::Span<bst_idx_t const> row_indices,
BitVector* decision_bits, BitVector* missing_bits, Predicate&& pred) { bst_idx_t base_rowid, BitVector* decision_bits, BitVector* missing_bits,
Predicate&& pred) {
auto& column = *p_column; auto& column = *p_column;
for (auto const row_id : row_indices) { for (auto const row_id : row_indices) {
auto const bin_id = column[row_id - base_rowid]; auto const bin_id = column[row_id - base_rowid];
@ -205,9 +205,9 @@ class PartitionBuilder {
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry> template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
void MaskRows(const size_t node_in_set, std::vector<ExpandEntry> const& nodes, void MaskRows(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, bst_bin_t split_cond, GHistIndexMatrix const& gmat, const common::Range1d range, bst_bin_t split_cond, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid, const common::ColumnMatrix& column_matrix, const RegTree& tree,
BitVector* decision_bits, BitVector* missing_bits) { bst_idx_t const* rid, BitVector* decision_bits, BitVector* missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<bst_idx_t const> rid_span{rid + range.begin(), rid + range.end()};
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bst_feature_t fid = tree.SplitIndex(nid); bst_feature_t fid = tree.SplitIndex(nid);
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical; bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
@ -263,11 +263,11 @@ class PartitionBuilder {
template <typename ExpandEntry> template <typename ExpandEntry>
void PartitionByMask(const size_t node_in_set, std::vector<ExpandEntry> const& nodes, void PartitionByMask(const size_t node_in_set, std::vector<ExpandEntry> const& nodes,
const common::Range1d range, GHistIndexMatrix const& gmat, const common::Range1d range, GHistIndexMatrix const& gmat,
const RegTree& tree, const size_t* rid, BitVector const& decision_bits, const RegTree& tree, bst_idx_t const* rid, BitVector const& decision_bits,
BitVector const& missing_bits) { BitVector const& missing_bits) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<bst_idx_t const> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<bst_idx_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end()); common::Span<bst_idx_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
std::size_t nid = nodes[node_in_set].nid; std::size_t nid = nodes[node_in_set].nid;
bool default_left = tree.DefaultLeft(nid); bool default_left = tree.DefaultLeft(nid);
@ -299,12 +299,12 @@ class PartitionBuilder {
} }
} }
common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) { common::Span<bst_idx_t> GetLeftBuffer(int nid, size_t begin, size_t end) {
const size_t task_idx = GetTaskIdx(nid, begin); const size_t task_idx = GetTaskIdx(nid, begin);
return { mem_blocks_.at(task_idx)->Left(), end - begin }; return { mem_blocks_.at(task_idx)->Left(), end - begin };
} }
common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) { common::Span<bst_idx_t> GetRightBuffer(int nid, size_t begin, size_t end) {
const size_t task_idx = GetTaskIdx(nid, begin); const size_t task_idx = GetTaskIdx(nid, begin);
return { mem_blocks_.at(task_idx)->Right(), end - begin }; return { mem_blocks_.at(task_idx)->Right(), end - begin };
} }
@ -346,14 +346,14 @@ class PartitionBuilder {
} }
} }
void MergeToArray(int nid, size_t begin, size_t* rows_indexes) { void MergeToArray(bst_node_t nid, size_t begin, bst_idx_t* rows_indexes) {
size_t task_idx = GetTaskIdx(nid, begin); size_t task_idx = GetTaskIdx(nid, begin);
size_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left; bst_idx_t* left_result = rows_indexes + mem_blocks_[task_idx]->n_offset_left;
size_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right; bst_idx_t* right_result = rows_indexes + mem_blocks_[task_idx]->n_offset_right;
const size_t* left = mem_blocks_[task_idx]->Left(); bst_idx_t const* left = mem_blocks_[task_idx]->Left();
const size_t* right = mem_blocks_[task_idx]->Right(); bst_idx_t const* right = mem_blocks_[task_idx]->Right();
std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result); std::copy_n(left, mem_blocks_[task_idx]->n_left, left_result);
std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result); std::copy_n(right, mem_blocks_[task_idx]->n_right, right_result);
@ -377,10 +377,10 @@ class PartitionBuilder {
return; return;
} }
CHECK(tree.IsLeaf(node.node_id)); CHECK(tree.IsLeaf(node.node_id));
if (node.begin) { // guard for empty node. if (node.begin()) { // guard for empty node.
size_t ptr_offset = node.end - p_begin; size_t ptr_offset = node.end() - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin; idx != node.end; ++idx) { for (auto idx = node.begin(); idx != node.end(); ++idx) {
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id; h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
} }
} }
@ -395,16 +395,16 @@ class PartitionBuilder {
size_t n_offset_left; size_t n_offset_left;
size_t n_offset_right; size_t n_offset_right;
size_t* Left() { bst_idx_t* Left() {
return &left_data_[0]; return &left_data_[0];
} }
size_t* Right() { bst_idx_t* Right() {
return &right_data_[0]; return &right_data_[0];
} }
private: private:
size_t left_data_[BlockSize]; bst_idx_t left_data_[BlockSize];
size_t right_data_[BlockSize]; bst_idx_t right_data_[BlockSize];
}; };
std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_; std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_;
std::vector<size_t> blocks_offsets_; std::vector<size_t> blocks_offsets_;

View File

@ -31,15 +31,29 @@ class RowSetCollection {
* associated with a particular node in a decision tree. * associated with a particular node in a decision tree.
*/ */
struct Elem { struct Elem {
std::size_t const* begin{nullptr}; private:
std::size_t const* end{nullptr}; bst_idx_t* begin_{nullptr};
bst_idx_t* end_{nullptr};
public:
bst_node_t node_id{-1}; bst_node_t node_id{-1};
// id of node associated with this instance set; -1 means uninitialized // id of node associated with this instance set; -1 means uninitialized
Elem() = default; Elem() = default;
Elem(std::size_t const* begin, std::size_t const* end, bst_node_t node_id = -1) Elem(bst_idx_t* begin, bst_idx_t* end, bst_node_t node_id = -1)
: begin(begin), end(end), node_id(node_id) {} : begin_(begin), end_(end), node_id(node_id) {}
std::size_t Size() const { return end - begin; } // Disable copy ctor to avoid casting away the constness via copy.
Elem(Elem const& that) = delete;
Elem& operator=(Elem const& that) = delete;
Elem(Elem&& that) = default;
Elem& operator=(Elem&& that) = default;
[[nodiscard]] std::size_t Size() const { return std::distance(begin(), end()); }
[[nodiscard]] bst_idx_t const* begin() const { return this->begin_; } // NOLINT
[[nodiscard]] bst_idx_t const* end() const { return this->end_; } // NOLINT
[[nodiscard]] bst_idx_t* begin() { return this->begin_; } // NOLINT
[[nodiscard]] bst_idx_t* end() { return this->end_; } // NOLINT
}; };
[[nodiscard]] std::vector<Elem>::const_iterator begin() const { // NOLINT [[nodiscard]] std::vector<Elem>::const_iterator begin() const { // NOLINT
@ -71,55 +85,57 @@ class RowSetCollection {
CHECK(elem_of_each_node_.empty()); CHECK(elem_of_each_node_.empty());
if (row_indices_.empty()) { // edge case: empty instance set if (row_indices_.empty()) { // edge case: empty instance set
constexpr std::size_t* kBegin = nullptr; constexpr bst_idx_t* kBegin = nullptr;
constexpr std::size_t* kEnd = nullptr; constexpr bst_idx_t* kEnd = nullptr;
static_assert(kEnd - kBegin == 0); static_assert(kEnd - kBegin == 0);
elem_of_each_node_.emplace_back(kBegin, kEnd, 0); elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
return; return;
} }
const std::size_t* begin = dmlc::BeginPtr(row_indices_); bst_idx_t* begin = row_indices_.data();
const std::size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); bst_idx_t* end = row_indices_.data() + row_indices_.size();
elem_of_each_node_.emplace_back(begin, end, 0); elem_of_each_node_.emplace_back(begin, end, 0);
} }
[[nodiscard]] std::vector<std::size_t>* Data() { return &row_indices_; } [[nodiscard]] std::vector<bst_idx_t>* Data() { return &row_indices_; }
[[nodiscard]] std::vector<std::size_t> const* Data() const { return &row_indices_; } [[nodiscard]] std::vector<bst_idx_t> const* Data() const { return &row_indices_; }
// split rowset into two // split rowset into two
void AddSplit(bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id, void AddSplit(bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
bst_idx_t n_left, bst_idx_t n_right) { bst_idx_t n_left, bst_idx_t n_right) {
const Elem e = elem_of_each_node_[node_id]; Elem& e = elem_of_each_node_[node_id];
std::size_t* all_begin{nullptr}; bst_idx_t* all_begin{nullptr};
std::size_t* begin{nullptr}; bst_idx_t* begin{nullptr};
if (e.begin == nullptr) { bst_idx_t* end{nullptr};
if (e.begin() == nullptr) {
CHECK_EQ(n_left, 0); CHECK_EQ(n_left, 0);
CHECK_EQ(n_right, 0); CHECK_EQ(n_right, 0);
} else { } else {
all_begin = row_indices_.data(); all_begin = row_indices_.data();
begin = all_begin + (e.begin - all_begin); begin = all_begin + (e.begin() - all_begin);
end = elem_of_each_node_[node_id].end();
} }
CHECK_EQ(n_left + n_right, e.Size()); CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end); CHECK_LE(begin + n_left, e.end());
CHECK_EQ(begin + n_left + n_right, e.end); CHECK_EQ(begin + n_left + n_right, e.end());
if (left_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) { if (left_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
elem_of_each_node_.resize(left_node_id + 1, Elem{nullptr, nullptr, -1}); elem_of_each_node_.resize(left_node_id + 1);
} }
if (right_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) { if (right_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
elem_of_each_node_.resize(right_node_id + 1, Elem{nullptr, nullptr, -1}); elem_of_each_node_.resize(right_node_id + 1);
} }
elem_of_each_node_[left_node_id] = Elem{begin, begin + n_left, left_node_id}; elem_of_each_node_[left_node_id] = Elem{begin, begin + n_left, left_node_id};
elem_of_each_node_[right_node_id] = Elem{begin + n_left, e.end, right_node_id}; elem_of_each_node_[right_node_id] = Elem{begin + n_left, end, right_node_id};
elem_of_each_node_[node_id] = Elem{nullptr, nullptr, -1}; elem_of_each_node_[node_id] = Elem{nullptr, nullptr, -1};
} }
private: private:
// stores the row indexes in the set // stores the row indexes in the set
std::vector<std::size_t> row_indices_; std::vector<bst_idx_t> row_indices_;
// vector: node_id -> elements // vector: node_id -> elements
std::vector<Elem> elem_of_each_node_; std::vector<Elem> elem_of_each_node_;
}; };

View File

@ -7,7 +7,7 @@
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#include <algorithm> // for all_of, fill #include <algorithm> // for all_of, fill
#include <cinttypes> // for uint32_t #include <cstdint> // for uint32_t, int32_t
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
#include <vector> // for vector #include <vector> // for vector
@ -18,7 +18,7 @@
#include "../common/partition_builder.h" // for PartitionBuilder #include "../common/partition_builder.h" // for PartitionBuilder
#include "../common/row_set.h" // for RowSetCollection #include "../common/row_set.h" // for RowSetCollection
#include "../common/threading_utils.h" // for ParallelFor2d #include "../common/threading_utils.h" // for ParallelFor2d
#include "xgboost/base.h" // for bst_row_t #include "xgboost/base.h" // for bst_idx_t
#include "xgboost/collective/result.h" // for Success, SafeColl #include "xgboost/collective/result.h" // for Success, SafeColl
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView #include "xgboost/linalg.h" // for TensorView
@ -46,7 +46,7 @@ class ColumnSplitHelper {
void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads, void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads,
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix, GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
std::vector<ExpandEntry> const& nodes, std::vector<ExpandEntry> const& nodes,
std::vector<int32_t> const& split_conditions, RegTree const* p_tree) { std::vector<std::int32_t> const& split_conditions, RegTree const* p_tree) {
// When data is split by column, we don't have all the feature values in the local worker, so // When data is split by column, we don't have all the feature values in the local worker, so
// we first collect all the decisions and whether the feature is missing into bit vectors. // we first collect all the decisions and whether the feature is missing into bit vectors.
std::fill(decision_storage_.begin(), decision_storage_.end(), 0); std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
@ -56,7 +56,7 @@ class ColumnSplitHelper {
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0; bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
partition_builder_->MaskRows<BinIdxType, any_missing, any_cat>( partition_builder_->MaskRows<BinIdxType, any_missing, any_cat>(
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree, node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
(*row_set_collection_)[nid].begin, &decision_bits_, &missing_bits_); (*row_set_collection_)[nid].begin(), &decision_bits_, &missing_bits_);
}); });
// Then aggregate the bit vectors across all the workers. // Then aggregate the bit vectors across all the workers.
@ -74,7 +74,7 @@ class ColumnSplitHelper {
const size_t task_id = partition_builder_->GetTaskIdx(node_in_set, begin); const size_t task_id = partition_builder_->GetTaskIdx(node_in_set, begin);
partition_builder_->AllocateForTask(task_id); partition_builder_->AllocateForTask(task_id);
partition_builder_->PartitionByMask(node_in_set, nodes, r, gmat, *p_tree, partition_builder_->PartitionByMask(node_in_set, nodes, r, gmat, *p_tree,
(*row_set_collection_)[nid].begin, decision_bits_, (*row_set_collection_)[nid].begin(), decision_bits_,
missing_bits_); missing_bits_);
}); });
} }
@ -98,10 +98,10 @@ class CommonRowPartitioner {
bool is_col_split) bool is_col_split)
: base_rowid{_base_rowid}, is_col_split_{is_col_split} { : base_rowid{_base_rowid}, is_col_split_{is_col_split} {
row_set_collection_.Clear(); row_set_collection_.Clear();
std::vector<size_t>& row_indices = *row_set_collection_.Data(); std::vector<bst_idx_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(num_row); row_indices.resize(num_row);
std::size_t* p_row_indices = row_indices.data(); bst_idx_t* p_row_indices = row_indices.data();
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid); common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
row_set_collection_.Init(); row_set_collection_.Init();
@ -112,7 +112,7 @@ class CommonRowPartitioner {
template <typename ExpandEntry> template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree, void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) { const GHistIndexMatrix& gmat, std::vector<bst_bin_t>* split_conditions) {
auto const& ptrs = gmat.cut.Ptrs(); auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values(); auto const& vals = gmat.cut.Values();
@ -197,7 +197,7 @@ class CommonRowPartitioner {
// 1. Find split condition for each split // 1. Find split condition for each split
size_t n_nodes = nodes.size(); size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions; std::vector<bst_bin_t> split_conditions;
if (column_matrix.IsInitialized()) { if (column_matrix.IsInitialized()) {
split_conditions.resize(n_nodes); split_conditions.resize(n_nodes);
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
@ -206,8 +206,8 @@ class CommonRowPartitioner {
// 2.1 Create a blocked space of size SUM(samples in each node) // 2.1 Create a blocked space of size SUM(samples in each node)
common::BlockedSpace2d space( common::BlockedSpace2d space(
n_nodes, n_nodes,
[&](size_t node_in_set) { [&](std::size_t node_in_set) {
int32_t nid = nodes[node_in_set].nid; auto nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size(); return row_set_collection_[nid].Size();
}, },
kPartitionBlockSize); kPartitionBlockSize);
@ -236,7 +236,7 @@ class CommonRowPartitioner {
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0; bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>( partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree, node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin); row_set_collection_[nid].begin());
}); });
} }
@ -248,8 +248,7 @@ class CommonRowPartitioner {
// with updated row-indexes for each tree-node // with updated row-indexes for each tree-node
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid; const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(), partition_builder_.MergeToArray(node_in_set, r.begin(), row_set_collection_[nid].begin());
const_cast<size_t*>(row_set_collection_[nid].begin));
}); });
// 5. Add info about splits into row_set_collection_ // 5. Add info about splits into row_set_collection_

View File

@ -739,7 +739,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) { if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx]; auto const &rowset = part[nidx];
auto leaf_value = tree[nidx].LeafValue(); auto leaf_value = tree[nidx].LeafValue();
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { for (auto const *it = rowset.begin() + r.begin(); it < rowset.begin() + r.end(); ++it) {
out_preds(*it) += leaf_value; out_preds(*it) += leaf_value;
} }
} }
@ -774,7 +774,8 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
if (tree.IsLeaf(nidx)) { if (tree.IsLeaf(nidx)) {
auto const &rowset = part[nidx]; auto const &rowset = part[nidx];
auto leaf_value = mttree->LeafValue(nidx); auto leaf_value = mttree->LeafValue(nidx);
for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { for (bst_idx_t const *it = rowset.begin() + r.begin(); it < rowset.begin() + r.end();
++it) {
for (std::size_t i = 0; i < n_targets; ++i) { for (std::size_t i = 0; i < n_targets; ++i) {
out_preds(*it, i) += leaf_value(i); out_preds(*it, i) += leaf_value(i);
} }

View File

@ -76,13 +76,13 @@ class HistogramBuilder {
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
const auto tid = static_cast<unsigned>(omp_get_thread_num()); const auto tid = static_cast<unsigned>(omp_get_thread_num());
bst_node_t const nidx = nodes_to_build[nid_in_set]; bst_node_t const nidx = nodes_to_build[nid_in_set];
auto elem = row_set_collection[nidx]; auto const& elem = row_set_collection[nidx];
auto start_of_row_set = std::min(r.begin(), elem.Size()); auto start_of_row_set = std::min(r.begin(), elem.Size());
auto end_of_row_set = std::min(r.end(), elem.Size()); auto end_of_row_set = std::min(r.end(), elem.Size());
auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set, auto rid_set = common::Span<bst_idx_t const>{elem.begin() + start_of_row_set,
elem.begin + end_of_row_set, nidx); elem.begin() + end_of_row_set};
auto hist = buffer_.GetInitializedHist(tid, nid_in_set); auto hist = buffer_.GetInitializedHist(tid, nid_in_set);
if (rid_set.Size() != 0) { if (rid_set.size() != 0) {
common::BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, force_read_by_column); common::BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, force_read_by_column);
} }
}); });

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2023 by XGBoost contributors * Copyright 2020-2024, XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -58,7 +58,7 @@ TEST(PartitionBuilder, BasicTest) {
} }
builder.CalculateRowOffsets(); builder.CalculateRowOffsets();
std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize); std::vector<bst_idx_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize);
for(size_t nid = 0; nid < kNodes; ++nid) { for(size_t nid = 0; nid < kNodes; ++nid) {

View File

@ -45,7 +45,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
// dense, no missing values // dense, no missing values
GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false); GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false);
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows); row_indices.resize(kRows);
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();
@ -53,7 +53,9 @@ void TestEvaluateSplits(bool force_read_by_column) {
HistMakerTrainParam hist_param; HistMakerTrainParam hist_param;
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node); hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
hist.AllocateHistograms({0}); hist.AllocateHistograms({0});
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column); auto const &elem = row_set_collection[0];
common::BuildHist<false>(row_gpairs, common::Span{elem.begin(), elem.end()}, gmat, hist[0],
force_read_by_column);
// Compute total gradient for all data points // Compute total gradient for all data points
GradientPairPrecise total_gpair; GradientPairPrecise total_gpair;

View File

@ -14,7 +14,6 @@
#include <algorithm> // for max #include <algorithm> // for max
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint32_t #include <cstdint> // for int32_t, uint32_t
#include <functional> // for function
#include <iterator> // for back_inserter #include <iterator> // for back_inserter
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
#include <memory> // for shared_ptr, allocator, unique_ptr #include <memory> // for shared_ptr, allocator, unique_ptr
@ -108,7 +107,7 @@ void TestSyncHist(bool is_distributed) {
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
{ {
row_set_collection.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kNRows); row_indices.resize(kNRows);
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();
@ -251,7 +250,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
row_set_collection.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kNRows); row_indices.resize(kNRows);
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();
@ -345,7 +344,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
row_set_collection.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<bst_idx_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows); row_indices.resize(kRows);
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();

View File

@ -3,7 +3,6 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../src/common/numeric.h"
#include "../../../src/tree/common_row_partitioner.h" #include "../../../src/tree/common_row_partitioner.h"
#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../collective/test_worker.h" // for TestDistributedGlobal
#include "../helpers.h" #include "../helpers.h"
@ -54,20 +53,23 @@ TEST(Approx, Partitioner) {
GetSplit(&tree, split_value, &candidates); GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree); partitioner.UpdatePosition(&ctx, page, candidates, &tree);
{
auto left_nidx = tree[RegTree::kRoot].LeftChild(); auto left_nidx = tree[RegTree::kRoot].LeftChild();
auto elem = partitioner[left_nidx]; auto const& elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
for (auto it = elem.begin; it != elem.end; ++it) { for (auto& it : elem) {
auto value = page.cut.Values().at(page.index[*it]); auto value = page.cut.Values().at(page.index[it]);
ASSERT_LE(value, split_value); ASSERT_LE(value, split_value);
} }
}
{
auto right_nidx = tree[RegTree::kRoot].RightChild(); auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx]; auto const& elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) { for (auto& it : elem) {
auto value = page.cut.Values().at(page.index[*it]); auto value = page.cut.Values().at(page.index[it]);
ASSERT_GT(value, split_value) << *it; ASSERT_GT(value, split_value) << it;
}
} }
} }
} }
@ -99,26 +101,28 @@ void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared
RegTree tree; RegTree tree;
GetSplit(&tree, mid_value, &candidates); GetSplit(&tree, mid_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree); partitioner.UpdatePosition(&ctx, page, candidates, &tree);
{
auto left_nidx = tree[RegTree::kRoot].LeftChild(); auto left_nidx = tree[RegTree::kRoot].LeftChild();
auto elem = partitioner[left_nidx]; auto const& elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
auto expected_elem = expected_mid_partitioner[left_nidx]; auto const& expected_elem = expected_mid_partitioner[left_nidx];
ASSERT_EQ(elem.Size(), expected_elem.Size()); ASSERT_EQ(elem.Size(), expected_elem.Size());
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) { for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
ASSERT_EQ(*it, *eit); ASSERT_EQ(*it, *eit);
} }
}
{
auto right_nidx = tree[RegTree::kRoot].RightChild(); auto right_nidx = tree[RegTree::kRoot].RightChild();
elem = partitioner[right_nidx]; auto const& elem = partitioner[right_nidx];
expected_elem = expected_mid_partitioner[right_nidx]; auto const& expected_elem = expected_mid_partitioner[right_nidx];
ASSERT_EQ(elem.Size(), expected_elem.Size()); ASSERT_EQ(elem.Size(), expected_elem.Size());
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) { for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
ASSERT_EQ(*it, *eit); ASSERT_EQ(*it, *eit);
} }
} }
} }
}
} }
} // anonymous namespace } // anonymous namespace

View File

@ -5,7 +5,6 @@
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <algorithm>
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <string> #include <string>
#include <vector> #include <vector>
@ -68,24 +67,27 @@ void TestPartitioner(bst_target_t n_targets) {
} else { } else {
GetMultiSplitForTest(&tree, split_value, &candidates); GetMultiSplitForTest(&tree, split_value, &candidates);
} }
auto left_nidx = tree.LeftChild(RegTree::kRoot);
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
{
auto elem = partitioner[left_nidx]; auto left_nidx = tree.LeftChild(RegTree::kRoot);
auto const& elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
for (auto it = elem.begin; it != elem.end; ++it) { for (auto& it : elem) {
auto value = gmat.cut.Values().at(gmat.index[*it]); auto value = gmat.cut.Values().at(gmat.index[it]);
ASSERT_LE(value, split_value); ASSERT_LE(value, split_value);
} }
}
{
auto right_nidx = tree.RightChild(RegTree::kRoot); auto right_nidx = tree.RightChild(RegTree::kRoot);
elem = partitioner[right_nidx]; auto const& elem = partitioner[right_nidx];
for (auto it = elem.begin; it != elem.end; ++it) { for (auto& it : elem) {
auto value = gmat.cut.Values().at(gmat.index[*it]); auto value = gmat.cut.Values().at(gmat.index[it]);
ASSERT_GT(value, split_value); ASSERT_GT(value, split_value);
} }
} }
} }
}
} }
} // anonymous namespace } // anonymous namespace
@ -138,24 +140,27 @@ void VerifyColumnSplitPartitioner(bst_target_t n_targets, size_t n_samples,
auto left_nidx = tree.LeftChild(RegTree::kRoot); auto left_nidx = tree.LeftChild(RegTree::kRoot);
partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree); partitioner.UpdatePosition<false, true>(&ctx, gmat, column_indices, candidates, &tree);
auto elem = partitioner[left_nidx]; {
auto const& elem = partitioner[left_nidx];
ASSERT_LT(elem.Size(), n_samples); ASSERT_LT(elem.Size(), n_samples);
ASSERT_GT(elem.Size(), 1); ASSERT_GT(elem.Size(), 1);
auto expected_elem = expected_mid_partitioner[left_nidx]; auto const& expected_elem = expected_mid_partitioner[left_nidx];
ASSERT_EQ(elem.Size(), expected_elem.Size()); ASSERT_EQ(elem.Size(), expected_elem.Size());
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) { for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
ASSERT_EQ(*it, *eit); ASSERT_EQ(*it, *eit);
} }
}
{
auto right_nidx = tree.RightChild(RegTree::kRoot); auto right_nidx = tree.RightChild(RegTree::kRoot);
elem = partitioner[right_nidx]; auto const& elem = partitioner[right_nidx];
expected_elem = expected_mid_partitioner[right_nidx]; auto const& expected_elem = expected_mid_partitioner[right_nidx];
ASSERT_EQ(elem.Size(), expected_elem.Size()); ASSERT_EQ(elem.Size(), expected_elem.Size());
for (auto it = elem.begin, eit = expected_elem.begin; it != elem.end; ++it, ++eit) { for (auto it = elem.begin(), eit = expected_elem.begin(); it != elem.end(); ++it, ++eit) {
ASSERT_EQ(*it, *eit); ASSERT_EQ(*it, *eit);
} }
} }
} }
}
} }
template <typename ExpandEntry> template <typename ExpandEntry>