85 lines
2.7 KiB
C++
85 lines
2.7 KiB
C++
/**
|
|
* Copyright 2022-2023, XGBoost contributors
|
|
*/
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <string> // for string
|
|
#include <thread> // for thread
|
|
|
|
#include "../../../../plugin/federated/federated_comm.h"
|
|
#include "../../collective/net_test.h" // for SocketTest
|
|
#include "../../helpers.h" // for ExpectThrow
|
|
#include "test_worker.h" // for TestFederated
|
|
#include "xgboost/json.h" // for Json
|
|
|
|
namespace xgboost::collective {
|
|
namespace {
|
|
class FederatedCommTest : public SocketTest {};
|
|
} // namespace
|
|
|
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) {
|
|
auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; };
|
|
ExpectThrow<dmlc::Error>("Invalid world size.", construct);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, ThrowOnRankTooSmall) {
|
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; };
|
|
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, ThrowOnRankTooBig) {
|
|
auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; };
|
|
ExpectThrow<dmlc::Error>("Invalid worker rank.", construct);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) {
|
|
auto construct = [] {
|
|
Json config{Object{}};
|
|
config["federated_server_address"] = std::string("localhost:0");
|
|
config["federated_world_size"] = std::string("1");
|
|
config["federated_rank"] = Integer(0);
|
|
FederatedComm comm(config);
|
|
};
|
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, ThrowOnRankNotInteger) {
|
|
auto construct = [] {
|
|
Json config{Object{}};
|
|
config["federated_server_address"] = std::string("localhost:0");
|
|
config["federated_world_size"] = 1;
|
|
config["federated_rank"] = std::string("0");
|
|
FederatedComm comm(config);
|
|
};
|
|
ExpectThrow<dmlc::Error>("got: `String`", construct);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, GetWorldSizeAndRank) {
|
|
Json config{Object{}};
|
|
config["federated_world_size"] = 6;
|
|
config["federated_rank"] = 3;
|
|
config["federated_server_address"] = String{"localhost:0"};
|
|
FederatedComm comm{config};
|
|
EXPECT_EQ(comm.World(), 6);
|
|
EXPECT_EQ(comm.Rank(), 3);
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, IsDistributed) {
|
|
FederatedComm comm{"localhost", 0, 2, 1};
|
|
EXPECT_TRUE(comm.IsDistributed());
|
|
}
|
|
|
|
TEST_F(FederatedCommTest, InsecureTracker) {
|
|
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
|
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
|
Json config{Object{}};
|
|
config["federated_world_size"] = n_workers;
|
|
config["federated_rank"] = rank;
|
|
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
|
FederatedComm comm{config};
|
|
ASSERT_EQ(comm.Rank(), rank);
|
|
ASSERT_EQ(comm.World(), n_workers);
|
|
});
|
|
}
|
|
} // namespace xgboost::collective
|