Refactor device communicator to make allreduce more flexible (#9295)
This commit is contained in:
81
src/collective/communicator-inl.cuh
Normal file
81
src/collective/communicator-inl.cuh
Normal file
@@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -17,32 +17,15 @@ class DeviceCommunicator {
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
|
||||
@@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
@@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||
}
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
cuda_stream_));
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
@@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return id;
|
||||
}
|
||||
|
||||
template <ncclDataType_t data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result;
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result;
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
|
||||
Reference in New Issue
Block a user