Unify the code path between local and distributed training. (#9433)

This removes the need for a local histogram space during distributed training, which cuts the cache size by half.
This commit is contained in:
Jiaming Yuan 2023-08-03 21:46:36 +08:00 committed by GitHub
parent f958e32683
commit 1332ff787f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 173 deletions

View File

@ -453,6 +453,7 @@ class HistCollection {
data_[0].resize(new_size); data_[0].resize(new_size);
} }
} }
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }
private: private:
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */

View File

@ -14,14 +14,11 @@
#include "expand_entry.h" #include "expand_entry.h"
#include "xgboost/tree_model.h" // for RegTree #include "xgboost/tree_model.h" // for RegTree
namespace xgboost { namespace xgboost::tree {
namespace tree {
template <typename ExpandEntry> template <typename ExpandEntry>
class HistogramBuilder { class HistogramBuilder {
/*! \brief culmulative histogram of gradients. */ /*! \brief culmulative histogram of gradients. */
common::HistCollection hist_; common::HistCollection hist_;
/*! \brief culmulative local parent histogram of gradients. */
common::HistCollection hist_local_worker_;
common::ParallelGHistBuilder buffer_; common::ParallelGHistBuilder buffer_;
BatchParam param_; BatchParam param_;
int32_t n_threads_{-1}; int32_t n_threads_{-1};
@ -46,12 +43,9 @@ class HistogramBuilder {
n_batches_ = n_batches; n_batches_ = n_batches;
param_ = p; param_ = p;
hist_.Init(total_bins); hist_.Init(total_bins);
hist_local_worker_.Init(total_bins);
buffer_.Init(total_bins); buffer_.Init(total_bins);
is_distributed_ = is_distributed; is_distributed_ = is_distributed;
is_col_split_ = is_col_split; is_col_split_ = is_col_split;
// Workaround s390x gcc 7.5.0
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
} }
template <bool any_missing> template <bool any_missing>
@ -91,17 +85,19 @@ class HistogramBuilder {
}); });
} }
void AddHistRows(int *starting_index, int *sync_count, void AddHistRows(int *starting_index,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
RegTree const *p_tree) { for (auto const &entry : nodes_for_explicit_hist_build) {
if (is_distributed_ && !is_col_split_) { int nid = entry.nid;
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build, this->hist_.AddHistRow(nid);
nodes_for_subtraction_trick, p_tree); (*starting_index) = std::min(nid, (*starting_index));
} else {
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
} }
for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
} }
/** Main entry point of this class, build histogram for tree nodes. */ /** Main entry point of this class, build histogram for tree nodes. */
@ -111,10 +107,9 @@ class HistogramBuilder {
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
common::Span<GradientPair const> gpair, bool force_read_by_column = false) { common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
if (page_id == 0) { if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build, this->AddHistRows(&starting_index, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree); nodes_for_subtraction_trick);
} }
if (gidx.IsDense()) { if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build, this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
@ -129,13 +124,8 @@ class HistogramBuilder {
return; return;
} }
if (is_distributed_ && !is_col_split_) { this->SyncHistogram(p_tree, nodes_for_explicit_hist_build,
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, starting_index);
nodes_for_subtraction_trick,
starting_index, sync_count);
} else {
this->SyncHistogramLocal(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
}
} }
/** same as the other build hist but handles only single batch data (in-core) */ /** same as the other build hist but handles only single batch data (in-core) */
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree, void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
@ -156,62 +146,33 @@ class HistogramBuilder {
nodes_for_subtraction_trick, gpair, force_read_by_column); nodes_for_subtraction_trick, gpair, force_read_by_column);
} }
void SyncHistogramDistributed(RegTree const *p_tree, void SyncHistogram(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) { int starting_index) {
auto n_bins = buffer_.TotalBins(); auto n_bins = buffer_.TotalBins();
common::BlockedSpace2d space( common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024); nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024);
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) { CHECK(hist_.IsContiguous());
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
buffer_.ReduceHist(node, r.begin(), r.end());
// Store posible parent node
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());
if (!p_tree->IsRoot(entry.nid)) {
const size_t parent_id = p_tree->Parent(entry.nid);
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
});
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()), n_bins * sync_count * 2);
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
p_tree);
common::BlockedSpace2d space2(
nodes_for_subtraction_trick.size(), [&](size_t) { return n_bins; }, 1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
p_tree);
}
void SyncHistogramLocal(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
const size_t nbins = this->buffer_.TotalBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) { common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node]; const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid]; auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once // Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end()); this->buffer_.ReduceHist(node, r.begin(), r.end());
});
if (is_distributed_ && !is_col_split_) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()),
n_bins * nodes_for_explicit_hist_build.size() * 2);
}
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[nidx_in_set];
auto this_hist = this->hist_[entry.nid];
if (!p_tree->IsRoot(entry.nid)) { if (!p_tree->IsRoot(entry.nid)) {
auto const parent_id = p_tree->Parent(entry.nid); auto const parent_id = p_tree->Parent(entry.nid);
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid; auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid;
auto parent_hist = this->hist_[parent_id]; auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id]; auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end()); common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
@ -222,82 +183,7 @@ class HistogramBuilder {
public: public:
/* Getters for tests. */ /* Getters for tests. */
common::HistCollection const &Histogram() { return hist_; } common::HistCollection const &Histogram() { return hist_; }
auto& Buffer() { return buffer_; } auto &Buffer() { return buffer_; }
private:
void
ParallelSubtractionHist(const common::BlockedSpace2d &space,
const std::vector<ExpandEntry> &nodes,
const std::vector<ExpandEntry> &subtraction_nodes,
const RegTree *p_tree) {
common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes[node];
if (!(p_tree->IsLeftChild(entry.nid))) {
auto this_hist = this->hist_[entry.nid];
if (!p_tree->IsRoot(entry.nid)) {
const int subtraction_node_id = subtraction_nodes[node].nid;
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[subtraction_node_id];
common::SubtractionHist(this_hist, parent_hist, sibling_hist,
r.begin(), r.end());
}
}
});
}
// Add a tree node to histogram buffer in local training environment.
void AddHistRowsLocal(
int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
for (auto const &entry : nodes_for_explicit_hist_build) {
int nid = entry.nid;
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}
(*sync_count) = nodes_for_explicit_hist_build.size();
for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
}
void AddHistRowsDistributed(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
const size_t explicit_size = nodes_for_explicit_hist_build.size();
const size_t subtaction_size = nodes_for_subtraction_trick.size();
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
for (size_t i = 0; i < explicit_size; ++i) {
merged_node_ids[i] = nodes_for_explicit_hist_build[i].nid;
}
for (size_t i = 0; i < subtaction_size; ++i) {
merged_node_ids[explicit_size + i] = nodes_for_subtraction_trick[i].nid;
}
std::sort(merged_node_ids.begin(), merged_node_ids.end());
int n_left = 0;
for (auto const &nid : merged_node_ids) {
if (p_tree->IsLeftChild(nid)) {
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
n_left++;
this->hist_local_worker_.AddHistRow(nid);
}
}
for (auto const &nid : merged_node_ids) {
if (!(p_tree->IsLeftChild(nid))) {
this->hist_.AddHistRow(nid);
this->hist_local_worker_.AddHistRow(nid);
}
}
this->hist_.AllocateAllData();
this->hist_local_worker_.AllocateAllData();
(*sync_count) = std::max(1, n_left);
}
}; };
// Construct a work space for building histogram. Eventually we should move this // Construct a work space for building histogram. Eventually we should move this
@ -318,6 +204,5 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
return space; return space;
} }
} // namespace tree } // namespace xgboost::tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_ #endif // XGBOOST_TREE_HIST_HISTOGRAM_H_

