[REVIEW] Enable Multi-Node Multi-GPU functionality (#4095)
* Initial commit to support multi-node multi-gpu xgboost using dask * Fixed NCCL initialization by not ignoring the opg parameter. - it now crashes on NCCL initialization, but at least we're attempting it properly * At the root node, perform a rabit::Allreduce to get initial sum_gradient across workers * Synchronizing in a couple of more places. - now the workers don't go down, but just hang - no more "wild" values of gradients - probably needs syncing in more places * Added another missing max-allreduce operation inside BuildHistLeftRight * Removed unnecessary collective operations. * Simplified rabit::Allreduce() sync of gradient sums. * Removed unnecessary rabit syncs around ncclAllReduce. - this improves performance _significantly_ (7x faster for overall training, 20x faster for xgboost proper) * pulling in latest xgboost * removing changes to updater_quantile_hist.cc * changing use_nccl_opg initialization, removing unnecessary if statements * added definition for opaque ncclUniqueId struct to properly encapsulate GetUniqueId * placing struct defintion in guard to avoid duplicate code errors * addressing linting errors * removing * removing additional arguments to AllReduer initialization * removing distributed flag * making comm init symmetric * removing distributed flag * changing ncclCommInit to support multiple modalities * fix indenting * updating ncclCommInitRank block with necessary group calls * fix indenting * adding print statement, and updating accessor in vector * improving print statement to end-line * generalizing nccl_rank construction using rabit * assume device_ordinals is the same for every node * test, assume device_ordinals is identical for all nodes * test, assume device_ordinals is unique for all nodes * changing names of offset variable to be more descriptive, editing indenting * wrapping ncclUniqueId GetUniqueId() and aesthetic changes * adding synchronization, and tests for distributed * adding to tests * fixing broken #endif * fixing initialization of gpu histograms, correcting errors in tests * adding to contributors list * adding distributed tests to jenkins * fixing bad path in distributed test * debugging * adding kubernetes for distributed tests * adding proper import for OrderedDict * adding urllib3==1.22 to address ordered_dict import error * added sleep to allow workers to save their models for comparison * adding name to GPU contributors under docs
This commit is contained in:
parent
9fefa2128d
commit
92b7577c62
@ -85,5 +85,6 @@ List of Contributors
|
||||
* [Andrew Thia](https://github.com/BlueTea88)
|
||||
- Andrew Thia implemented feature interaction constraints
|
||||
* [Wei Tian](https://github.com/weitian)
|
||||
* [Chen Qin] (https://github.com/chenqin)
|
||||
* [Chen Qin](https://github.com/chenqin)
|
||||
* [Sam Wilkinson](https://samwilkinson.io)
|
||||
* [Matthew Jones](https://github.com/mt-jones)
|
||||
|
||||
@ -208,6 +208,7 @@ Many thanks to the following contributors (alphabetical order):
|
||||
* Andrey Adinets
|
||||
* Jiaming Yuan
|
||||
* Jonathan C. McKinney
|
||||
* Matthew Jones
|
||||
* Philip Cho
|
||||
* Rory Mitchell
|
||||
* Shankara Rao Thejaswi Nanditale
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#include "../common/io.h"
|
||||
#endif
|
||||
|
||||
// Uncomment to enable
|
||||
@ -853,6 +854,8 @@ class AllReducer {
|
||||
std::vector<ncclComm_t> comms;
|
||||
std::vector<cudaStream_t> streams;
|
||||
std::vector<int> device_ordinals; // device id from CUDA
|
||||
std::vector<int> device_counts; // device count from CUDA
|
||||
ncclUniqueId id;
|
||||
#endif
|
||||
|
||||
public:
|
||||
@ -872,14 +875,41 @@ class AllReducer {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/** \brief this >monitor . init. */
|
||||
this->device_ordinals = device_ordinals;
|
||||
comms.resize(device_ordinals.size());
|
||||
dh::safe_nccl(ncclCommInitAll(comms.data(),
|
||||
static_cast<int>(device_ordinals.size()),
|
||||
device_ordinals.data()));
|
||||
streams.resize(device_ordinals.size());
|
||||
this->device_counts.resize(rabit::GetWorldSize());
|
||||
this->comms.resize(device_ordinals.size());
|
||||
this->streams.resize(device_ordinals.size());
|
||||
this->id = GetUniqueId();
|
||||
|
||||
device_counts.at(rabit::GetRank()) = device_ordinals.size();
|
||||
for (size_t i = 0; i < device_counts.size(); i++) {
|
||||
int dev_count = device_counts.at(i);
|
||||
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
|
||||
device_counts.at(i) = dev_count;
|
||||
}
|
||||
|
||||
int nccl_rank = 0;
|
||||
int nccl_rank_offset = std::accumulate(device_counts.begin(),
|
||||
device_counts.begin() + rabit::GetRank(), 0);
|
||||
int nccl_nranks = std::accumulate(device_counts.begin(),
|
||||
device_counts.end(), 0);
|
||||
nccl_rank += nccl_rank_offset;
|
||||
|
||||
GroupStart();
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
safe_cuda(cudaSetDevice(device_ordinals[i]));
|
||||
safe_cuda(cudaStreamCreate(&streams[i]));
|
||||
int dev = device_ordinals.at(i);
|
||||
dh::safe_cuda(cudaSetDevice(dev));
|
||||
dh::safe_nccl(ncclCommInitRank(
|
||||
&comms.at(i),
|
||||
nccl_nranks, id,
|
||||
nccl_rank));
|
||||
|
||||
nccl_rank++;
|
||||
}
|
||||
GroupEnd();
|
||||
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
|
||||
safe_cuda(cudaStreamCreate(&streams.at(i)));
|
||||
}
|
||||
initialised_ = true;
|
||||
#else
|
||||
@ -1010,7 +1040,30 @@ class AllReducer {
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int RootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rabit::GetRank() == RootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
rabit::Broadcast(
|
||||
(void*)&id,
|
||||
(size_t)sizeof(ncclUniqueId),
|
||||
(int)RootRank);
|
||||
return id;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
class SaveCudaContext {
|
||||
|
||||
@ -628,10 +628,12 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
DeviceSplitCandidate best_split;
|
||||
for (auto candidate : split_candidates) {
|
||||
best_split.Update(candidate, param);
|
||||
}
|
||||
|
||||
return best_split;
|
||||
}
|
||||
|
||||
@ -1049,7 +1051,8 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx) {
|
||||
if (shards_.size() == 1) return;
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed())
|
||||
return;
|
||||
monitor_.Start("AllReduce");
|
||||
|
||||
reducer_.GroupStart();
|
||||
@ -1080,6 +1083,9 @@ class GPUHistMakerSpecialised{
|
||||
right_node_max_elements, shard->ridx_segments[nidx_right].Size());
|
||||
}
|
||||
|
||||
rabit::Allreduce<rabit::op::Max, size_t>(&left_node_max_elements, 1);
|
||||
rabit::Allreduce<rabit::op::Max, size_t>(&right_node_max_elements, 1);
|
||||
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
|
||||
@ -1142,9 +1148,12 @@ class GPUHistMakerSpecialised{
|
||||
tmp_sums[i] = dh::SumReduction(
|
||||
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
|
||||
});
|
||||
|
||||
GradientPair sum_gradient =
|
||||
std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair());
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>((GradientPair::ValueT*)&sum_gradient, 2);
|
||||
|
||||
// Generate root histogram
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
|
||||
@ -35,7 +35,7 @@ ENV CPP=/opt/rh/devtoolset-2/root/usr/bin/cpp
|
||||
|
||||
# Install Python packages
|
||||
RUN \
|
||||
pip install numpy pytest scipy scikit-learn wheel
|
||||
pip install numpy pytest scipy scikit-learn wheel kubernetes urllib3==1.22
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
|
||||
@ -6,3 +6,6 @@ python setup.py install --user
|
||||
cd ..
|
||||
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
|
||||
./testxgboost --gtest_filter=*.MGPU_*
|
||||
|
||||
cd tests/distributed-gpu
|
||||
./runtests-gpu.sh
|
||||
19
tests/distributed-gpu/runtests-gpu.sh
Executable file
19
tests/distributed-gpu/runtests-gpu.sh
Executable file
@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm -f *.model*
|
||||
|
||||
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
||||
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=4 \
|
||||
python test_gpu_basic_1x4.py
|
||||
|
||||
echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n"
|
||||
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
|
||||
python test_gpu_basic_2x2.py
|
||||
|
||||
echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n"
|
||||
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \
|
||||
python test_gpu_basic_asym.py
|
||||
|
||||
echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
|
||||
PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \
|
||||
python test_gpu_basic_4x1.py
|
||||
51
tests/distributed-gpu/test_gpu_basic_1x4.py
Normal file
51
tests/distributed-gpu/test_gpu_basic_1x4.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Always call this before using distributed module
|
||||
xgb.rabit.init()
|
||||
rank = xgb.rabit.get_rank()
|
||||
world = xgb.rabit.get_world_size()
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 20
|
||||
|
||||
# Run training, all the features in training API is available.
|
||||
# Currently, this script only support calling train once for fault recovery purpose.
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Have each worker save its model
|
||||
model_name = "test.model.1x4." + str(rank)
|
||||
bst.dump_model(model_name, with_stats=True); time.sleep(2)
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
fail = False
|
||||
if (rank == 0):
|
||||
for i in range(0, world):
|
||||
model_name_root = "test.model.1x4." + str(i)
|
||||
for j in range(0, world):
|
||||
if i != j:
|
||||
with open(model_name_root, 'r') as model_root:
|
||||
model_name_rank = "test.model.1x4." + str(j)
|
||||
with open(model_name_rank, 'r') as model_rank:
|
||||
diff = set(model_root).difference(model_rank)
|
||||
if len(diff) != 0:
|
||||
fail = True
|
||||
xgb.rabit.finalize()
|
||||
raise Exception('Worker models diverged: test.model.1x4.{} differs from test.model.1x4.{}'.format(i, j))
|
||||
|
||||
if (rank != 0) and (fail):
|
||||
xgb.rabit.finalize()
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
xgb.rabit.finalize()
|
||||
51
tests/distributed-gpu/test_gpu_basic_2x2.py
Normal file
51
tests/distributed-gpu/test_gpu_basic_2x2.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Always call this before using distributed module
|
||||
xgb.rabit.init()
|
||||
rank = xgb.rabit.get_rank()
|
||||
world = xgb.rabit.get_world_size()
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
param = {'n_gpus': 2, 'gpu_id': 2*rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 20
|
||||
|
||||
# Run training, all the features in training API is available.
|
||||
# Currently, this script only support calling train once for fault recovery purpose.
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Have each worker save its model
|
||||
model_name = "test.model.2x2." + str(rank)
|
||||
bst.dump_model(model_name, with_stats=True); time.sleep(2)
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
fail = False
|
||||
if (rank == 0):
|
||||
for i in range(0, world):
|
||||
model_name_root = "test.model.2x2." + str(i)
|
||||
for j in range(0, world):
|
||||
if i != j:
|
||||
with open(model_name_root, 'r') as model_root:
|
||||
model_name_rank = "test.model.2x2." + str(j)
|
||||
with open(model_name_rank, 'r') as model_rank:
|
||||
diff = set(model_root).difference(model_rank)
|
||||
if len(diff) != 0:
|
||||
fail = True
|
||||
xgb.rabit.finalize()
|
||||
raise Exception('Worker models diverged: test.model.2x2.{} differs from test.model.2x2.{}'.format(i, j))
|
||||
|
||||
if (rank != 0) and (fail):
|
||||
xgb.rabit.finalize()
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
xgb.rabit.finalize()
|
||||
34
tests/distributed-gpu/test_gpu_basic_4x1.py
Normal file
34
tests/distributed-gpu/test_gpu_basic_4x1.py
Normal file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Always call this before using distributed module
|
||||
xgb.rabit.init()
|
||||
rank = xgb.rabit.get_rank()
|
||||
world = xgb.rabit.get_world_size()
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
param = {'n_gpus': 4, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 20
|
||||
|
||||
# Run training, all the features in training API is available.
|
||||
# Currently, this script only support calling train once for fault recovery purpose.
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Have root save its model
|
||||
if(rank == 0):
|
||||
model_name = "test.model.4x1." + str(rank)
|
||||
bst.dump_model(model_name, with_stats=True)
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
xgb.rabit.finalize()
|
||||
54
tests/distributed-gpu/test_gpu_basic_asym.py
Normal file
54
tests/distributed-gpu/test_gpu_basic_asym.py
Normal file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Always call this before using distributed module
|
||||
xgb.rabit.init()
|
||||
rank = xgb.rabit.get_rank()
|
||||
world = xgb.rabit.get_world_size()
|
||||
|
||||
# Load file, file will be automatically sharded in distributed mode.
|
||||
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
if rank == 0:
|
||||
param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||
else:
|
||||
param = {'n_gpus': 3, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 20
|
||||
|
||||
# Run training, all the features in training API is available.
|
||||
# Currently, this script only support calling train once for fault recovery purpose.
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Have each worker save its model
|
||||
model_name = "test.model.asym." + str(rank)
|
||||
bst.dump_model(model_name, with_stats=True); time.sleep(2)
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
fail = False
|
||||
if (rank == 0):
|
||||
for i in range(0, world):
|
||||
model_name_root = "test.model.asym." + str(i)
|
||||
for j in range(0, world):
|
||||
if i != j:
|
||||
with open(model_name_root, 'r') as model_root:
|
||||
model_name_rank = "test.model.asym." + str(j)
|
||||
with open(model_name_rank, 'r') as model_rank:
|
||||
diff = set(model_root).difference(model_rank)
|
||||
if len(diff) != 0:
|
||||
fail = True
|
||||
xgb.rabit.finalize()
|
||||
raise Exception('Worker models diverged: test.model.asym.{} differs from test.model.asym.{}'.format(i, j))
|
||||
|
||||
if (rank != 0) and (fail):
|
||||
xgb.rabit.finalize()
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
xgb.rabit.finalize()
|
||||
Loading…
x
Reference in New Issue
Block a user