Revert "Switch to per-thread default stream (#9396)" (#9413)

This reverts commit f7f673b00c.
This commit is contained in:
Jiaming Yuan
2023-07-25 03:03:28 +08:00
committed by GitHub
parent 1b657a5513
commit 3a9996173e
8 changed files with 35 additions and 25 deletions

View File

@@ -44,12 +44,16 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
@@ -119,8 +123,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) {
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
@@ -138,22 +142,25 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream()));
nccl_comm_, cuda_stream_));
if (needs_sync_) {
dh::DefaultStream().Sync();
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
@@ -172,7 +179,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream()));
cuda_stream_));
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
@@ -199,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream()));
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
@@ -210,7 +217,7 @@ void NcclDeviceCommunicator::Synchronize() {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::DefaultStream().Sync();
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
} // namespace collective

View File

@@ -77,6 +77,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.