Switch to per-thread default stream (#9396)
This commit is contained in:
@@ -44,16 +44,12 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
@@ -123,8 +119,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size, cudaStream_t stream) {
|
||||
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
@@ -142,25 +138,22 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
|
||||
// First gather data from all the workers.
|
||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, cuda_stream_));
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
if (needs_sync_) {
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
@@ -179,7 +172,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
} else {
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
cuda_stream_));
|
||||
dh::DefaultStream()));
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
@@ -206,7 +199,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
@@ -217,7 +210,7 @@ void NcclDeviceCommunicator::Synchronize() {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
|
||||
@@ -77,7 +77,6 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
cudaStream_t cuda_stream_{};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
Reference in New Issue
Block a user