Fix device communicator dependency (#9346)

This commit is contained in:
Rong Ou
2023-06-28 19:34:30 -07:00
committed by GitHub
parent f4798718c7
commit f90771eec6
10 changed files with 107 additions and 123 deletions

View File

@@ -16,12 +16,7 @@ namespace xgboost {
namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}

View File

@@ -37,7 +37,14 @@ class ServerForTest {
}
~ServerForTest() {
using namespace std::chrono_literals;
while (!server_) {
std::this_thread::sleep_for(100ms);
}
server_->Shutdown();
while (!server_thread_) {
std::this_thread::sleep_for(100ms);
}
server_thread_->join();
}
@@ -56,7 +63,7 @@ class BaseFederatedTest : public ::testing::Test {
void TearDown() override { server_.reset(nullptr); }
static int constexpr kWorldSize{3};
static int constexpr kWorldSize{2};
std::unique_ptr<ServerForTest> server_;
};

View File

@@ -9,6 +9,7 @@
#include <thread>
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/device_communicator_adapter.cuh"
#include "./helpers.h"
@@ -17,67 +18,63 @@ namespace xgboost::collective {
class FederatedAdapterTest : public BaseFederatedTest {};
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back([rank, server_address = server_->Address()] {
FederatedCommunicator comm{kWorldSize, rank, server_address};
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
DeviceCommunicatorAdapter adapter{0, &comm};
int count = 3;
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduce(buffer.data().get(), count, DataType::kDouble, Operation::kSum);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * kWorldSize);
}
});
}
for (auto& thread : threads) {
thread.join();
namespace {
void VerifyAllReduceSum() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
int count = 3;
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
collective::AllReduce<collective::Operation::kSum>(rank, buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * world_size);
}
}
} // anonymous namespace
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back([rank, server_address = server_->Address()] {
FederatedCommunicator comm{kWorldSize, rank, server_address};
// Assign device 0 to all workers, since we run gtest in a single-GPU machine
DeviceCommunicatorAdapter adapter{0, &comm};
int const count = rank + 2;
thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(kWorldSize);
dh::caching_device_vector<char> receive_buffer{};
adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 9);
int expected[] = {0, 1, 0, 1, 2, 0, 1, 2, 3};
for (auto i = 0; i < 9; i++) {
EXPECT_EQ(host_buffer[i], expected[i]);
}
});
TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceSum test with # GPUs = " << n_gpus;
}
for (auto& thread : threads) {
thread.join();
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
}
namespace {
void VerifyAllGatherV() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
int const count = rank + 2;
thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(world_size);
dh::caching_device_vector<char> receive_buffer{};
collective::AllGatherV(rank, buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 5);
int expected[] = {0, 1, 0, 1, 2};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(host_buffer[i], expected[i]);
}
}
} // anonymous namespace
TEST_F(FederatedAdapterTest, MGPUAllGatherV) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllGatherV test with # GPUs = " << n_gpus;
}
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV);
}
} // namespace xgboost::collective

View File

@@ -31,7 +31,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
protected:
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0};
int buffer[kWorldSize] = {0, 0};
buffer[rank] = rank;
comm.AllGather(buffer, sizeof(buffer));
for (auto i = 0; i < kWorldSize; i++) {
@@ -42,7 +42,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {3, 6, 9, 12, 15};
int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}

View File

@@ -30,7 +30,7 @@ void VerifyLoadUri() {
std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 1);
ASSERT_EQ(dmat->Info().num_row_, kRows);
for (auto const& page : dmat->GetBatches<SparsePage>()) {

View File

@@ -39,7 +39,7 @@ class FederatedServerTest : public BaseFederatedTest {
protected:
static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[kWorldSize] = {0, 0, 0};
int data[kWorldSize] = {0, 0};
data[rank] = rank;
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allgather(send_buffer);
@@ -54,7 +54,7 @@ class FederatedServerTest : public BaseFederatedTest {
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
auto const* result = reinterpret_cast<int const*>(reply.data());
int expected[] = {3, 6, 9, 12, 15};
int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(result[i], expected[i]);
}