View File

@ -28,7 +28,6 @@ void TestAddHistRows(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_; std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_; std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
size_t constexpr kNRows = 8, kNCols = 16; size_t constexpr kNRows = 8, kNCols = 16;
int32_t constexpr kMaxBins = 4; int32_t constexpr kMaxBins = 4;
@ -49,11 +48,9 @@ void TestAddHistRows(bool is_distributed) {
HistogramBuilder<CPUExpandEntry> histogram_builder; HistogramBuilder<CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
is_distributed, false); is_distributed, false);
histogram_builder.AddHistRows(&starting_index, &sync_count, histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_);
nodes_for_subtraction_trick_, &tree);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3); ASSERT_EQ(starting_index, 3);
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) { for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
@ -78,7 +75,6 @@ void TestSyncHist(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_; std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_; std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
RegTree tree; RegTree tree;
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
@ -100,9 +96,8 @@ void TestSyncHist(bool is_distributed) {
// level 0 // level 0
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0)); nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
histogram.AddHistRows(&starting_index, &sync_count, histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_);
nodes_for_subtraction_trick_, &tree);
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.clear(); nodes_for_explicit_hist_build_.clear();
@ -112,9 +107,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1)); nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2)); nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
histogram.AddHistRows(&starting_index, &sync_count, histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_);
nodes_for_subtraction_trick_, &tree);
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
@ -127,9 +121,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5)); nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6)); nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
histogram.AddHistRows(&starting_index, &sync_count, histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_);
nodes_for_subtraction_trick_, &tree);
const size_t n_nodes = nodes_for_explicit_hist_build_.size(); const size_t n_nodes = nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2ul); ASSERT_EQ(n_nodes, 2ul);
@ -175,14 +168,8 @@ void TestSyncHist(bool is_distributed) {
histogram.Buffer().Reset(1, n_nodes, space, target_hists); histogram.Buffer().Reset(1, n_nodes, space, target_hists);
// sync hist // sync hist
if (is_distributed) { histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
histogram.SyncHistogramDistributed(&tree, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, starting_index);
nodes_for_subtraction_trick_,
starting_index, sync_count);
} else {
histogram.SyncHistogramLocal(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
}
using GHistRowT = common::GHistRow; using GHistRowT = common::GHistRow;
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right, auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
@ -487,4 +474,3 @@ TEST(CPUHistogram, ExternalMemory) {
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true); TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
} }
} // namespace xgboost::tree } // namespace xgboost::tree