Support external memory in CPU histogram building. (#7372)
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "../../common/hist_util.h"
|
||||
#include "../../data/gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -25,10 +26,11 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
|
||||
common::GHistBuilder<GradientSumT> builder_;
|
||||
common::ParallelGHistBuilder<GradientSumT> buffer_;
|
||||
rabit::Reducer<GradientPairT, GradientPairT::Reduce> reducer_;
|
||||
int32_t max_bin_ {-1};
|
||||
int32_t n_threads_ {-1};
|
||||
BatchParam param_;
|
||||
int32_t n_threads_{-1};
|
||||
size_t n_batches_{0};
|
||||
// Whether XGBoost is running in distributed environment.
|
||||
bool is_distributed_ {false};
|
||||
bool is_distributed_{false};
|
||||
|
||||
public:
|
||||
/**
|
||||
@@ -39,59 +41,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
|
||||
* \param is_distributed Mostly used for testing to allow injecting parameters instead
|
||||
* of using global rabit variable.
|
||||
*/
|
||||
void Reset(uint32_t total_bins, int32_t max_bin_per_feat, int32_t n_threads,
|
||||
bool is_distributed = rabit::IsDistributed()) {
|
||||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches,
|
||||
bool is_distributed) {
|
||||
CHECK_GE(n_threads, 1);
|
||||
n_threads_ = n_threads;
|
||||
CHECK_GE(max_bin_per_feat, 2);
|
||||
max_bin_ = max_bin_per_feat;
|
||||
n_batches_ = n_batches;
|
||||
param_ = p;
|
||||
hist_.Init(total_bins);
|
||||
hist_local_worker_.Init(total_bins);
|
||||
buffer_.Init(total_bins);
|
||||
builder_ = common::GHistBuilder<GradientSumT>(n_threads, total_bins);
|
||||
builder_ = common::GHistBuilder<GradientSumT>(total_bins);
|
||||
is_distributed_ = is_distributed;
|
||||
}
|
||||
|
||||
template <bool any_missing>
|
||||
void
|
||||
BuildLocalHistograms(DMatrix *p_fmat,
|
||||
std::vector<ExpandEntry> nodes_for_explicit_hist_build,
|
||||
common::RowSetCollection const &row_set_collection,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
void BuildLocalHistograms(size_t page_idx, common::BlockedSpace2d space,
|
||||
GHistIndexMatrix const &gidx,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
common::RowSetCollection const &row_set_collection,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
||||
|
||||
// create space of size (# rows in each node)
|
||||
common::BlockedSpace2d space(
|
||||
n_nodes,
|
||||
[&](size_t node) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build[node].nid;
|
||||
return row_set_collection[nid].Size();
|
||||
},
|
||||
256);
|
||||
CHECK_GT(n_nodes, 0);
|
||||
|
||||
std::vector<GHistRowT> target_hists(n_nodes);
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build[i].nid;
|
||||
target_hists[i] = hist_[nid];
|
||||
}
|
||||
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists);
|
||||
if (page_idx == 0) {
|
||||
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
|
||||
// partition size for the buffer, which might not be efficient if partition sizes
|
||||
// has significant variance.
|
||||
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists);
|
||||
}
|
||||
|
||||
// Parallel processing by nodes and data in each node
|
||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||
BatchParam{GenericParameter::kCpuId, max_bin_})) {
|
||||
common::ParallelFor2d(
|
||||
space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid;
|
||||
|
||||
auto start_of_row_set = row_set_collection[nid].begin;
|
||||
auto rid_set = common::RowSetCollection::Elem(
|
||||
start_of_row_set + r.begin(), start_of_row_set + r.end(), nid);
|
||||
builder_.template BuildHist<any_missing>(
|
||||
gpair_h, rid_set, gmat,
|
||||
buffer_.GetInitializedHist(tid, nid_in_set));
|
||||
});
|
||||
}
|
||||
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid;
|
||||
auto elem = row_set_collection[nid];
|
||||
auto start_of_row_set = std::min(r.begin(), elem.Size());
|
||||
auto end_of_row_set = std::min(r.end(), elem.Size());
|
||||
auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set,
|
||||
elem.begin + end_of_row_set, nid);
|
||||
auto hist = buffer_.GetInitializedHist(tid, nid_in_set);
|
||||
if (rid_set.Size() != 0) {
|
||||
builder_.template BuildHist<any_missing>(gpair_h, rid_set, gidx, hist);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void
|
||||
@@ -110,24 +107,34 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/* Main entry point of this class, build histogram for tree nodes. */
|
||||
void BuildHist(DMatrix *p_fmat, RegTree *p_tree,
|
||||
common::RowSetCollection const &row_set_collection,
|
||||
/** Main entry point of this class, build histogram for tree nodes. */
|
||||
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
|
||||
RegTree *p_tree, common::RowSetCollection const &row_set_collection,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
std::vector<GradientPair> const &gpair) {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
this->AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, p_tree);
|
||||
if (p_fmat->IsDense()) {
|
||||
BuildLocalHistograms<false>(p_fmat, nodes_for_explicit_hist_build,
|
||||
row_set_collection, gpair);
|
||||
} else {
|
||||
BuildLocalHistograms<true>(p_fmat, nodes_for_explicit_hist_build,
|
||||
row_set_collection, gpair);
|
||||
if (page_id == 0) {
|
||||
this->AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, p_tree);
|
||||
}
|
||||
if (gidx.IsDense()) {
|
||||
this->BuildLocalHistograms<false>(page_id, space, gidx,
|
||||
nodes_for_explicit_hist_build,
|
||||
row_set_collection, gpair);
|
||||
} else {
|
||||
this->BuildLocalHistograms<true>(page_id, space, gidx,
|
||||
nodes_for_explicit_hist_build,
|
||||
row_set_collection, gpair);
|
||||
}
|
||||
|
||||
CHECK_GE(n_batches_, 1);
|
||||
if (page_id != n_batches_ - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_distributed_) {
|
||||
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick,
|
||||
@@ -138,6 +145,25 @@ template <typename GradientSumT, typename ExpandEntry> class HistogramBuilder {
|
||||
sync_count);
|
||||
}
|
||||
}
|
||||
/** same as the other build hist but handles only single batch data (in-core) */
|
||||
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
|
||||
common::RowSetCollection const &row_set_collection,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
std::vector<GradientPair> const &gpair) {
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
||||
// create space of size (# rows in each node)
|
||||
common::BlockedSpace2d space(
|
||||
n_nodes,
|
||||
[&](size_t nidx_in_set) {
|
||||
const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid;
|
||||
return row_set_collection[nidx].Size();
|
||||
},
|
||||
256);
|
||||
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection,
|
||||
nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
|
||||
gpair);
|
||||
}
|
||||
|
||||
void SyncHistogramDistributed(
|
||||
RegTree *p_tree,
|
||||
|
||||
@@ -127,9 +127,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.push_back(node);
|
||||
|
||||
this->histogram_builder_->BuildHist(p_fmat, p_tree, row_set_collection_,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, gpair_h);
|
||||
size_t page_id = 0;
|
||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, param_.max_bin})) {
|
||||
this->histogram_builder_->BuildHist(
|
||||
page_id, gidx, p_tree, row_set_collection_,
|
||||
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h);
|
||||
++page_id;
|
||||
}
|
||||
|
||||
{
|
||||
auto nid = RegTree::kRoot;
|
||||
@@ -259,9 +264,15 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
|
||||
|
||||
if (depth < param_.max_depth) {
|
||||
this->histogram_builder_->BuildHist(
|
||||
p_fmat, p_tree, row_set_collection_, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, gpair_h);
|
||||
size_t i = 0;
|
||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, param_.max_bin})) {
|
||||
this->histogram_builder_->BuildHist(
|
||||
i, gidx, p_tree, row_set_collection_,
|
||||
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_,
|
||||
gpair_h);
|
||||
++i;
|
||||
}
|
||||
} else {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@@ -432,7 +443,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
this->histogram_builder_->Reset(nbins, param_.max_bin, this->nthread_);
|
||||
this->histogram_builder_->Reset(
|
||||
nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
|
||||
this->nthread_, 1, rabit::IsDistributed());
|
||||
|
||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||
row_indices.resize(info.num_row_);
|
||||
|
||||
Reference in New Issue
Block a user