Support column split in GPU evaluate splits (#9511)

This commit is contained in:
Rong Ou
2023-08-23 01:33:43 -07:00
committed by GitHub
parent 8c10af45a0
commit 6103dca0bb
11 changed files with 240 additions and 113 deletions

View File

@@ -57,6 +57,20 @@ inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.

View File

@@ -27,6 +27,17 @@ class DeviceCommunicator {
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.

View File

@@ -28,12 +28,26 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
host_buffer_.resize(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
cudaMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (world_size_ == 1) {
@@ -49,7 +63,7 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);

View File

@@ -178,6 +178,17 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream()));
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {

View File

@@ -29,6 +29,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;