Fix device communicator dependency (#9346)

This commit is contained in:
Rong Ou 2023-06-28 19:34:30 -07:00 committed by GitHub
parent f4798718c7
commit f90771eec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 107 additions and 123 deletions

View File

@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_world_size = communicator_->GetWorldSize(); old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) { if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get())); device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
} else { } else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
} }
#else #else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get())); device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif #endif
} }
return device_communicator_.get(); return device_communicator_.get();

View File

@ -11,21 +11,18 @@ namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator { class DeviceCommunicatorAdapter : public DeviceCommunicator {
public: public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator) explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, communicator_{communicator} { : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) { if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
} }
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
} }
~DeviceCommunicatorAdapter() override = default; ~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override { Operation op) override {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
@ -33,37 +30,34 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
auto size = count * GetTypeSize(data_type); auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size); host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op); Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
} }
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments, void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override { dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear(); segments->clear();
segments->resize(world_size, 0); segments->resize(world_size_, 0);
segments->at(rank) = length_bytes; segments->at(rank_) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes); receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes); host_buffer_.reserve(total_bytes);
size_t offset = 0; size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) { for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i); size_t as_bytes = segments->at(i);
if (i == rank) { if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank), dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault)); cudaMemcpyDefault));
} }
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i); Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes; offset += as_bytes;
} }
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
private: private:
int const device_ordinal_; int const device_ordinal_;
Communicator *communicator_; int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions. /// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{}; std::vector<char> host_buffer_{};
}; };

View File

@ -7,31 +7,24 @@
namespace xgboost { namespace xgboost {
namespace collective { namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
: device_ordinal_{device_ordinal}, communicator_{communicator} { : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) { if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
} }
if (communicator_ == nullptr) { if (world_size_ == 1) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return; return;
} }
std::vector<uint64_t> uuids(world * kUuidLength, 0); std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()}; auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid); GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather. // TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world); std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0; size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) { for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength}; converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator
auto iter = std::unique(converted.begin(), converted.end()); auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter); auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world) CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA " << "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; << "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId(); nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
} }
NcclDeviceCommunicator::~NcclDeviceCommunicator() { NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
if (cuda_stream_) { if (cuda_stream_) {
@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func,
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count, void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) { DataType data_type, Operation op) {
auto const world_size = communicator_->GetWorldSize();
auto const size = count * GetTypeSize(data_type); auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size); dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get(); auto *device_buffer = buffer.data().get();
// First gather data from all the workers. // First gather data from all the workers.
@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *out_buffer = static_cast<char *>(send_receive_buffer); auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) { switch (op) {
case Operation::kBitwiseAND: case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_); cuda_stream_);
break; break;
case Operation::kBitwiseOR: case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_); cuda_stream_);
break; break;
case Operation::kBitwiseXOR: case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size, RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_); cuda_stream_);
break; break;
default: default:
@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count, void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) { DataType data_type, Operation op) {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes, void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) { dh::caching_device_vector<char> *receive_buffer) {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear(); segments->clear();
segments->resize(world_size, 0); segments->resize(world_size_, 0);
segments->at(rank) = length_bytes; segments->at(rank_) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes); receive_buffer->resize(total_bytes);
size_t offset = 0; size_t offset = 0;
dh::safe_nccl(ncclGroupStart()); dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) { for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i); size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_)); ncclChar, i, nccl_comm_, cuda_stream_));
@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
} }
void NcclDeviceCommunicator::Synchronize() { void NcclDeviceCommunicator::Synchronize() {
if (communicator_->GetWorldSize() == 1) { if (world_size_ == 1) {
return; return;
} }
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));

View File

@ -12,7 +12,7 @@ namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator { class NcclDeviceCommunicator : public DeviceCommunicator {
public: public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator); explicit NcclDeviceCommunicator(int device_ordinal);
~NcclDeviceCommunicator() override; ~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override; Operation op) override;
@ -49,11 +49,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
ncclUniqueId GetUniqueId() { ncclUniqueId GetUniqueId() {
static const int kRootRank = 0; static const int kRootRank = 0;
ncclUniqueId id; ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) { if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id)); dh::safe_nccl(ncclGetUniqueId(&id));
} }
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
static_cast<int>(kRootRank));
return id; return id;
} }
@ -61,7 +60,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
Operation op); Operation op);
int const device_ordinal_; int const device_ordinal_;
Communicator *communicator_; int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{}; ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{}; cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{}; ncclUniqueId nccl_unique_id_{};

View File

@ -16,12 +16,7 @@ namespace xgboost {
namespace collective { namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; }; auto construct = []() { NcclDeviceCommunicator comm{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error); EXPECT_THROW(construct(), dmlc::Error);
} }

View File

@ -37,7 +37,14 @@ class ServerForTest {
} }
~ServerForTest() { ~ServerForTest() {
using namespace std::chrono_literals;
while (!server_) {
std::this_thread::sleep_for(100ms);
}
server_->Shutdown(); server_->Shutdown();
while (!server_thread_) {
std::this_thread::sleep_for(100ms);
}
server_thread_->join(); server_thread_->join();
} }
@ -56,7 +63,7 @@ class BaseFederatedTest : public ::testing::Test {
void TearDown() override { server_.reset(nullptr); } void TearDown() override { server_.reset(nullptr); }
static int constexpr kWorldSize{3}; static int constexpr kWorldSize{2};
std::unique_ptr<ServerForTest> server_; std::unique_ptr<ServerForTest> server_;
}; };

