Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
}
void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
}
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
@@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id;
}
template <ncclDataType_t data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
}
int const device_ordinal_;