Fix device communicator dependency (#9346)
This commit is contained in:
@@ -31,7 +31,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||
|
||||
protected:
|
||||
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
||||
int buffer[kWorldSize] = {0, 0, 0};
|
||||
int buffer[kWorldSize] = {0, 0};
|
||||
buffer[rank] = rank;
|
||||
comm.AllGather(buffer, sizeof(buffer));
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
@@ -42,7 +42,7 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||
static void CheckAllreduce(FederatedCommunicator &comm) {
|
||||
int buffer[] = {1, 2, 3, 4, 5};
|
||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||
int expected[] = {3, 6, 9, 12, 15};
|
||||
int expected[] = {2, 4, 6, 8, 10};
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
EXPECT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user