View File

@ -9,6 +9,7 @@
#include <thread> #include <thread>
#include "../../../plugin/federated/federated_communicator.h" #include "../../../plugin/federated/federated_communicator.h"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/device_communicator_adapter.cuh" #include "../../../src/collective/device_communicator_adapter.cuh"
#include "./helpers.h" #include "./helpers.h"
@ -17,67 +18,63 @@ namespace xgboost::collective {
class FederatedAdapterTest : public BaseFederatedTest {}; class FederatedAdapterTest : public BaseFederatedTest {};
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; }; auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; };
EXPECT_THROW(construct(), dmlc::Error); EXPECT_THROW(construct(), dmlc::Error);
} }
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { namespace {
auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; }; void VerifyAllReduceSum() {
EXPECT_THROW(construct(), dmlc::Error); auto const world_size = collective::GetWorldSize();
} auto const rank = collective::GetRank();
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back([rank, server_address = server_->Address()] {
FederatedCommunicator comm{kWorldSize, rank, server_address};
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
DeviceCommunicatorAdapter adapter{0, &comm};
int count = 3; int count = 3;
thrust::device_vector<double> buffer(count, 0); thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end()); thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum); collective::AllReduce<collective::Operation::kSum>(rank, buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer; thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count); EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) { for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * kWorldSize); EXPECT_EQ(host_buffer[i], i * world_size);
}
});
}
for (auto& thread : threads) {
thread.join();
} }
} }
} // anonymous namespace
TEST_F(FederatedAdapterTest, DeviceAllGatherV) { TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
std::vector<std::thread> threads; auto const n_gpus = common::AllVisibleGPUs();
for (auto rank = 0; rank < kWorldSize; rank++) { if (n_gpus <= 1) {
threads.emplace_back([rank, server_address = server_->Address()] { GTEST_SKIP() << "Skipping MGPUAllReduceSum test with # GPUs = " << n_gpus;
FederatedCommunicator comm{kWorldSize, rank, server_address}; }
// Assign device 0 to all workers, since we run gtest in a single-GPU machine RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
DeviceCommunicatorAdapter adapter{0, &comm}; }
namespace {
void VerifyAllGatherV() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
int const count = rank + 2; int const count = rank + 2;
thrust::device_vector<char> buffer(count, 0); thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end()); thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(kWorldSize); std::vector<std::size_t> segments(world_size);
dh::caching_device_vector<char> receive_buffer{}; dh::caching_device_vector<char> receive_buffer{};
adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer); collective::AllGatherV(rank, buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2); EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3); EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer; thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 9); EXPECT_EQ(host_buffer.size(), 5);
int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3}; int expected[] = {0, 1, 0, 1, 2};
for (auto i = 0; i < 9; i++) { for (auto i = 0; i < 5; i++) {
EXPECT_EQ(host_buffer[i], expected[i]); EXPECT_EQ(host_buffer[i], expected[i]);
} }
}); }
} } // anonymous namespace
for (auto& thread : threads) {
thread.join(); TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllGatherV test with # GPUs = " << n_gpus;
} }
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -31,7 +31,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
protected: protected:
static void CheckAllgather(FederatedCommunicator &comm, int rank) { static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0}; int buffer[kWorldSize] = {0, 0};
buffer[rank] = rank; buffer[rank] = rank;
comm.AllGather(buffer, sizeof(buffer)); comm.AllGather(buffer, sizeof(buffer));
for (auto i = 0; i < kWorldSize; i++) { for (auto i = 0; i < kWorldSize; i++) {
@ -42,7 +42,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
static void CheckAllreduce(FederatedCommunicator &comm) { static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5}; int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {3, 6, 9, 12, 15}; int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) { for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]); EXPECT_EQ(buffer[i], expected[i]);
} }

View File

@ -30,7 +30,7 @@ void VerifyLoadUri() {
std::string uri = path + "?format=csv"; std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol)); dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3); ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 1);
ASSERT_EQ(dmat->Info().num_row_, kRows); ASSERT_EQ(dmat->Info().num_row_, kRows);
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {

View File

@ -39,7 +39,7 @@ class FederatedServerTest : public BaseFederatedTest {
protected: protected:
static void CheckAllgather(federated::FederatedClient& client, int rank) { static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[kWorldSize] = {0, 0, 0}; int data[kWorldSize] = {0, 0};
data[rank] = rank; data[rank] = rank;
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data)); std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allgather(send_buffer); auto reply = client.Allgather(send_buffer);
@ -54,7 +54,7 @@ class FederatedServerTest : public BaseFederatedTest {
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data)); std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
auto const* result = reinterpret_cast<int const*>(reply.data()); auto const* result = reinterpret_cast<int const*>(reply.data());
int expected[] = {3, 6, 9, 12, 15}; int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) { for (auto i = 0; i < 5; i++) {
EXPECT_EQ(result[i], expected[i]); EXPECT_EQ(result[i], expected[i]);
} }