Small refactor for hist builder. (#8698)

- Use span instead of vector as parameter. No perf change as the builder work on pointer.
- Use const pointer for reg tree.
This commit is contained in:
Jiaming Yuan 2023-01-30 14:06:41 +08:00 committed by GitHub
parent 8af98e30fc
commit 21a28f2cc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 120 deletions

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2020 by Contributors
/**
* Copyright 2017-2023 by XGBoost Contributors
* \file hist_util.cc
*/
#include <dmlc/timer.h>
@ -193,7 +193,7 @@ class GHistBuildingManager {
};
template <bool do_prefetch, class BuildingManager>
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
@ -262,7 +262,7 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
}
template <class BuildingManager>
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
@ -315,9 +315,8 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
}
template <class BuildingManager>
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist) {
if (BuildingManager::kReadByColumn) {
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
} else {
@ -344,9 +343,8 @@ void BuildHistDispatch(const std::vector<GradientPair> &gpair,
}
template <bool any_missing>
void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat,
void GHistBuilder::BuildHist(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist, bool force_read_by_column) const {
/* force_read_by_column is used for testing the columnwise building of histograms.
* default force_read_by_column = false
@ -358,19 +356,18 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
auto bin_type_size = gmat.index.GetBinTypeSize();
GHistBuildingManager<any_missing>::DispatchAndExecute(
{first_page, read_by_column || force_read_by_column, bin_type_size},
[&](auto t) {
{first_page, read_by_column || force_read_by_column, bin_type_size}, [&](auto t) {
using BuildingManager = decltype(t);
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
});
}
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
template void GHistBuilder::BuildHist<true>(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column) const;
template void GHistBuilder::BuildHist<false>(const std::vector<GradientPair> &gpair,
template void GHistBuilder::BuildHist<false>(Span<GradientPair const> gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow hist,
bool force_read_by_column) const;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2017-2022 by XGBoost Contributors
/**
* Copyright 2017-2023 by XGBoost Contributors
* \file hist_util.h
* \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen
@ -23,6 +23,7 @@
#include "row_set.h"
#include "threading_utils.h"
#include "timer.h"
#include "xgboost/base.h" // bst_feature_t, bst_bin_t
namespace xgboost {
class GHistIndexMatrix;
@ -320,10 +321,10 @@ struct Index {
};
template <typename GradientIndex>
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t end,
GradientIndex const& data,
uint32_t const fidx_begin,
uint32_t const fidx_end) {
bst_feature_t const fidx_begin,
bst_feature_t const fidx_end) {
size_t previous_middle = std::numeric_limits<size_t>::max();
while (end != begin) {
size_t middle = begin + (end - begin) / 2;
@ -635,7 +636,7 @@ class GHistBuilder {
// construct a histogram via histogram aggregation
template <bool any_missing>
void BuildHist(const std::vector<GradientPair>& gpair, const RowSetCollection::Elem row_indices,
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, GHistRow hist,
bool force_read_by_column = false) const;
uint32_t GetNumBins() const {

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
@ -59,15 +59,14 @@ class HistogramBuilder {
GHistIndexMatrix const &gidx,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
common::RowSetCollection const &row_set_collection,
const std::vector<GradientPair> &gpair_h,
bool force_read_by_column) {
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
const size_t n_nodes = nodes_for_explicit_hist_build.size();
CHECK_GT(n_nodes, 0);
std::vector<common::GHistRow> target_hists(n_nodes);
for (size_t i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes_for_explicit_hist_build[i].nid;
target_hists[i] = hist_[nid];
auto const nidx = nodes_for_explicit_hist_build[i].nid;
target_hists[i] = hist_[nidx];
}
if (page_idx == 0) {
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
@ -93,46 +92,37 @@ class HistogramBuilder {
});
}
void
AddHistRows(int *starting_index, int *sync_count,
void AddHistRows(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree *p_tree) {
RegTree const *p_tree) {
if (is_distributed_) {
this->AddHistRowsDistributed(starting_index, sync_count,
nodes_for_explicit_hist_build,
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
} else {
this->AddHistRowsLocal(starting_index, sync_count,
nodes_for_explicit_hist_build,
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
}
}
/** Main entry point of this class, build histogram for tree nodes. */
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
RegTree *p_tree, common::RowSetCollection const &row_set_collection,
RegTree const *p_tree, common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair,
bool force_read_by_column = false) {
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build,
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
}
if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair,
force_read_by_column);
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
row_set_collection, gpair, force_read_by_column);
} else {
this->BuildLocalHistograms<true>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair,
force_read_by_column);
this->BuildLocalHistograms<true>(page_id, space, gidx, nodes_for_explicit_hist_build,
row_set_collection, gpair, force_read_by_column);
}
CHECK_GE(n_batches_, 1);
@ -153,8 +143,7 @@ class HistogramBuilder {
common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair,
bool force_read_by_column = false) {
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
const size_t n_nodes = nodes_for_explicit_hist_build.size();
// create space of size (# rows in each node)
common::BlockedSpace2d space(
@ -164,22 +153,18 @@ class HistogramBuilder {
return row_set_collection[nidx].Size();
},
256);
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection,
nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
gpair, force_read_by_column);
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, gpair, force_read_by_column);
}
void SyncHistogramDistributed(
RegTree *p_tree,
void SyncHistogramDistributed(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) {
const size_t nbins = builder_.GetNumBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
1024);
common::ParallelFor2d(
space, n_threads_, [&](size_t node, common::Range1d r) {
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
@ -190,12 +175,10 @@ class HistogramBuilder {
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
const int subtraction_node_id =
nodes_for_subtraction_trick[node].nid;
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
r.begin(), r.end());
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
@ -206,39 +189,34 @@ class HistogramBuilder {
reinterpret_cast<double *>(this->hist_[starting_index].data()),
builder_.GetNumBins() * sync_count * 2);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
p_tree);
common::BlockedSpace2d space2(
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; },
1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick,
nodes_for_explicit_hist_build, p_tree);
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; }, 1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
p_tree);
}
void SyncHistogramLocal(RegTree *p_tree,
void SyncHistogramLocal(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
const size_t nbins = this->builder_.GetNumBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
1024);
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot()) {
const size_t parent_id = (*p_tree)[entry.nid].Parent();
const int subtraction_node_id =
nodes_for_subtraction_trick[node].nid;
auto const parent_id = (*p_tree)[entry.nid].Parent();
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
r.begin(), r.end());
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
}
});
}
@ -289,11 +267,10 @@ class HistogramBuilder {
this->hist_.AllocateAllData();
}
void AddHistRowsDistributed(
int *starting_index, int *sync_count,
void AddHistRowsDistributed(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree *p_tree) {
RegTree const *p_tree) {
const size_t explicit_size = nodes_for_explicit_hist_build.size();
const size_t subtaction_size = nodes_for_subtraction_trick.size();
std::vector<int> merged_node_ids(explicit_size + subtaction_size);