Unify the code path between local and distributed training. (#9433)
This removes the need for a local histogram space during distributed training, which cuts the cache size by half.
This commit is contained in:
@@ -28,7 +28,6 @@ void TestAddHistRows(bool is_distributed) {
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
|
||||
size_t constexpr kNRows = 8, kNCols = 16;
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
@@ -49,11 +48,9 @@ void TestAddHistRows(bool is_distributed) {
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||
is_distributed, false);
|
||||
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
|
||||
@@ -78,7 +75,6 @@ void TestSyncHist(bool is_distributed) {
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
RegTree tree;
|
||||
|
||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
@@ -100,9 +96,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
// level 0
|
||||
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
@@ -112,9 +107,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
|
||||
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));
|
||||
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
@@ -127,9 +121,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
|
||||
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));
|
||||
|
||||
histogram.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||
ASSERT_EQ(n_nodes, 2ul);
|
||||
@@ -175,14 +168,8 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
histogram.Buffer().Reset(1, n_nodes, space, target_hists);
|
||||
// sync hist
|
||||
if (is_distributed) {
|
||||
histogram.SyncHistogramDistributed(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_,
|
||||
starting_index, sync_count);
|
||||
} else {
|
||||
histogram.SyncHistogramLocal(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_);
|
||||
}
|
||||
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, starting_index);
|
||||
|
||||
using GHistRowT = common::GHistRow;
|
||||
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
|
||||
@@ -487,4 +474,3 @@ TEST(CPUHistogram, ExternalMemory) {
|
||||
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
|
||||
Reference in New Issue
Block a user