Support bitwise allreduce operations in the communicator (#8623)

This commit is contained in:
Rong Ou
2022-12-24 14:40:05 -08:00
committed by GitHub
parent c7e82b5914
commit 77b069c25d
10 changed files with 207 additions and 26 deletions

View File

@@ -4,6 +4,7 @@
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <bitset>
#include <thread>
#include "../../../src/collective/in_memory_communicator.h"
@@ -13,7 +14,37 @@ namespace collective {
class InMemoryCommunicatorTest : public ::testing::Test {
public:
static void VerifyAllreduce(int rank) {
static void Verify(void (*function)(int)) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(function, rank);
}
for (auto &thread : threads) {
thread.join();
}
}
static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
int expected[] = {3, 4, 5, 6, 7};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void AllreduceMin(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin);
int expected[] = {1, 2, 3, 4, 5};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void AllreduceSum(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
@@ -23,7 +54,35 @@ class InMemoryCommunicatorTest : public ::testing::Test {
}
}
static void VerifyBroadcast(int rank) {
static void AllreduceBitwiseAND(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND);
EXPECT_EQ(buffer, 0UL);
}
static void AllreduceBitwiseOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR);
std::bitset<2> actual(buffer);
std::bitset<2> expected{0b11};
EXPECT_EQ(actual, expected);
}
static void AllreduceBitwiseXOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<3> original(rank * 2);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR);
std::bitset<3> actual(buffer);
std::bitset<3> expected{0b110};
EXPECT_EQ(actual, expected);
}
static void Broadcast(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
if (rank == 0) {
std::string buffer{"hello"};
@@ -88,25 +147,19 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(InMemoryCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyAllreduce, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
TEST_F(InMemoryCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyBroadcast, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); }
TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); }
TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); }
} // namespace collective
} // namespace xgboost