Fix device communicator dependency (#9346)

This commit is contained in:
Rong Ou
2023-06-28 19:34:30 -07:00
committed by GitHub
parent f4798718c7
commit f90771eec6
10 changed files with 107 additions and 123 deletions

View File

@@ -31,7 +31,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
protected:
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0};
int buffer[kWorldSize] = {0, 0};
buffer[rank] = rank;
comm.AllGather(buffer, sizeof(buffer));
for (auto i = 0; i < kWorldSize; i++) {
@@ -42,7 +42,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {3, 6, 9, 12, 15};
int expected[] = {2, 4, 6, 8, 10};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}