[coll] Add comm group. (#9759)
- Implement `CommGroup` for double dispatching. - Small cleanup to tracker for handling abort.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string> // for string
|
||||
@@ -19,12 +20,14 @@ class FederatedCommTest : public SocketTest {};
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid world size")));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
||||
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
||||
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
||||
ASSERT_THAT(construct,
|
||||
::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr("Invalid worker rank.")));
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
||||
@@ -38,7 +41,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = std::string("1");
|
||||
config["federated_rank"] = Integer(0);
|
||||
FederatedComm comm(config);
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
@@ -49,7 +52,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
||||
config["federated_server_address"] = std::string("localhost:0");
|
||||
config["federated_world_size"] = 1;
|
||||
config["federated_rank"] = std::string("0");
|
||||
FederatedComm comm(config);
|
||||
FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config);
|
||||
};
|
||||
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
||||
}
|
||||
@@ -59,7 +62,7 @@ TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
||||
config["federated_world_size"] = 6;
|
||||
config["federated_rank"] = 3;
|
||||
config["federated_server_address"] = String{"localhost:0"};
|
||||
FederatedComm comm{config};
|
||||
FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config};
|
||||
EXPECT_EQ(comm.World(), 6);
|
||||
EXPECT_EQ(comm.Rank(), 3);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user