Improve allgather functions (#9649)

This commit is contained in:
Rong Ou
2023-10-12 08:31:43 -07:00
committed by GitHub
parent d1dee4ad99
commit e164d51c43
20 changed files with 346 additions and 122 deletions

View File

@@ -29,6 +29,11 @@ class InMemoryCommunicatorTest : public ::testing::Test {
VerifyAllgather(comm, rank);
}
static void AllgatherV(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllgatherV(comm, rank);
}
static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
VerifyAllreduceMax(comm, rank);
@@ -80,14 +85,19 @@ class InMemoryCommunicatorTest : public ::testing::Test {
protected:
static void VerifyAllgather(InMemoryCommunicator &comm, int rank) {
char buffer[kWorldSize] = {'a', 'b', 'c'};
buffer[rank] = '0' + rank;
comm.AllGather(buffer, kWorldSize);
std::string input{static_cast<char>('0' + rank)};
auto output = comm.AllGather(input);
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(buffer[i], '0' + i);
EXPECT_EQ(output[i], static_cast<char>('0' + i));
}
}
static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) {
std::vector<std::string_view> inputs{"a", "bb", "ccc"};
auto output = comm.AllGatherV(inputs[rank]);
EXPECT_EQ(output, "abbccc");
}
static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) {
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
@@ -205,6 +215,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); }
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }