Refactor device communicator to make allreduce more flexible (#9295)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/collective/device_communicator.cuh"
|
||||
#include "../../../src/collective/communicator-inl.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/quantile.cuh"
|
||||
|
||||
@@ -464,10 +464,9 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
||||
thrust::copy(thrust::device, local_data.data(),
|
||||
local_data.data() + local_data.size(),
|
||||
all_workers.begin() + local_data.size() * rank);
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(device);
|
||||
|
||||
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
|
||||
communicator->Synchronize();
|
||||
collective::AllReduce<collective::Operation::kSum>(device, all_workers.data().get(),
|
||||
all_workers.size());
|
||||
collective::Synchronize(device);
|
||||
|
||||
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
||||
std::vector<float> h_base_line(base_line.size());
|
||||
|
||||
@@ -36,7 +36,7 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
||||
int count = 3;
|
||||
thrust::device_vector<double> buffer(count, 0);
|
||||
thrust::sequence(buffer.begin(), buffer.end());
|
||||
adapter.AllReduceSum(buffer.data().get(), count);
|
||||
adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum);
|
||||
thrust::host_vector<double> host_buffer = buffer;
|
||||
EXPECT_EQ(host_buffer.size(), count);
|
||||
for (auto i = 0; i < count; i++) {
|
||||
|
||||
Reference in New Issue
Block a user