further cleanup of single process multi-GPU code (#4810)

* use subspan in gpu predictor instead of copying
* Revise `HostDeviceVector`
This commit is contained in:
Rong Ou
2019-08-30 02:27:23 -07:00
committed by Jiaming Yuan
parent 0184eb5d02
commit 733ed24dd9
12 changed files with 289 additions and 593 deletions

View File

@@ -238,8 +238,7 @@ class MemoryLogger {
device_allocations.erase(itr);
}
};
std::map<int, DeviceStats>
stats_; // Map device ordinal to memory information
DeviceStats stats_;
std::mutex mutex_;
public:
@@ -249,8 +248,8 @@ public:
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterAllocation(ptr, n);
CHECK_LE(stats_[current_device].peak_allocated_bytes, dh::TotalMemory(current_device));
stats_.RegisterAllocation(ptr, n);
CHECK_LE(stats_.peak_allocated_bytes, dh::TotalMemory(current_device));
}
void RegisterDeallocation(void *ptr, size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
@@ -258,19 +257,19 @@ public:
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterDeallocation(ptr, n, current_device);
stats_.RegisterDeallocation(ptr, n, current_device);
}
void Log() {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
return;
std::lock_guard<std::mutex> guard(mutex_);
for (const auto &kv : stats_) {
LOG(CONSOLE) << "======== Device " << kv.first << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< kv.second.peak_allocated_bytes / 1000000 << "mb";
LOG(CONSOLE) << "Number of allocations: " << kv.second.num_allocations;
}
int current_device;
safe_cuda(cudaGetDevice(&current_device));
LOG(CONSOLE) << "======== Device " << current_device << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< stats_.peak_allocated_bytes / 1000000 << "mb";
LOG(CONSOLE) << "Number of allocations: " << stats_.num_allocations;
}
};
};
@@ -940,10 +939,9 @@ class AllReducer {
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
std::vector<size_t> host_data; // Used for all reduce on host
#ifdef XGBOOST_USE_NCCL
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals; // device id from CUDA
std::vector<int> device_counts; // device count from CUDA
ncclComm_t comm;
cudaStream_t stream;
int device_ordinal;
ncclUniqueId id;
#endif
@@ -952,79 +950,28 @@ class AllReducer {
allreduce_calls_(0) {}
/**
* \brief If we are using a single GPU only
*/
bool IsSingleGPU() {
#ifdef XGBOOST_USE_NCCL
CHECK(device_counts.size() > 0) << "AllReducer not initialised.";
return device_counts.size() <= 1 && device_counts.at(0) == 1;
#else
return true;
#endif
}
/**
* \brief Initialise with the desired device ordinals for this communication
* \brief Initialise with the desired device ordinal for this communication
* group.
*
* \param device_ordinals The device ordinals.
* \param device_ordinal The device ordinal.
*/
void Init(const std::vector<int> &device_ordinals) {
void Init(int _device_ordinal) {
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->device_ordinals = device_ordinals;
this->device_counts.resize(rabit::GetWorldSize());
this->comms.resize(device_ordinals.size());
this->streams.resize(device_ordinals.size());
this->id = GetUniqueId();
device_counts.at(rabit::GetRank()) = device_ordinals.size();
for (size_t i = 0; i < device_counts.size(); i++) {
int dev_count = device_counts.at(i);
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
device_counts.at(i) = dev_count;
}
int nccl_rank = 0;
int nccl_rank_offset = std::accumulate(device_counts.begin(),
device_counts.begin() + rabit::GetRank(), 0);
int nccl_nranks = std::accumulate(device_counts.begin(),
device_counts.end(), 0);
nccl_rank += nccl_rank_offset;
GroupStart();
for (size_t i = 0; i < device_ordinals.size(); i++) {
int dev = device_ordinals.at(i);
dh::safe_cuda(cudaSetDevice(dev));
dh::safe_nccl(ncclCommInitRank(
&comms.at(i),
nccl_nranks, id,
nccl_rank));
nccl_rank++;
}
GroupEnd();
for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
safe_cuda(cudaStreamCreate(&streams.at(i)));
}
device_ordinal = _device_ordinal;
id = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank()));
safe_cuda(cudaStreamCreate(&stream));
initialised_ = true;
#else
CHECK_EQ(device_ordinals.size(), 1)
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
#endif
}
~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised_) {
for (auto &stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
for (auto &comm : comms) {
ncclCommDestroy(comm);
}
dh::safe_cuda(cudaStreamDestroy(stream));
ncclCommDestroy(comm);
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
@@ -1035,20 +982,21 @@ class AllReducer {
}
/**
* \brief Use in exactly the same way as ncclGroupStart
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void GroupStart() {
#ifdef XGBOOST_USE_NCCL
dh::safe_nccl(ncclGroupStart());
#endif
}
/**
* \brief Use in exactly the same way as ncclGroupEnd
*/
void GroupEnd() {
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
dh::safe_nccl(ncclGroupEnd());
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm, stream));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
#endif
}
@@ -1056,51 +1004,18 @@ class AllReducer {
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(int communication_group_idx, const double *sendbuff,
double *recvbuff, int count) {
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
#endif
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(int communication_group_idx, const float *sendbuff,
float *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
}
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm, stream));
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
#endif
}
@@ -1109,21 +1024,17 @@ class AllReducer {
*
* \param count Number of.
*
* \param communication_group_idx Zero-based index of the communication group. \param sendbuff.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
int64_t *recvbuff, int count) {
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
comms[communication_group_idx],
streams[communication_group_idx]));
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm, stream));
#endif
}
@@ -1134,26 +1045,8 @@ class AllReducer {
*/
void Synchronize() {
#ifdef XGBOOST_USE_NCCL
for (size_t i = 0; i < device_ordinals.size(); i++) {
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
}
#endif
};
/**
* \brief Synchronizes the device
*
* \param device_id Identifier for the device.
*/
void Synchronize(int device_id) {
#ifdef XGBOOST_USE_NCCL
SaveCudaContext([&]() {
dh::safe_cuda(cudaSetDevice(device_id));
int idx = std::find(device_ordinals.begin(), device_ordinals.end(), device_id) - device_ordinals.begin();
CHECK(idx < device_ordinals.size());
dh::safe_cuda(cudaStreamSynchronize(streams[idx]));
});
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_cuda(cudaStreamSynchronize(stream));
#endif
};
@@ -1219,58 +1112,6 @@ class AllReducer {
}
};
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element. In addition, passes the shard index
* into the function.
*
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*/
template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext{[&]() {
// Temporarily turn off dynamic so we have a guaranteed number of threads
bool dynamic = omp_get_dynamic();
omp_set_dynamic(false);
const long shards_size = static_cast<long>(shards->size());
#pragma omp parallel for schedule(static, 1) if (shards_size > 1) num_threads(shards_size)
for (long shard = 0; shard < shards_size; ++shard) {
f(shard, shards->at(shard));
}
omp_set_dynamic(dynamic);
}};
}
/**
* \brief Executes some operation on each element of the input vector, using a single controlling
* thread for each element, returns the sum of the results.
*
* \tparam ReduceT Type of the reduce t.
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*
* \return A reduce_t.
*/
template <typename ReduceT, typename ShardT, typename FunctionT>
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
std::vector<ReduceT> sums(shards->size());
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
sums[shard] = f(shards->at(shard));
}}
};
return std::accumulate(sums.begin(), sums.end(), ReduceT());
}
template <typename T,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(