Unify the code path between local and distributed training. (#9433)
This removes the need for a local histogram space during distributed training, which cuts the cache size by half.
This commit is contained in:
parent
f958e32683
commit
1332ff787f
@ -453,6 +453,7 @@ class HistCollection {
|
||||
data_[0].resize(new_size);
|
||||
}
|
||||
}
|
||||
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }
|
||||
|
||||
private:
|
||||
/*! \brief number of all bins over all features */
|
||||
|
||||
@ -14,14 +14,11 @@
|
||||
#include "expand_entry.h"
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace xgboost::tree {
|
||||
template <typename ExpandEntry>
|
||||
class HistogramBuilder {
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
common::HistCollection hist_;
|
||||
/*! \brief culmulative local parent histogram of gradients. */
|
||||
common::HistCollection hist_local_worker_;
|
||||
common::ParallelGHistBuilder buffer_;
|
||||
BatchParam param_;
|
||||
int32_t n_threads_{-1};
|
||||
@ -46,12 +43,9 @@ class HistogramBuilder {
|
||||
n_batches_ = n_batches;
|
||||
param_ = p;
|
||||
hist_.Init(total_bins);
|
||||
hist_local_worker_.Init(total_bins);
|
||||
buffer_.Init(total_bins);
|
||||
is_distributed_ = is_distributed;
|
||||
is_col_split_ = is_col_split;
|
||||
// Workaround s390x gcc 7.5.0
|
||||
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
|
||||
}
|
||||
|
||||
template <bool any_missing>
|
||||
@ -91,17 +85,19 @@ class HistogramBuilder {
|
||||
});
|
||||
}
|
||||
|
||||
void AddHistRows(int *starting_index, int *sync_count,
|
||||
void AddHistRows(int *starting_index,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
RegTree const *p_tree) {
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, p_tree);
|
||||
} else {
|
||||
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick);
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
||||
for (auto const &entry : nodes_for_explicit_hist_build) {
|
||||
int nid = entry.nid;
|
||||
this->hist_.AddHistRow(nid);
|
||||
(*starting_index) = std::min(nid, (*starting_index));
|
||||
}
|
||||
|
||||
for (auto const &node : nodes_for_subtraction_trick) {
|
||||
this->hist_.AddHistRow(node.nid);
|
||||
}
|
||||
this->hist_.AllocateAllData();
|
||||
}
|
||||
|
||||
/** Main entry point of this class, build histogram for tree nodes. */
|
||||
@ -111,10 +107,9 @@ class HistogramBuilder {
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
if (page_id == 0) {
|
||||
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, p_tree);
|
||||
this->AddHistRows(&starting_index, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick);
|
||||
}
|
||||
if (gidx.IsDense()) {
|
||||
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
|
||||
@ -129,13 +124,8 @@ class HistogramBuilder {
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick,
|
||||
starting_index, sync_count);
|
||||
} else {
|
||||
this->SyncHistogramLocal(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
|
||||
}
|
||||
this->SyncHistogram(p_tree, nodes_for_explicit_hist_build,
|
||||
nodes_for_subtraction_trick, starting_index);
|
||||
}
|
||||
/** same as the other build hist but handles only single batch data (in-core) */
|
||||
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
|
||||
@ -156,62 +146,33 @@ class HistogramBuilder {
|
||||
nodes_for_subtraction_trick, gpair, force_read_by_column);
|
||||
}
|
||||
|
||||
void SyncHistogramDistributed(RegTree const *p_tree,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
int starting_index, int sync_count) {
|
||||
void SyncHistogram(RegTree const *p_tree,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
int starting_index) {
|
||||
auto n_bins = buffer_.TotalBins();
|
||||
common::BlockedSpace2d space(
|
||||
nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024);
|
||||
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
|
||||
const auto &entry = nodes_for_explicit_hist_build[node];
|
||||
auto this_hist = this->hist_[entry.nid];
|
||||
// Merging histograms from each thread into once
|
||||
buffer_.ReduceHist(node, r.begin(), r.end());
|
||||
// Store posible parent node
|
||||
auto this_local = hist_local_worker_[entry.nid];
|
||||
common::CopyHist(this_local, this_hist, r.begin(), r.end());
|
||||
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
const size_t parent_id = p_tree->Parent(entry.nid);
|
||||
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||
auto parent_hist = this->hist_local_worker_[parent_id];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||
// Store posible parent node
|
||||
auto sibling_local = hist_local_worker_[subtraction_node_id];
|
||||
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
|
||||
}
|
||||
});
|
||||
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double *>(this->hist_[starting_index].data()), n_bins * sync_count * 2);
|
||||
|
||||
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
|
||||
p_tree);
|
||||
|
||||
common::BlockedSpace2d space2(
|
||||
nodes_for_subtraction_trick.size(), [&](size_t) { return n_bins; }, 1024);
|
||||
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
|
||||
p_tree);
|
||||
}
|
||||
|
||||
void SyncHistogramLocal(RegTree const *p_tree,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
||||
const size_t nbins = this->buffer_.TotalBins();
|
||||
common::BlockedSpace2d space(
|
||||
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
|
||||
|
||||
CHECK(hist_.IsContiguous());
|
||||
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
||||
const auto &entry = nodes_for_explicit_hist_build[node];
|
||||
auto this_hist = this->hist_[entry.nid];
|
||||
// Merging histograms from each thread into once
|
||||
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
||||
});
|
||||
|
||||
if (is_distributed_ && !is_col_split_) {
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
||||
n_bins * nodes_for_explicit_hist_build.size() * 2);
|
||||
}
|
||||
|
||||
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
|
||||
const auto &entry = nodes_for_explicit_hist_build[nidx_in_set];
|
||||
auto this_hist = this->hist_[entry.nid];
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
auto const parent_id = p_tree->Parent(entry.nid);
|
||||
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||
auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid;
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||
@ -222,82 +183,7 @@ class HistogramBuilder {
|
||||
public:
|
||||
/* Getters for tests. */
|
||||
common::HistCollection const &Histogram() { return hist_; }
|
||||
auto& Buffer() { return buffer_; }
|
||||
|
||||
private:
|
||||
void
|
||||
ParallelSubtractionHist(const common::BlockedSpace2d &space,
|
||||
const std::vector<ExpandEntry> &nodes,
|
||||
const std::vector<ExpandEntry> &subtraction_nodes,
|
||||
const RegTree *p_tree) {
|
||||
common::ParallelFor2d(
|
||||
space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
||||
const auto &entry = nodes[node];
|
||||
if (!(p_tree->IsLeftChild(entry.nid))) {
|
||||
auto this_hist = this->hist_[entry.nid];
|
||||
|
||||
if (!p_tree->IsRoot(entry.nid)) {
|
||||
const int subtraction_node_id = subtraction_nodes[node].nid;
|
||||
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
|
||||
auto sibling_hist = hist_[subtraction_node_id];
|
||||
common::SubtractionHist(this_hist, parent_hist, sibling_hist,
|
||||
r.begin(), r.end());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Add a tree node to histogram buffer in local training environment.
|
||||
void AddHistRowsLocal(
|
||||
int *starting_index, int *sync_count,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
||||
for (auto const &entry : nodes_for_explicit_hist_build) {
|
||||
int nid = entry.nid;
|
||||
this->hist_.AddHistRow(nid);
|
||||
(*starting_index) = std::min(nid, (*starting_index));
|
||||
}
|
||||
(*sync_count) = nodes_for_explicit_hist_build.size();
|
||||
|
||||
for (auto const &node : nodes_for_subtraction_trick) {
|
||||
this->hist_.AddHistRow(node.nid);
|
||||
}
|
||||
this->hist_.AllocateAllData();
|
||||
}
|
||||
|
||||
void AddHistRowsDistributed(int *starting_index, int *sync_count,
|
||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||
RegTree const *p_tree) {
|
||||
const size_t explicit_size = nodes_for_explicit_hist_build.size();
|
||||
const size_t subtaction_size = nodes_for_subtraction_trick.size();
|
||||
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
|
||||
for (size_t i = 0; i < explicit_size; ++i) {
|
||||
merged_node_ids[i] = nodes_for_explicit_hist_build[i].nid;
|
||||
}
|
||||
for (size_t i = 0; i < subtaction_size; ++i) {
|
||||
merged_node_ids[explicit_size + i] = nodes_for_subtraction_trick[i].nid;
|
||||
}
|
||||
std::sort(merged_node_ids.begin(), merged_node_ids.end());
|
||||
int n_left = 0;
|
||||
for (auto const &nid : merged_node_ids) {
|
||||
if (p_tree->IsLeftChild(nid)) {
|
||||
this->hist_.AddHistRow(nid);
|
||||
(*starting_index) = std::min(nid, (*starting_index));
|
||||
n_left++;
|
||||
this->hist_local_worker_.AddHistRow(nid);
|
||||
}
|
||||
}
|
||||
for (auto const &nid : merged_node_ids) {
|
||||
if (!(p_tree->IsLeftChild(nid))) {
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->hist_local_worker_.AddHistRow(nid);
|
||||
}
|
||||
}
|
||||
this->hist_.AllocateAllData();
|
||||
this->hist_local_worker_.AllocateAllData();
|
||||
(*sync_count) = std::max(1, n_left);
|
||||
}
|
||||
auto &Buffer() { return buffer_; }
|
||||
};
|
||||
|
||||
// Construct a work space for building histogram. Eventually we should move this
|
||||
@ -318,6 +204,5 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
||||
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
|
||||
return space;
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||
|
||||
@ -28,7 +28,6 @@ void TestAddHistRows(bool is_distributed) {
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
|
||||
size_t constexpr kNRows = 8, kNCols = 16;
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
@ -49,11 +48,9 @@ void TestAddHistRows(bool is_distributed) {
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||
is_distributed, false);
|
||||
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
|
||||
@ -78,7 +75,6 @@ void TestSyncHist(bool is_distributed) {
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
RegTree tree;
|
||||
|
||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
@ -100,9 +96,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
// level 0
|
||||
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
@ -112,9 +107,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
|
||||
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
|
||||
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
@ -127,9 +121,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
|
||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
||||
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||
ASSERT_EQ(n_nodes, 2ul);
|
||||
@ -175,14 +168,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
||||
// sync hist
|
||||
if (is_distributed) {
|
||||
histogram.SyncHistogramDistributed(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_,
|
||||
starting_index, sync_count);
|
||||
} else {
|
||||
histogram.SyncHistogramLocal(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
}
|
||||
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, starting_index);
|
||||
|
||||
using GHistRowT = common::GHistRow;
|
||||
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
||||
@ -487,4 +474,3 @@ TEST(CPUHistogram, ExternalMemory) {
|
||||
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user