Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -11,7 +11,7 @@
#include <cstddef> // std::size_t
#include "../collective/device_communicator.cuh" // DeviceCommunicator
#include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
@@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
collective::AllReduce<collective::Operation::kSum>(
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable {

View File

@@ -12,7 +12,7 @@
#include <utility>
#include <vector>
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext
@@ -546,12 +546,13 @@ struct GPUHistMakerDevice {
}
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, collective::DeviceCommunicator* communicator, int num_histograms) {
void AllReduceHist(int nidx, int num_histograms) {
monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
communicator->AllReduceSum(reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms);
collective::AllReduce<collective::Operation::kSum>(
ctx_->gpu_id, reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms);
monitor.Stop("AllReduce");
}
@@ -559,8 +560,7 @@ struct GPUHistMakerDevice {
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates,
collective::DeviceCommunicator* communicator, const RegTree& tree) {
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) {
if (candidates.empty()) return;
// Some nodes we will manually compute histograms
// others we will do by subtraction
@@ -591,7 +591,7 @@ struct GPUHistMakerDevice {
// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(hist_nidx.at(0), communicator, hist_nidx.size());
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
@@ -601,7 +601,7 @@ struct GPUHistMakerDevice {
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, communicator, 1);
this->AllReduceHist(subtraction_trick_nidx, 1);
}
}
}
@@ -659,7 +659,7 @@ struct GPUHistMakerDevice {
parent.RightChild());
}
GPUExpandEntry InitRoot(RegTree* p_tree, collective::DeviceCommunicator* communicator) {
GPUExpandEntry InitRoot(RegTree* p_tree) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
auto quantiser = *this->quantiser;
@@ -676,7 +676,7 @@ struct GPUHistMakerDevice {
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, communicator, 1);
this->AllReduceHist(kRootNIdx, 1);
// Remember root stats
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
@@ -692,7 +692,6 @@ struct GPUHistMakerDevice {
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
ObjInfo const* task, RegTree* p_tree,
collective::DeviceCommunicator* communicator,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
// Process maximum 32 nodes at a time
@@ -703,7 +702,7 @@ struct GPUHistMakerDevice {
monitor.Stop("Reset");
monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, communicator) });
driver.Push({this->InitRoot(p_tree)});
monitor.Stop("InitRoot");
// The set of leaves that can be expanded asynchronously
@@ -730,7 +729,7 @@ struct GPUHistMakerDevice {
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, communicator, tree);
this->BuildHistLeftRight(filtered_expand_set, tree);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
@@ -851,8 +850,7 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop("InitData");
gpair->SetDevice(ctx_->gpu_id);
auto* communicator = collective::Communicator::GetDevice(ctx_->gpu_id);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, communicator, p_out_position);
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
}
bool UpdatePredictionCache(const DMatrix* data,