Retry switching to per-thread default stream (#9416)

This commit is contained in:
Rong Ou
2023-07-25 16:09:12 -07:00
committed by GitHub
parent 54579da4d7
commit 7579905e18
9 changed files with 37 additions and 36 deletions

View File

@@ -44,16 +44,12 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
@@ -123,8 +119,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) {
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
@@ -142,25 +138,22 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_));
nccl_comm_, dh::DefaultStream()));
if (needs_sync_) {
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
dh::DefaultStream().Sync();
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
@@ -179,7 +172,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
dh::DefaultStream()));
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
@@ -206,7 +199,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
ncclChar, i, nccl_comm_, dh::DefaultStream()));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
@@ -217,7 +210,7 @@ void NcclDeviceCommunicator::Synchronize() {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
dh::DefaultStream().Sync();
}
} // namespace collective

View File

@@ -77,7 +77,6 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.