Refactor device communicator to make allreduce more flexible (#9295)

This commit is contained in:
Rong Ou
2023-06-13 12:53:03 -07:00
committed by GitHub
parent c2f0486d37
commit e70810be8a
11 changed files with 190 additions and 106 deletions

View File

@@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include "test_quantile.h"
#include "../helpers.h"
#include "../../../src/collective/device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
@@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
communicator->Synchronize();
collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
all_workers.size());
collective::Synchronize(device);
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());