Support external memory in CPU histogram building. (#7372)

This commit is contained in:
Jiaming Yuan 2021-11-23 01:13:33 +08:00 committed by GitHub
parent d33854af1b
commit 176110a22d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 305 additions and 162 deletions

View File

@ -133,36 +133,51 @@ struct Prefetch {
constexpr size_t Prefetch::kNoPrefetchSize; constexpr size_t Prefetch::kNoPrefetchSize;
template <typename FPType, bool do_prefetch, typename BinIdxType,
template<typename FPType, bool do_prefetch, typename BinIdxType, bool any_missing = true> bool first_page, bool any_missing = true>
void BuildHistKernel(const std::vector<GradientPair> &gpair, void BuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, const GHistIndexMatrix &gmat, GHistRow<FPType> hist) {
GHistRow<FPType> hist) {
const size_t size = row_indices.Size(); const size_t size = row_indices.Size();
const size_t *rid = row_indices.begin; const size_t *rid = row_indices.begin;
const float* pgh = reinterpret_cast<const float*>(gpair.data()); auto const *pgh = reinterpret_cast<const float *>(gpair.data());
const BinIdxType *gradient_index = gmat.index.data<BinIdxType>(); const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();
const size_t* row_ptr = gmat.row_ptr.data();
auto const &row_ptr = gmat.row_ptr.data();
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset(); const uint32_t *offsets = gmat.index.Offset();
const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; auto get_row_ptr = [&](size_t ridx) {
FPType* hist_data = reinterpret_cast<FPType*>(hist.data()); return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return first_page ? ridx : (ridx - base_rowid);
};
const size_t n_features =
get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]);
auto hist_data = reinterpret_cast<FPType *>(hist.data());
const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains
// 2 FP values: gradient and hessian. // 2 FP values: gradient and hessian.
// So we need to multiply each row-index/bin-index by 2 // So we need to multiply each row-index/bin-index by 2
// to work with gradient pairs as a singe row FP array // to work with gradient pairs as a singe row FP array
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const size_t icol_start = any_missing ? row_ptr[rid[i]] : rid[i] * n_features; const size_t icol_start =
const size_t icol_end = any_missing ? row_ptr[rid[i]+1] : icol_start + n_features; any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
const size_t icol_end =
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
const size_t row_size = icol_end - icol_start; const size_t row_size = icol_end - icol_start;
const size_t idx_gh = two * rid[i]; const size_t idx_gh = two * rid[i];
if (do_prefetch) { if (do_prefetch) {
const size_t icol_start_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]] : const size_t icol_start_prefetch =
rid[i + Prefetch::kPrefetchOffset] * n_features; any_missing
const size_t icol_end_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]+1] : ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
icol_start_prefetch + n_features; : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
const size_t icol_end_prefetch =
any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
: icol_start_prefetch + n_features;
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
for (size_t j = icol_start_prefetch; j < icol_end_prefetch; for (size_t j = icol_start_prefetch; j < icol_end_prefetch;
@ -173,9 +188,8 @@ void BuildHistKernel(const std::vector<GradientPair>& gpair,
const BinIdxType *gr_index_local = gradient_index + icol_start; const BinIdxType *gr_index_local = gradient_index + icol_start;
for (size_t j = 0; j < row_size; ++j) { for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) + ( const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
any_missing ? 0 : offsets[j])); (any_missing ? 0 : offsets[j]));
hist_data[idx_bin] += pgh[idx_gh]; hist_data[idx_bin] += pgh[idx_gh];
hist_data[idx_bin + 1] += pgh[idx_gh + 1]; hist_data[idx_bin + 1] += pgh[idx_gh + 1];
} }
@ -186,95 +200,94 @@ template<typename FPType, bool do_prefetch, bool any_missing>
void BuildHistDispatch(const std::vector<GradientPair> &gpair, void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, GHistRow<FPType> hist) { const GHistIndexMatrix &gmat, GHistRow<FPType> hist) {
auto first_page = gmat.base_rowid == 0;
if (first_page) {
switch (gmat.index.GetBinTypeSize()) { switch (gmat.index.GetBinTypeSize()) {
case kUint8BinsTypeSize: case kUint8BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint8_t, any_missing>(gpair, row_indices, BuildHistKernel<FPType, do_prefetch, uint8_t, true, any_missing>(
gmat, hist); gpair, row_indices, gmat, hist);
break; break;
case kUint16BinsTypeSize: case kUint16BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint16_t, any_missing>(gpair, row_indices, BuildHistKernel<FPType, do_prefetch, uint16_t, true, any_missing>(
gmat, hist); gpair, row_indices, gmat, hist);
break; break;
case kUint32BinsTypeSize: case kUint32BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint32_t, any_missing>(gpair, row_indices, BuildHistKernel<FPType, do_prefetch, uint32_t, true, any_missing>(
gmat, hist); gpair, row_indices, gmat, hist);
break; break;
default: default:
CHECK(false); // no default behavior CHECK(false); // no default behavior
} }
} else {
switch (gmat.index.GetBinTypeSize()) {
case kUint8BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint8_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint16BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint16_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
case kUint32BinsTypeSize:
BuildHistKernel<FPType, do_prefetch, uint32_t, false, any_missing>(
gpair, row_indices, gmat, hist);
break;
default:
CHECK(false); // no default behavior
}
}
} }
template <typename GradientSumT> template <typename GradientSumT>
template <bool any_missing> template <bool any_missing>
void GHistBuilder<GradientSumT>::BuildHist( void GHistBuilder<GradientSumT>::BuildHist(
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
const GHistIndexMatrix &gmat, GHistRowT hist) const {
GHistRowT hist) {
const size_t nrows = row_indices.Size(); const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
// if need to work with all rows from bin-matrix (e.g. root node) // if need to work with all rows from bin-matrix (e.g. root node)
const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); const bool contiguousBlock =
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);
if (contiguousBlock) { if (contiguousBlock) {
// contiguous memory access, built-in HW prefetching is enough // contiguous memory access, built-in HW prefetching is enough
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, row_indices, gmat, hist); BuildHistDispatch<GradientSumT, false, any_missing>(gpair, row_indices,
gmat, hist);
} else { } else {
const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); const RowSetCollection::Elem span1(row_indices.begin,
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); row_indices.end - no_prefetch_size);
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
row_indices.end);
BuildHistDispatch<GradientSumT, true, any_missing>(gpair, span1, gmat, hist); BuildHistDispatch<GradientSumT, true, any_missing>(gpair, span1, gmat,
hist);
// no prefetching to avoid loading extra memory // no prefetching to avoid loading extra memory
BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat, hist); BuildHistDispatch<GradientSumT, false, any_missing>(gpair, span2, gmat,
hist);
} }
} }
template void template void
GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair> &gpair, GHistBuilder<float>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
GHistRow<float> hist); GHistRow<float> hist) const;
template void template void
GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair> &gpair, GHistBuilder<float>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
GHistRow<float> hist); GHistRow<float> hist) const;
template void template void
GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair> &gpair, GHistBuilder<double>::BuildHist<true>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
GHistRow<double> hist); GHistRow<double> hist) const;
template void template void
GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair, GHistBuilder<double>::BuildHist<false>(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
GHistRow<double> hist); GHistRow<double> hist) const;
template<typename GradientSumT>
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT self,
GHistRowT sibling,
GHistRowT parent) {
const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);
const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);
ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
SubtractionHist(self, parent, sibling, ibegin, iend);
});
}
template
void GHistBuilder<float>::SubtractionTrick(GHistRow<float> self,
GHistRow<float> sibling,
GHistRow<float> parent);
template
void GHistBuilder<double>::SubtractionTrick(GHistRow<double> self,
GHistRow<double> sibling,
GHistRow<double> parent);
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -460,7 +460,7 @@ class ParallelGHistBuilder {
} }
// Reduce following bins (begin, end] for nid-node in dst across threads // Reduce following bins (begin, end] for nid-node in dst across threads
void ReduceHist(size_t nid, size_t begin, size_t end) { void ReduceHist(size_t nid, size_t begin, size_t end) const {
CHECK_GT(end, begin); CHECK_GT(end, begin);
CHECK_LT(nid, nodes_); CHECK_LT(nid, nodes_);
@ -486,7 +486,6 @@ class ParallelGHistBuilder {
} }
} }
protected:
void MatchThreadsToNodes(const BlockedSpace2d& space) { void MatchThreadsToNodes(const BlockedSpace2d& space) {
const size_t space_size = space.Size(); const size_t space_size = space.Size();
const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_); const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_);
@ -533,6 +532,7 @@ class ParallelGHistBuilder {
} }
} }
private:
void MatchNodeNidPairToHist() { void MatchNodeNidPairToHist() {
size_t hist_allocated_additionally = 0; size_t hist_allocated_additionally = 0;
@ -586,26 +586,18 @@ class GHistBuilder {
using GHistRowT = GHistRow<GradientSumT>; using GHistRowT = GHistRow<GradientSumT>;
GHistBuilder() = default; GHistBuilder() = default;
GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {}
// construct a histogram via histogram aggregation // construct a histogram via histogram aggregation
template <bool any_missing> template <bool any_missing>
void BuildHist(const std::vector<GradientPair> &gpair, void BuildHist(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, const GHistIndexMatrix &gmat, GHistRowT hist) const;
GHistRowT hist);
// construct a histogram via subtraction trick
void SubtractionTrick(GHistRowT self,
GHistRowT sibling,
GHistRowT parent);
uint32_t GetNumBins() const { uint32_t GetNumBins() const {
return nbins_; return nbins_;
} }
private: private:
/*! \brief number of threads for parallel computation */
size_t nthread_ { 0 };
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_ { 0 }; uint32_t nbins_ { 0 };
}; };

View File

@ -11,6 +11,7 @@
#include "rabit/rabit.h" #include "rabit/rabit.h"
#include "xgboost/tree_model.h" #include "xgboost/tree_model.h"
#include "../../common/hist_util.h" #include "../../common/hist_util.h"
#include "../../data/gradient_index.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -25,8 +26,9 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
common::GHistBuilder<GradientSumT> builder_; common::GHistBuilder<GradientSumT> builder_;
common::ParallelGHistBuilder<GradientSumT> buffer_; common::ParallelGHistBuilder<GradientSumT> buffer_;
rabit::Reducer<GradientPairT, GradientPairT::Reduce> reducer_; rabit::Reducer<GradientPairT, GradientPairT::Reduce> reducer_;
int32_t max_bin_ {-1}; BatchParam param_;
int32_t n_threads_{-1}; int32_t n_threads_{-1};
size_t n_batches_{0};
// Whether XGBoost is running in distributed environment. // Whether XGBoost is running in distributed environment.
bool is_distributed_{false}; bool is_distributed_{false};
@ -39,59 +41,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
* \param is_distributed Mostly used for testing to allow injecting parameters instead * \param is_distributed Mostly used for testing to allow injecting parameters instead
* of using global rabit variable. * of using global rabit variable.
*/ */
void Reset(uint32_t total_bins, int32_t max_bin_per_feat, int32_t n_threads, void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
bool is_distributed = rabit::IsDistributed()) { bool is_distributed) {
CHECK_GE(n_threads, 1); CHECK_GE(n_threads, 1);
n_threads_ = n_threads; n_threads_ = n_threads;
CHECK_GE(max_bin_per_feat, 2); n_batches_ = n_batches;
max_bin_ = max_bin_per_feat; param_ = p;
hist_.Init(total_bins); hist_.Init(total_bins);
hist_local_worker_.Init(total_bins); hist_local_worker_.Init(total_bins);
buffer_.Init(total_bins); buffer_.Init(total_bins);
builder_ = common::GHistBuilder<GradientSumT>(n_threads, total_bins); builder_ = common::GHistBuilder<GradientSumT>(total_bins);
is_distributed_ = is_distributed; is_distributed_ = is_distributed;
} }
template <bool any_missing> template <bool any_missing>
void void BuildLocalHistograms(size_t page_idx, common::BlockedSpace2d space,
BuildLocalHistograms(DMatrix *p_fmat, GHistIndexMatrix const &gidx,
std::vector<ExpandEntry> nodes_for_explicit_hist_build, std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
common::RowSetCollection const &row_set_collection, common::RowSetCollection const &row_set_collection,
const std::vector<GradientPair> &gpair_h) { const std::vector<GradientPair> &gpair_h) {
const size_t n_nodes = nodes_for_explicit_hist_build.size(); const size_t n_nodes = nodes_for_explicit_hist_build.size();
CHECK_GT(n_nodes, 0);
// create space of size (# rows in each node)
common::BlockedSpace2d space(
n_nodes,
[&](size_t node) {
const int32_t nid = nodes_for_explicit_hist_build[node].nid;
return row_set_collection[nid].Size();
},
256);
std::vector<GHistRowT> target_hists(n_nodes); std::vector<GHistRowT> target_hists(n_nodes);
for (size_t i = 0; i < n_nodes; ++i) { for (size_t i = 0; i < n_nodes; ++i) {
const int32_t nid = nodes_for_explicit_hist_build[i].nid; const int32_t nid = nodes_for_explicit_hist_build[i].nid;
target_hists[i] = hist_[nid]; target_hists[i] = hist_[nid];
} }
if (page_idx == 0) {
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
// partition size for the buffer, which might not be efficient if partition sizes
// has significant variance.
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); buffer_.Reset(this->n_threads_, n_nodes, space, target_hists);
}
// Parallel processing by nodes and data in each node // Parallel processing by nodes and data in each node
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>( common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
BatchParam{GenericParameter::kCpuId, max_bin_})) {
common::ParallelFor2d(
space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
const auto tid = static_cast<unsigned>(omp_get_thread_num()); const auto tid = static_cast<unsigned>(omp_get_thread_num());
const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid;
auto elem = row_set_collection[nid];
auto start_of_row_set = row_set_collection[nid].begin; auto start_of_row_set = std::min(r.begin(), elem.Size());
auto rid_set = common::RowSetCollection::Elem( auto end_of_row_set = std::min(r.end(), elem.Size());
start_of_row_set + r.begin(), start_of_row_set + r.end(), nid); auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set,
builder_.template BuildHist<any_missing>( elem.begin + end_of_row_set, nid);
gpair_h, rid_set, gmat, auto hist = buffer_.GetInitializedHist(tid, nid_in_set);
buffer_.GetInitializedHist(tid, nid_in_set)); if (rid_set.Size() != 0) {
}); builder_.template BuildHist<any_missing>(gpair_h, rid_set, gidx, hist);
} }
});
} }
void void
@ -110,24 +107,34 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
} }
} }
/* Main entry point of this class, build histogram for tree nodes. */ /** Main entry point of this class, build histogram for tree nodes. */
void BuildHist(DMatrix *p_fmat, RegTree *p_tree, void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
common::RowSetCollection const &row_set_collection, RegTree *p_tree, common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair) { std::vector<GradientPair> const &gpair) {
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0; int sync_count = 0;
if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count, this->AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree); nodes_for_subtraction_trick, p_tree);
if (p_fmat->IsDense()) { }
BuildLocalHistograms<false>(p_fmat, nodes_for_explicit_hist_build, if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair); row_set_collection, gpair);
} else { } else {
BuildLocalHistograms<true>(p_fmat, nodes_for_explicit_hist_build, this->BuildLocalHistograms<true>(page_id, space, gidx,
nodes_for_explicit_hist_build,
row_set_collection, gpair); row_set_collection, gpair);
} }
CHECK_GE(n_batches_, 1);
if (page_id != n_batches_ - 1) {
return;
}
if (is_distributed_) { if (is_distributed_) {
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, nodes_for_subtraction_trick,
@ -138,6 +145,25 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
sync_count); sync_count);
} }
} }
/** same as the other build hist but handles only single batch data (in-core) */
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
common::RowSetCollection const &row_set_collection,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
std::vector<GradientPair> const &gpair) {
const size_t n_nodes = nodes_for_explicit_hist_build.size();
// create space of size (# rows in each node)
common::BlockedSpace2d space(
n_nodes,
[&](size_t nidx_in_set) {
const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid;
return row_set_collection[nidx].Size();
},
256);
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection,
nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
gpair);
}
void SyncHistogramDistributed( void SyncHistogramDistributed(
RegTree *p_tree, RegTree *p_tree,

View File

@ -127,9 +127,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
nodes_for_subtraction_trick_.clear(); nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(node); nodes_for_explicit_hist_build_.push_back(node);
this->histogram_builder_->BuildHist(p_fmat, p_tree, row_set_collection_, size_t page_id = 0;
nodes_for_explicit_hist_build_, for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
nodes_for_subtraction_trick_, gpair_h); {GenericParameter::kCpuId, param_.max_bin})) {
this->histogram_builder_->BuildHist(
page_id, gidx, p_tree, row_set_collection_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
++page_id;
}
{ {
auto nid = RegTree::kRoot; auto nid = RegTree::kRoot;
@ -259,9 +264,15 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
if (depth < param_.max_depth) { if (depth < param_.max_depth) {
size_t i = 0;
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, param_.max_bin})) {
this->histogram_builder_->BuildHist( this->histogram_builder_->BuildHist(
p_fmat, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, i, gidx, p_tree, row_set_collection_,
nodes_for_subtraction_trick_, gpair_h); nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_,
gpair_h);
++i;
}
} else { } else {
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0; int sync_count = 0;
@ -432,7 +443,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
}); });
} }
exc.Rethrow(); exc.Rethrow();
this->histogram_builder_->Reset(nbins, param_.max_bin, this->nthread_); this->histogram_builder_->Reset(
nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
this->nthread_, 1, rabit::IsDistributed());
std::vector<size_t>& row_indices = *row_set_collection_.Data(); std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(info.num_row_); row_indices.resize(info.num_row_);

View File

@ -36,7 +36,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init(); row_set_collection.Init();
auto hist_builder = GHistBuilder<GradientSumT>(omp_get_max_threads(), gmat.cut.Ptrs().back()); auto hist_builder = GHistBuilder<GradientSumT>(gmat.cut.Ptrs().back());
hist.Init(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back());
hist.AddHistRow(0); hist.AddHistRow(0);
hist.AllocateAllData(); hist.AllocateAllData();

View File

@ -8,6 +8,16 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
namespace {
void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples,
size_t base_rowid = 0) {
auto &row_indices = *row_set->Data();
row_indices.resize(n_samples);
std::iota(row_indices.begin(), row_indices.end(), base_rowid);
row_set->Init();
}
} // anonymous namespace
template <typename GradientSumT> template <typename GradientSumT>
void TestAddHistRows(bool is_distributed) { void TestAddHistRows(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_; std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
@ -35,8 +45,9 @@ void TestAddHistRows(bool is_distributed) {
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f);
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder; HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(), kMaxBins, omp_get_max_threads(), histogram_builder.Reset(gmat.cut.TotalBins(),
is_distributed); {GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
histogram_builder.AddHistRows(&starting_index, &sync_count, histogram_builder.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree); nodes_for_subtraction_trick_, &tree);
@ -81,7 +92,8 @@ void TestSyncHist(bool is_distributed) {
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram; HistogramBuilder<GradientSumT, CPUExpandEntry> histogram;
uint32_t total_bins = gmat.cut.Ptrs().back(); uint32_t total_bins = gmat.cut.Ptrs().back();
histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
RowSetCollection row_set_collection_; RowSetCollection row_set_collection_;
{ {
@ -247,22 +259,26 @@ void TestBuildHistogram(bool is_distributed) {
bst_node_t nid = 0; bst_node_t nid = 0;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram; HistogramBuilder<GradientSumT, CPUExpandEntry> histogram;
histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins},
omp_get_max_threads(), 1, is_distributed);
RegTree tree; RegTree tree;
RowSetCollection row_set_collection_; RowSetCollection row_set_collection;
row_set_collection_.Clear(); row_set_collection.Clear();
std::vector<size_t> &row_indices = *row_set_collection_.Data(); std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kNRows); row_indices.resize(kNRows);
std::iota(row_indices.begin(), row_indices.end(), 0); std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection_.Init(); row_set_collection.Init();
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_; std::vector<CPUExpandEntry> nodes_for_explicit_hist_build;
nodes_for_explicit_hist_build_.push_back(node); nodes_for_explicit_hist_build.push_back(node);
histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_, for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
nodes_for_explicit_hist_build_, {}, gpair); {GenericParameter::kCpuId, kMaxBins})) {
histogram.BuildHist(0, gidx, &tree, row_set_collection,
nodes_for_explicit_hist_build, {}, gpair);
}
// Check if number of histogram bins is correct // Check if number of histogram bins is correct
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back()); ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
@ -294,5 +310,88 @@ TEST(CPUHistogram, BuildHist) {
TestBuildHistogram<float>(false); TestBuildHistogram<float>(false);
TestBuildHistogram<double>(false); TestBuildHistogram<double>(false);
} }
TEST(CPUHistogram, ExternalMemory) {
size_t constexpr kEntries = 1 << 16;
int32_t constexpr kBins = 32;
auto m = CreateSparsePageDMatrix(kEntries, "cache");
std::vector<size_t> partition_size(1, 0);
size_t total_bins{0};
size_t n_samples{0};
auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0);
auto const &h_gpair = gpair.HostVector();
RegTree tree;
std::vector<CPUExpandEntry> nodes;
nodes.emplace_back(0, tree.GetDepth(0), 0.0f);
GHistRow<double> multi_page;
HistogramBuilder<double, CPUExpandEntry> multi_build;
{
/**
* Multi page
*/
std::vector<RowSetCollection> rows_set;
std::vector<float> hess(m->Info().num_row_, 1.0);
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins, hess})) {
CHECK_LT(page.base_rowid, m->Info().num_row_);
auto n_rows_in_node = page.Size();
partition_size[0] = std::max(partition_size[0], n_rows_in_node);
total_bins = page.cut.TotalBins();
n_samples += n_rows_in_node;
rows_set.emplace_back();
InitRowPartitionForTest(&rows_set.back(), n_rows_in_node, page.base_rowid);
}
ASSERT_EQ(n_samples, m->Info().num_row_);
common::BlockedSpace2d space{
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
256};
multi_build.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), rows_set.size(), false);
size_t page_idx{0};
for (auto const &page : m->GetBatches<GHistIndexMatrix>(
{GenericParameter::kCpuId, kBins, hess})) {
multi_build.BuildHist(page_idx, space, page, &tree,
rows_set.at(page_idx), nodes, {}, h_gpair);
++page_idx;
}
ASSERT_EQ(page_idx, 2);
multi_page = multi_build.Histogram()[0];
}
HistogramBuilder<double, CPUExpandEntry> single_build;
GHistRow<double> single_page;
{
/**
* Single page
*/
RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, {GenericParameter::kCpuId, kBins},
omp_get_max_threads(), 1, false);
size_t n_batches{0};
for (auto const &page :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, kBins})) {
single_build.BuildHist(0, page, &tree, row_set_collection, nodes, {},
h_gpair);
n_batches ++;
}
ASSERT_EQ(n_batches, 1);
single_page = single_build.Histogram()[0];
}
for (size_t i = 0; i < single_page.size(); ++i) {
ASSERT_NEAR(single_page[i].GetGrad(), multi_page[i].GetGrad(), kRtEps);
ASSERT_NEAR(single_page[i].GetHess(), multi_page[i].GetHess(), kRtEps);
}
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost