[REVIEW] Enable Multi-Node Multi-GPU functionality (#4095)
* Initial commit to support multi-node multi-gpu xgboost using dask * Fixed NCCL initialization by not ignoring the opg parameter. - it now crashes on NCCL initialization, but at least we're attempting it properly * At the root node, perform a rabit::Allreduce to get initial sum_gradient across workers * Synchronizing in a couple of more places. - now the workers don't go down, but just hang - no more "wild" values of gradients - probably needs syncing in more places * Added another missing max-allreduce operation inside BuildHistLeftRight * Removed unnecessary collective operations. * Simplified rabit::Allreduce() sync of gradient sums. * Removed unnecessary rabit syncs around ncclAllReduce. - this improves performance _significantly_ (7x faster for overall training, 20x faster for xgboost proper) * pulling in latest xgboost * removing changes to updater_quantile_hist.cc * changing use_nccl_opg initialization, removing unnecessary if statements * added definition for opaque ncclUniqueId struct to properly encapsulate GetUniqueId * placing struct defintion in guard to avoid duplicate code errors * addressing linting errors * removing * removing additional arguments to AllReduer initialization * removing distributed flag * making comm init symmetric * removing distributed flag * changing ncclCommInit to support multiple modalities * fix indenting * updating ncclCommInitRank block with necessary group calls * fix indenting * adding print statement, and updating accessor in vector * improving print statement to end-line * generalizing nccl_rank construction using rabit * assume device_ordinals is the same for every node * test, assume device_ordinals is identical for all nodes * test, assume device_ordinals is unique for all nodes * changing names of offset variable to be more descriptive, editing indenting * wrapping ncclUniqueId GetUniqueId() and aesthetic changes * adding synchronization, and tests for distributed * adding to tests * fixing broken #endif * fixing initialization of gpu histograms, correcting errors in tests * adding to contributors list * adding distributed tests to jenkins * fixing bad path in distributed test * debugging * adding kubernetes for distributed tests * adding proper import for OrderedDict * adding urllib3==1.22 to address ordered_dict import error * added sleep to allow workers to save their models for comparison * adding name to GPU contributors under docs
This commit is contained in:
committed by
Rory Mitchell
parent
9fefa2128d
commit
92b7577c62
@@ -23,6 +23,7 @@
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#include "../common/io.h"
|
||||
#endif
|
||||
|
||||
// Uncomment to enable
|
||||
@@ -853,6 +854,8 @@ class AllReducer {
|
||||
std::vector<ncclComm_t> comms;
|
||||
std::vector<cudaStream_t> streams;
|
||||
std::vector<int> device_ordinals; // device id from CUDA
|
||||
std::vector<int> device_counts; // device count from CUDA
|
||||
ncclUniqueId id;
|
||||
#endif
|
||||
|
||||
public:
|
||||
@@ -872,14 +875,41 @@ class AllReducer {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/** \brief this >monitor . init. */
|
||||
this->device_ordinals = device_ordinals;
|
||||
comms.resize(device_ordinals.size());
|
||||
dh::safe_nccl(ncclCommInitAll(comms.data(),
|
||||
static_cast<int>(device_ordinals.size()),
|
||||
device_ordinals.data()));
|
||||
streams.resize(device_ordinals.size());
|
||||
this->device_counts.resize(rabit::GetWorldSize());
|
||||
this->comms.resize(device_ordinals.size());
|
||||
this->streams.resize(device_ordinals.size());
|
||||
this->id = GetUniqueId();
|
||||
|
||||
device_counts.at(rabit::GetRank()) = device_ordinals.size();
|
||||
for (size_t i = 0; i < device_counts.size(); i++) {
|
||||
int dev_count = device_counts.at(i);
|
||||
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
|
||||
device_counts.at(i) = dev_count;
|
||||
}
|
||||
|
||||
int nccl_rank = 0;
|
||||
int nccl_rank_offset = std::accumulate(device_counts.begin(),
|
||||
device_counts.begin() + rabit::GetRank(), 0);
|
||||
int nccl_nranks = std::accumulate(device_counts.begin(),
|
||||
device_counts.end(), 0);
|
||||
nccl_rank += nccl_rank_offset;
|
||||
|
||||
GroupStart();
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
safe_cuda(cudaSetDevice(device_ordinals[i]));
|
||||
safe_cuda(cudaStreamCreate(&streams[i]));
|
||||
int dev = device_ordinals.at(i);
|
||||
dh::safe_cuda(cudaSetDevice(dev));
|
||||
dh::safe_nccl(ncclCommInitRank(
|
||||
&comms.at(i),
|
||||
nccl_nranks, id,
|
||||
nccl_rank));
|
||||
|
||||
nccl_rank++;
|
||||
}
|
||||
GroupEnd();
|
||||
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
|
||||
safe_cuda(cudaStreamCreate(&streams.at(i)));
|
||||
}
|
||||
initialised_ = true;
|
||||
#else
|
||||
@@ -1010,7 +1040,30 @@ class AllReducer {
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int RootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rabit::GetRank() == RootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
rabit::Broadcast(
|
||||
(void*)&id,
|
||||
(size_t)sizeof(ncclUniqueId),
|
||||
(int)RootRank);
|
||||
return id;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class SaveCudaContext {
|
||||
|
||||
@@ -628,10 +628,12 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
DeviceSplitCandidate best_split;
|
||||
for (auto candidate : split_candidates) {
|
||||
best_split.Update(candidate, param);
|
||||
}
|
||||
|
||||
return best_split;
|
||||
}
|
||||
|
||||
@@ -1049,7 +1051,8 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx) {
|
||||
if (shards_.size() == 1) return;
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed())
|
||||
return;
|
||||
monitor_.Start("AllReduce");
|
||||
|
||||
reducer_.GroupStart();
|
||||
@@ -1080,6 +1083,9 @@ class GPUHistMakerSpecialised{
|
||||
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
|
||||
}
|
||||
|
||||
rabit::Allreduce<rabit::op::Max, size_t>(&left_node_max_elements, 1);
|
||||
rabit::Allreduce<rabit::op::Max, size_t>(&right_node_max_elements, 1);
|
||||
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
@@ -1142,9 +1148,12 @@ class GPUHistMakerSpecialised{
|
||||
tmp_sums[i] = dh::SumReduction(
|
||||
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
|
||||
});
|
||||
|
||||
GradientPair sum_gradient =
|
||||
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>((GradientPair::ValueT*)&sum_gradient, 2);
|
||||
|
||||
// Generate root histogram
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
|
||||
Reference in New Issue
Block a user