Unify the code path between local and distributed training. (#9433)

This removes the need for a local histogram space during distributed training, which cuts the cache size by half.
This commit is contained in:
Jiaming Yuan 2023-08-03 21:46:36 +08:00 committed by GitHub
parent f958e32683
commit 1332ff787f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 173 deletions

View File

@ -453,6 +453,7 @@ class HistCollection {
data_[0].resize(new_size);
}
}
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }
private:
/*! \brief number of all bins over all features */

View File

@ -14,14 +14,11 @@
#include "expand_entry.h"
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost {
namespace tree {
namespace xgboost::tree {
template <typename ExpandEntry>
class HistogramBuilder {
/*! \brief culmulative histogram of gradients. */
common::HistCollection hist_;
/*! \brief culmulative local parent histogram of gradients. */
common::HistCollection hist_local_worker_;
common::ParallelGHistBuilder buffer_;
BatchParam param_;
int32_t n_threads_{-1};
@ -46,12 +43,9 @@ class HistogramBuilder {
n_batches_ = n_batches;
param_ = p;
hist_.Init(total_bins);
hist_local_worker_.Init(total_bins);
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
// Workaround s390x gcc 7.5.0
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
}
template <bool any_missing>
@ -91,17 +85,19 @@ class HistogramBuilder {
});
}
void AddHistRows(int *starting_index, int *sync_count,
void AddHistRows(int *starting_index,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
if (is_distributed_ && !is_col_split_) {
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
} else {
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
for (auto const &entry : nodes_for_explicit_hist_build) {
int nid = entry.nid;
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}
for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
}
/** Main entry point of this class, build histogram for tree nodes. */
@ -111,10 +107,9 @@ class HistogramBuilder {
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
this->AddHistRows(&starting_index, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
}
if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
@ -129,13 +124,8 @@ class HistogramBuilder {
return;
}
if (is_distributed_ && !is_col_split_) {
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick,
starting_index, sync_count);
} else {
this->SyncHistogramLocal(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
}
this->SyncHistogram(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, starting_index);
}
/** same as the other build hist but handles only single batch data (in-core) */
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
@ -156,62 +146,33 @@ class HistogramBuilder {
nodes_for_subtraction_trick, gpair, force_read_by_column);
}
void SyncHistogramDistributed(RegTree const *p_tree,
void SyncHistogram(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) {
int starting_index) {
auto n_bins = buffer_.TotalBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024);
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
buffer_.ReduceHist(node, r.begin(), r.end());
// Store posible parent node
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());
if (!p_tree->IsRoot(entry.nid)) {
const size_t parent_id = p_tree->Parent(entry.nid);
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
});
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()), n_bins * sync_count * 2);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
p_tree);
common::BlockedSpace2d space2(
nodes_for_subtraction_trick.size(), [&](size_t) { return n_bins; }, 1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
p_tree);
}
void SyncHistogramLocal(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
const size_t nbins = this->buffer_.TotalBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
CHECK(hist_.IsContiguous());
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());
});
if (is_distributed_ && !is_col_split_) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()),
n_bins * nodes_for_explicit_hist_build.size() * 2);
}
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[nidx_in_set];
auto this_hist = this->hist_[entry.nid];
if (!p_tree->IsRoot(entry.nid)) {
auto const parent_id = p_tree->Parent(entry.nid);
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
@ -223,81 +184,6 @@ class HistogramBuilder {
/* Getters for tests. */
common::HistCollection const &Histogram() { return hist_; }
auto &Buffer() { return buffer_; }
private:
void
ParallelSubtractionHist(const common::BlockedSpace2d &space,
const std::vector<ExpandEntry> &nodes,
const std::vector<ExpandEntry> &subtraction_nodes,
const RegTree *p_tree) {
common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes[node];
if (!(p_tree->IsLeftChild(entry.nid))) {
auto this_hist = this->hist_[entry.nid];
if (!p_tree->IsRoot(entry.nid)) {
const int subtraction_node_id = subtraction_nodes[node].nid;
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[subtraction_node_id];
common::SubtractionHist(this_hist, parent_hist, sibling_hist,
r.begin(), r.end());
}
}
});
}
// Add a tree node to histogram buffer in local training environment.
void AddHistRowsLocal(
int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
for (auto const &entry : nodes_for_explicit_hist_build) {
int nid = entry.nid;
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}
(*sync_count) = nodes_for_explicit_hist_build.size();
for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
}
void AddHistRowsDistributed(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
const size_t explicit_size = nodes_for_explicit_hist_build.size();
const size_t subtaction_size = nodes_for_subtraction_trick.size();
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
for (size_t i = 0; i < explicit_size; ++i) {
merged_node_ids[i] = nodes_for_explicit_hist_build[i].nid;
}
for (size_t i = 0; i < subtaction_size; ++i) {
merged_node_ids[explicit_size + i] = nodes_for_subtraction_trick[i].nid;
}
std::sort(merged_node_ids.begin(), merged_node_ids.end());
int n_left = 0;
for (auto const &nid : merged_node_ids) {
if (p_tree->IsLeftChild(nid)) {
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
n_left++;
this->hist_local_worker_.AddHistRow(nid);
}
}
for (auto const &nid : merged_node_ids) {
if (!(p_tree->IsLeftChild(nid))) {
this->hist_.AddHistRow(nid);
this->hist_local_worker_.AddHistRow(nid);
}
}
this->hist_.AllocateAllData();
this->hist_local_worker_.AllocateAllData();
(*sync_count) = std::max(1, n_left);
}
};
// Construct a work space for building histogram. Eventually we should move this
@ -318,6 +204,5 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
return space;
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_

View File

@ -28,7 +28,6 @@ void TestAddHistRows(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
size_t constexpr kNRows = 8, kNCols = 16;
int32_t constexpr kMaxBins = 4;
@ -49,11 +48,9 @@ void TestAddHistRows(bool is_distributed) {
HistogramBuilder<CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
is_distributed, false);
histogram_builder.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
@ -78,7 +75,6 @@ void TestSyncHist(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
RegTree tree;
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
@ -100,9 +96,8 @@ void TestSyncHist(bool is_distributed) {
// level 0
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.clear();
@ -112,9 +107,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@ -127,9 +121,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2ul);
@ -175,14 +168,8 @@ void TestSyncHist(bool is_distributed) {
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
// sync hist
if (is_distributed) {
histogram.SyncHistogramDistributed(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_,
starting_index, sync_count);
} else {
histogram.SyncHistogramLocal(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
}
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, starting_index);
using GHistRowT = common::GHistRow;
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
@ -487,4 +474,3 @@ TEST(CPUHistogram, ExternalMemory) {
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
}
} // namespace xgboost::tree