Unify the partitioner for hist and approx.

Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com>
Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Dmitry Razdoburdin
2022-10-19 20:49:20 +02:00
committed by GitHub
parent c69af90319
commit 5bd849f1b5
13 changed files with 358 additions and 450 deletions

View File

@@ -0,0 +1,212 @@
/*!
* Copyright 2021-2022 XGBoost contributors
* \file common_row_partitioner.h
* \brief Common partitioner logic for hist and approx methods.
*/
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#include <limits> // std::numeric_limits
#include <vector>
#include "../common/numeric.h" // Iota
#include "../common/partition_builder.h"
#include "hist/expand_entry.h" // CPUExpandEntry
#include "xgboost/generic_parameters.h" // Context
namespace xgboost {
namespace tree {
class CommonRowPartitioner {
static constexpr size_t kPartitionBlockSize = 2048;
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
common::RowSetCollection row_set_collection_;
public:
bst_row_t base_rowid = 0;
CommonRowPartitioner() = default;
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid)
: base_rowid{_base_rowid} {
row_set_collection_.Clear();
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(num_row);
std::size_t* p_row_indices = row_indices.data();
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
row_set_collection_.Init();
}
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
for (size_t i = 0; i < nodes.size(); ++i) {
const int32_t nid = nodes[i].nid;
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) {
split_cond = static_cast<int32_t>(bound);
}
}
(*split_conditions).at(i) = split_cond;
}
}
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree) {
const size_t n_nodes = nodes.size();
for (unsigned int i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
n_left, n_right);
}
}
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
auto const& column_matrix = gmat.Transpose();
if (column_matrix.IsInitialized()) {
if (gmat.cut.HasCategorical()) {
this->template UpdatePosition<true>(ctx, gmat, column_matrix, nodes, p_tree);
} else {
this->template UpdatePosition<false>(ctx, gmat, column_matrix, nodes, p_tree);
}
} else {
/* ColumnMatrix is not initilized.
* It means that we use 'approx' method.
* any_missing and any_cat don't metter in this case.
* Jump directly to the main method.
*/
this->template UpdatePosition<uint8_t, true, true>(ctx, gmat, column_matrix, nodes, p_tree);
}
}
template <bool any_cat>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
} else {
this->template UpdatePosition<false, any_cat>(ctx, gmat, column_matrix, nodes, p_tree);
}
}
template <bool any_missing, bool any_cat>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
switch (column_matrix.GetTypeSize()) {
case common::kUint8BinsTypeSize:
this->template UpdatePosition<uint8_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
case common::kUint16BinsTypeSize:
this->template UpdatePosition<uint16_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
case common::kUint32BinsTypeSize:
this->template UpdatePosition<uint32_t, any_missing, any_cat>(ctx, gmat, column_matrix,
nodes, p_tree);
break;
default:
// no default behavior
CHECK(false) << column_matrix.GetTypeSize();
}
}
template <typename BinIdxType, bool any_missing, bool any_cat>
void UpdatePosition(Context const* ctx, GHistIndexMatrix const& gmat,
const common::ColumnMatrix& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
// 1. Find split condition for each split
size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
if (column_matrix.IsInitialized()) {
split_conditions.resize(n_nodes);
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
}
// 2.1 Create a blocked space of size SUM(samples in each node)
common::BlockedSpace2d space(
n_nodes,
[&](size_t node_in_set) {
int32_t nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size();
},
kPartitionBlockSize);
// 2.2 Initialize the partition builder
// allocate buffers for storage intermediate results by each thread
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
const size_t size = row_set_collection_[nid].Size();
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
return n_tasks;
});
CHECK_EQ(base_rowid, gmat.base_rowid);
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
});
// 3. Compute offsets to copy blocks of row-indexes
// from partition_builder_ to row_set_collection_
partition_builder_.CalculateRowOffsets();
// 4. Copy elements from partition_builder_ to row_set_collection_ back
// with updated row-indexes for each tree-node
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t*>(row_set_collection_[nid].begin));
});
// 5. Add info about splits into row_set_collection_
AddSplitsToRowSet(nodes, p_tree);
}
auto const& Partitions() const { return row_set_collection_; }
size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
}
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
}
void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_

View File

