Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -8,6 +8,7 @@
#include <string> // for string
#include "../../../src/collective/nccl_device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh"
namespace xgboost {
namespace collective {

View File

@@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include "test_quantile.h"
#include "../helpers.h"
#include "../../../src/collective/device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
@@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
communicator->Synchronize();
collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
all_workers.size());
collective::Synchronize(device);
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());

View File

@@ -36,7 +36,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
int count = 3;
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduceSum(buffer.data().get(), count);
adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {