GPU performance logging/improvements (#3945)

- Improved GPU performance logging

- Only use one execute shards function

- Revert performance regression on multi-GPU

- Use threads to launch NCCL AllReduce
This commit is contained in:
Rory Mitchell
2018-11-29 14:36:51 +13:00
committed by GitHub
parent c5f92df475
commit a9d684db18
8 changed files with 127 additions and 102 deletions

View File

@@ -19,6 +19,7 @@
#include <sstream>
#include <string>
#include <vector>
#include "timer.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
@@ -840,14 +841,17 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) {
*/
class AllReducer {
bool initialised;
bool initialised_;
bool debug_verbose_;
size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
#ifdef XGBOOST_USE_NCCL
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals;
#endif
public:
AllReducer() : initialised(false) {}
AllReducer() : initialised_(false),debug_verbose_(false) {}
/**
* \fn void Init(const std::vector<int> &device_ordinals)
@@ -858,8 +862,10 @@ class AllReducer {
* \param device_ordinals The device ordinals.
*/
void Init(const std::vector<int> &device_ordinals) {
void Init(const std::vector<int> &device_ordinals, bool debug_verbose) {
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->debug_verbose_ = debug_verbose;
this->device_ordinals = device_ordinals;
comms.resize(device_ordinals.size());
dh::safe_nccl(ncclCommInitAll(comms.data(),
@@ -870,7 +876,7 @@ class AllReducer {
safe_cuda(cudaSetDevice(device_ordinals[i]));
safe_cuda(cudaStreamCreate(&streams[i]));
}
initialised = true;
initialised_ = true;
#else
CHECK_EQ(device_ordinals.size(), 1)
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
@@ -878,7 +884,7 @@ class AllReducer {
}
~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised) {
if (initialised_) {
for (auto &stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
@@ -886,6 +892,11 @@ class AllReducer {
ncclCommDestroy(comm);
}
}
if (debug_verbose_) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MB communicated: " << allreduce_bytes_/1000000;
}
#endif
}
@@ -920,11 +931,16 @@ class AllReducer {
void AllReduceSum(int communication_group_idx, const double *sendbuff,
double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised);
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
#endif
}
@@ -942,7 +958,7 @@ class AllReducer {
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
int64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised);
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
@@ -989,27 +1005,6 @@ class SaveCudaContext {
}
};
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element.
*
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*/
template <typename T, typename FunctionT>
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
f(shards->at(shard));
}
}};
}
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element. In addition, passes the shard index
@@ -1023,13 +1018,12 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext {
[&](){
SaveCudaContext{[&]() {
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard));
}
}};
for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard));
}
}};
}
/**

View File

@@ -358,7 +358,7 @@ struct GPUSketcher {
});
// compute sketches for each shard
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr<DeviceShard>& shard) {
shard->Init(batch, info);
shard->Sketch(batch, info);
});

View File

@@ -275,7 +275,7 @@ struct HostDeviceVectorImpl {
(end - begin) * sizeof(T),
cudaMemcpyDeviceToHost));
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(begin.get());
});
}
@@ -288,7 +288,7 @@ struct HostDeviceVectorImpl {
data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice));
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); });
}
}
@@ -296,7 +296,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) {
std::fill(data_h_.begin(), data_h_.end(), v);
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.Fill(v); });
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); });
}
}
@@ -323,7 +323,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin());
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(other.data());
});
}
@@ -334,7 +334,7 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin());
} else {
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.ScatterFrom(other.begin());
});
}
@@ -387,14 +387,14 @@ struct HostDeviceVectorImpl {
if (perm_h_.CanAccess(access)) { return; }
if (perm_h_.CanRead()) {
// data is present, just need to deny access to the device
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.perm_d_.DenyComplementary(access);
});
perm_h_.Grant(access);
return;
}
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); }
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) {
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.LazySyncHost(access);
});
perm_h_.Grant(access);

View File

@@ -45,9 +45,13 @@ struct Timer {
*/
struct Monitor {
struct Statistics {
Timer timer;
size_t count{0};
};
bool debug_verbose = false;
std::string label = "";
std::map<std::string, Timer> timer_map;
std::map<std::string, Statistics> statistics_map;
Timer self_timer;
Monitor() { self_timer.Start(); }
@@ -56,35 +60,46 @@ struct Monitor {
if (!debug_verbose) return;
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
for (auto &kv : timer_map) {
kv.second.PrintElapsed(kv.first);
for (auto &kv : statistics_map) {
LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds()
<< "s, " << kv.second.count << " calls @ "
<< std::chrono::duration_cast<std::chrono::microseconds>(
kv.second.timer.elapsed / kv.second.count)
.count()
<< "us";
}
self_timer.Stop();
self_timer.PrintElapsed(label + " Lifetime");
}
void Init(std::string label, bool debug_verbose) {
this->debug_verbose = debug_verbose;
this->label = label;
}
void Start(const std::string &name) { timer_map[name].Start(); }
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
void Start(const std::string &name, GPUSet devices) {
if (debug_verbose) {
#ifdef __CUDACC__
#include "device_helpers.cuh"
dh::SynchronizeNDevices(devices);
for (auto device : devices) {
cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif
}
timer_map[name].Start();
statistics_map[name].timer.Start();
}
void Stop(const std::string &name) {
statistics_map[name].timer.Stop();
statistics_map[name].count++;
}
void Stop(const std::string &name) { timer_map[name].Stop(); }
void Stop(const std::string &name, GPUSet devices) {
if (debug_verbose) {
#ifdef __CUDACC__
#include "device_helpers.cuh"
dh::SynchronizeNDevices(devices);
for (auto device : devices) {
cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif
}
timer_map[name].Stop();
this->Stop(name);
}
};
} // namespace common