@@ -3,14 +3,13 @@
*
* \brief Implementation for the approx tree method.
*/
#include "updater_approx.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "../common/random.h"
#include "../data/gradient_index.h"
#include "common_row_partitioner.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
@@ -46,7 +45,7 @@ class GloablApproxBuilder {
Context const *ctx_;
ObjInfo const task_;
std::vector<ApproxRowPartitioner> partitioner_;
std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree *p_last_tree_{nullptr};
common::Monitor *monitor_;
@@ -69,7 +68,7 @@ class GloablApproxBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid);
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
n_batches_++;
}
@@ -151,7 +150,7 @@ class GloablApproxBuilder {
monitor_->Stop(__func__);
}
void LeafPartition(RegTree const &tree, common::Span<float> hess,
void LeafPartition(RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!task_.UpdateTreeLeaf()) {

View File

@@ -1,150 +0,0 @@
/*!
* Copyright 2021-2022 XGBoost contributors
*
* \brief Implementation for the approx tree method.
*/
#ifndef XGBOOST_TREE_UPDATER_APPROX_H_
#define XGBOOST_TREE_UPDATER_APPROX_H_
#include <limits>
#include <utility>
#include <vector>
#include "../common/partition_builder.h"
#include "../common/random.h"
#include "constraints.h"
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/expand_entry.h"
#include "param.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/json.h"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace tree {
class ApproxRowPartitioner {
static constexpr size_t kPartitionBlockSize = 2048;
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
common::RowSetCollection row_set_collection_;
public:
bst_row_t base_rowid = 0;
static auto SearchCutValue(bst_row_t ridx, bst_feature_t fidx, GHistIndexMatrix const &index,
std::vector<uint32_t> const &cut_ptrs,
std::vector<float> const &cut_values) {
int32_t gidx = -1;
if (index.IsDense()) {
// RowIdx returns the starting pos of this row
gidx = index.index[index.RowIdx(ridx) + fidx];
} else {
auto begin = index.RowIdx(ridx);
auto end = index.RowIdx(ridx + 1);
auto f_begin = cut_ptrs[fidx];
auto f_end = cut_ptrs[fidx + 1];
gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end);
}
if (gidx == -1) {
return std::numeric_limits<float>::quiet_NaN();
}
return cut_values[gidx];
}
public:
void UpdatePosition(GenericParameter const *ctx, GHistIndexMatrix const &index,
std::vector<CPUExpandEntry> const &candidates, RegTree const *p_tree) {
size_t n_nodes = candidates.size();
auto const &cut_values = index.cut.Values();
auto const &cut_ptrs = index.cut.Ptrs();
common::BlockedSpace2d space{n_nodes,
[&](size_t node_in_set) {
auto candidate = candidates[node_in_set];
int32_t nid = candidate.nid;
return row_set_collection_[nid].Size();
},
kPartitionBlockSize};
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
auto candidate = candidates[node_in_set];
const int32_t nid = candidate.nid;
const size_t size = row_set_collection_[nid].Size();
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
return n_tasks;
});
auto node_ptr = p_tree->GetCategoriesMatrix().node_ptr;
auto categories = p_tree->GetCategoriesMatrix().categories;
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
auto candidate = candidates[node_in_set];
auto is_cat = candidate.split.is_cat;
const int32_t nid = candidate.nid;
auto fidx = candidate.split.SplitIndex();
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, r.begin());
partition_builder_.AllocateForTask(task_id);
partition_builder_.PartitionRange(
node_in_set, nid, r, &row_set_collection_, [&](size_t row_id) {
auto cut_value = SearchCutValue(row_id, fidx, index, cut_ptrs, cut_values);
if (std::isnan(cut_value)) {
return candidate.split.DefaultLeft();
}
bst_node_t nidx = candidate.nid;
auto segment = node_ptr[nidx];
auto node_cats = categories.subspan(segment.beg, segment.size);
bool go_left = true;
if (is_cat) {
go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft());
} else {
go_left = cut_value <= candidate.split.split_value;
}
return go_left;
});
});
partition_builder_.CalculateRowOffsets();
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
auto candidate = candidates[node_in_set];
const int32_t nid = candidate.nid;
partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t *>(row_set_collection_[nid].begin));
});
for (size_t i = 0; i < candidates.size(); ++i) {
auto const &candidate = candidates[i];
auto nidx = candidate.nid;
auto n_left = partition_builder_.GetNLeftElems(i);
auto n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ(n_left + n_right, row_set_collection_[nidx].Size());
bst_node_t left_nidx = (*p_tree)[nidx].LeftChild();
bst_node_t right_nidx = (*p_tree)[nidx].RightChild();
row_set_collection_.AddSplit(nidx, left_nidx, right_nidx, n_left, n_right);
}
}
auto const &Partitions() const { return row_set_collection_; }
void LeafPartition(Context const *ctx, RegTree const &tree, common::Span<float const> hess,
std::vector<bst_node_t> *p_out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
}
auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
}
ApproxRowPartitioner() = default;
explicit ApproxRowPartitioner(bst_row_t num_row, bst_row_t _base_rowid)
: base_rowid{_base_rowid} {
row_set_collection_.Clear();
auto p_positions = row_set_collection_.Data();
p_positions->resize(num_row);
std::iota(p_positions->begin(), p_positions->end(), base_rowid);
row_set_collection_.Init();
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_APPROX_H_

View File

@@ -12,7 +12,9 @@
#include <utility>
#include <vector>
#include "common_row_partitioner.h"
#include "constraints.h"
#include "hist/histogram.h"
#include "hist/evaluate_splits.h"
#include "param.h"
#include "xgboost/logging.h"
@@ -309,7 +311,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
@@ -331,44 +333,6 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
monitor_->Stop(__func__);
}
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
const RegTree &tree, const GHistIndexMatrix &gmat,
std::vector<int32_t> *split_conditions) {
const size_t n_nodes = nodes.size();
split_conditions->resize(n_nodes);
for (size_t i = 0; i < nodes.size(); ++i) {
const int32_t nid = nodes[i].nid;
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) {
split_cond = static_cast<int32_t>(bound);
}
}
(*split_conditions)[i] = split_cond;
}
}
void HistRowPartitioner::AddSplitsToRowSet(const std::vector<CPUExpandEntry> &nodes,
RegTree const *p_tree) {
const size_t n_nodes = nodes.size();
for (unsigned int i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(),
n_left, n_right);
}
}
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([](GenericParameter const *ctx, ObjInfo task) {

View File

@@ -24,6 +24,7 @@
#include "hist/histogram.h"
#include "hist/expand_entry.h"
#include "common_row_partitioner.h"
#include "constraints.h"
#include "./param.h"
#include "./driver.h"
@@ -77,155 +78,6 @@ struct RandomReplace {
};
namespace tree {
class HistRowPartitioner {
// heuristically chosen block size of parallel partitioning
static constexpr size_t kPartitionBlockSize = 2048;
// worker class that partition a block of rows
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
// storage for row index
common::RowSetCollection row_set_collection_;
/**
* \brief Turn split values into discrete bin indices.
*/
static void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions);
/**
* \brief Update the row set for new splits specifed by nodes.
*/
void AddSplitsToRowSet(const std::vector<CPUExpandEntry>& nodes, RegTree const* p_tree);
public:
bst_row_t base_rowid = 0;
public:
HistRowPartitioner(size_t n_samples, size_t base_rowid, int32_t n_threads) {
row_set_collection_.Clear();
const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads);
dmlc::OMPException exc;
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(n_samples);
size_t* p_row_indices = row_indices.data();
// parallel initialization o f row indices. (std::iota)
#pragma omp parallel num_threads(n_threads)
{
exc.Run([&]() {
const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * block_size;
const size_t iend = std::min(static_cast<size_t>(ibegin + block_size), n_samples);
for (size_t i = ibegin; i < iend; ++i) {
p_row_indices[i] = i + base_rowid;
}
});
}
row_set_collection_.Init();
this->base_rowid = base_rowid;
}
template <bool any_missing, bool any_cat>
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& gmat,
common::ColumnMatrix const& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
// 1. Find split condition for each split
const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
// 2.1 Create a blocked space of size SUM(samples in each node)
common::BlockedSpace2d space(
n_nodes,
[&](size_t node_in_set) {
int32_t nid = nodes[node_in_set].nid;
return row_set_collection_[nid].Size();
},
kPartitionBlockSize);
// 2.2 Initialize the partition builder
// allocate buffers for storage intermediate results by each thread
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
const size_t size = row_set_collection_[nid].Size();
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
return n_tasks;
});
CHECK_EQ(base_rowid, gmat.base_rowid);
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
switch (column_matrix.GetTypeSize()) {
case common::kUint8BinsTypeSize:
partition_builder_.template Partition<uint8_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
case common::kUint16BinsTypeSize:
partition_builder_.template Partition<uint16_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
case common::kUint32BinsTypeSize:
partition_builder_.template Partition<uint32_t, any_missing, any_cat>(
node_in_set, nid, r, split_conditions[node_in_set], gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
break;
default:
// no default behavior
CHECK(false) << column_matrix.GetTypeSize();
}
});
// 3. Compute offsets to copy blocks of row-indexes
// from partition_builder_ to row_set_collection_
partition_builder_.CalculateRowOffsets();
// 4. Copy elements from partition_builder_ to row_set_collection_ back
// with updated row-indexes for each tree-node
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t*>(row_set_collection_[nid].begin));
});
// 5. Add info about splits into row_set_collection_
AddSplitsToRowSet(nodes, p_tree);
}
void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page,
std::vector<CPUExpandEntry> const& applied, RegTree const* p_tree) {
auto const& column_matrix = page.Transpose();
if (page.cut.HasCategorical()) {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, true>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, true>(ctx, page, column_matrix, applied, p_tree);
}
} else {
if (column_matrix.AnyMissing()) {
this->template UpdatePosition<true, false>(ctx, page, column_matrix, applied, p_tree);
} else {
this->template UpdatePosition<false, false>(ctx, page, column_matrix, applied, p_tree);
}
}
}
auto const& Partitions() const { return row_set_collection_; }
size_t Size() const {
return std::distance(row_set_collection_.begin(), row_set_collection_.end());
}
void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
}
auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; }
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
};
inline BatchParam HistBatch(TrainParam const& param) {
return {param.max_bin, param.sparse_threshold};
}
@@ -314,7 +166,7 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<GradientPair> gpair_local_;
std::unique_ptr<HistEvaluator<CPUExpandEntry>> evaluator_;
std::vector<HistRowPartitioner> partitioner_;
std::vector<CommonRowPartitioner> partitioner_;
// back pointers to tree and data matrix
const RegTree* p_last_tree_{nullptr};