Consistently report error in tests. (#10453)

This commit is contained in:
Jiaming Yuan 2024-06-21 14:35:22 +08:00 committed by GitHub
parent b38c7fe2ce
commit 26eb68859f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 32 additions and 32 deletions

View File

@ -115,7 +115,7 @@ FederatedTracker::~FederatedTracker() = default;
Result FederatedTracker::Shutdown() { Result FederatedTracker::Shutdown() {
auto rc = this->WaitUntilReady(); auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report(); SafeColl(rc);
try { try {
server_->Shutdown(); server_->Shutdown();
@ -132,7 +132,7 @@ Result FederatedTracker::Shutdown() {
std::string host; std::string host;
rc = GetHostAddress(&host); rc = GetHostAddress(&host);
CHECK(rc.OK()); SafeColl(rc);
Json args{Object{}}; Json args{Object{}};
args["dmlc_tracker_uri"] = String{host}; args["dmlc_tracker_uri"] = String{host};
args["dmlc_tracker_port"] = this->Port(); args["dmlc_tracker_port"] = this->Port();

View File

@ -117,7 +117,7 @@ class Worker : public WorkerForTest {
common::Span{sizes.data(), sizes.size()}, common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()}, common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), AllgatherVAlgo::kBcast); common::EraseType(s_recv), AllgatherVAlgo::kBcast);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
CheckV(s_recv); CheckV(s_recv);
// Test inplace // Test inplace
@ -130,7 +130,7 @@ class Worker : public WorkerForTest {
common::Span{sizes.data(), sizes.size()}, common::Span{sizes.data(), sizes.size()},
common::Span{recv_segments.data(), recv_segments.size()}, common::Span{recv_segments.data(), recv_segments.size()},
common::EraseType(s_recv), algo); common::EraseType(s_recv), algo);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
CheckV(s_recv); CheckV(s_recv);
}; };

View File

@ -24,7 +24,7 @@ class AllreduceWorker : public WorkerForTest {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK()); SafeColl(rc);
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
} }
{ {
@ -34,7 +34,7 @@ class AllreduceWorker : public WorkerForTest {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK()); SafeColl(rc);
ASSERT_EQ(data[0], static_cast<double>(comm_.World())); ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
} }
} }
@ -49,7 +49,7 @@ class AllreduceWorker : public WorkerForTest {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (auto v : data) { for (auto v : data) {
ASSERT_EQ(v, comm_.World()); ASSERT_EQ(v, comm_.World());
} }
@ -62,7 +62,7 @@ class AllreduceWorker : public WorkerForTest {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (std::size_t i = 0; i < data.size(); ++i) { for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i]; auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i; ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;

View File

@ -41,7 +41,7 @@ TEST_F(TrackerAPITest, CAPI) {
auto args = Json::Load(StringView{cargs}); auto args = Json::Load(StringView{cargs});
std::string host; std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK()); SafeColl(GetHostAddress(&host));
ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"])); ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
auto port = get<Integer const>(args["dmlc_tracker_port"]); auto port = get<Integer const>(args["dmlc_tracker_port"]);
ASSERT_NE(port, 0); ASSERT_NE(port, 0);

View File

@ -47,6 +47,6 @@ TEST_F(CommTest, Channel) {
w.join(); w.join();
} }
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -70,7 +70,7 @@ TEST_F(SocketTest, Bind) {
auto sock = TCPSocket::Create(domain); auto sock = TCPSocket::Create(domain);
std::int32_t port{0}; std::int32_t port{0};
auto rc = sock.Bind(any, &port); auto rc = sock.Bind(any, &port);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
ASSERT_NE(port, 0); ASSERT_NE(port, 0);
}; };

View File

@ -59,7 +59,7 @@ TEST_F(TrackerTest, Print) {
std::vector<std::thread> workers; std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady(); auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK()); SafeColl(rc);
std::int32_t port = tracker.Port(); std::int32_t port = tracker.Port();
@ -74,7 +74,7 @@ TEST_F(TrackerTest, Print) {
w.join(); w.join();
} }
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
} }
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
@ -88,7 +88,7 @@ TEST_F(TrackerTest, AfterShutdown) {
std::vector<std::thread> workers; std::vector<std::thread> workers;
auto rc = tracker.WaitUntilReady(); auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK()); SafeColl(rc);
std::int32_t port = tracker.Port(); std::int32_t port = tracker.Port();
@ -101,7 +101,7 @@ TEST_F(TrackerTest, AfterShutdown) {
w.join(); w.join();
} }
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
// Launch workers again, they should fail. // Launch workers again, they should fail.
workers.clear(); workers.clear();

View File

@ -50,8 +50,8 @@ class WorkerForTest {
for (std::int32_t i = 0; i < comm_.World(); ++i) { for (std::int32_t i = 0; i < comm_.World(); ++i) {
if (i != comm_.Rank()) { if (i != comm_.Rank()) {
ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking());
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); SafeColl(comm_.Chan(i)->Socket()->SetBufSize(n_bytes));
ASSERT_TRUE(comm_.Chan(i)->Socket()->SetNoDelay().OK()); SafeColl(comm_.Chan(i)->Socket()->SetNoDelay());
} }
} }
} }
@ -131,7 +131,7 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
t.join(); t.join();
} }
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
} }
inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
@ -182,7 +182,7 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need
t.join(); t.join();
} }
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
system::SocketFinalize(); system::SocketFinalize();
} }

View File

@ -28,7 +28,7 @@ TEST_F(FederatedCollTest, Allreduce) {
auto coll = std::make_shared<FederatedColl>(); auto coll = std::make_shared<FederatedColl>();
auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
ArrayInterfaceHandler::kI4, Op::kSum); ArrayInterfaceHandler::kI4, Op::kSum);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (auto i = 0; i < 5; i++) { for (auto i = 0; i < 5; i++) {
ASSERT_EQ(buffer[i], expected[i]); ASSERT_EQ(buffer[i], expected[i]);
} }
@ -49,7 +49,7 @@ TEST_F(FederatedCollTest, Broadcast) {
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
ASSERT_EQ(buffer, "hello"); ASSERT_EQ(buffer, "hello");
} }
ASSERT_TRUE(rc.OK()); SafeColl(rc);
}); });
} }
@ -61,7 +61,7 @@ TEST_F(FederatedCollTest, Allgather) {
std::vector<std::int32_t> buffer(n_workers, 0); std::vector<std::int32_t> buffer(n_workers, 0);
buffer[comm->Rank()] = comm->Rank(); buffer[comm->Rank()] = comm->Rank();
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}));
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (auto i = 0; i < n_workers; i++) { for (auto i = 0; i < n_workers; i++) {
ASSERT_EQ(buffer[i], i); ASSERT_EQ(buffer[i], i);
} }
@ -87,7 +87,7 @@ TEST_F(FederatedCollTest, AllgatherV) {
common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing); common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing);
EXPECT_EQ(r, "Federated Learning!!!"); EXPECT_EQ(r, "Federated Learning!!!");
ASSERT_TRUE(rc.OK()); SafeColl(rc);
}); });
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -41,7 +41,7 @@ void TestAllreduce(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::
auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)),
ArrayInterfaceHandler::kI4, Op::kSum); ArrayInterfaceHandler::kI4, Op::kSum);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (auto i = 0; i < 5; i++) { for (auto i = 0; i < 5; i++) {
ASSERT_EQ(buffer[i], expected[i]); ASSERT_EQ(buffer[i], expected[i]);
} }
@ -63,7 +63,7 @@ void TestBroadcast(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
ASSERT_EQ(buffer, expect); ASSERT_EQ(buffer, expect);
} }
ASSERT_TRUE(rc.OK()); SafeColl(rc);
} }
void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) { void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
@ -72,7 +72,7 @@ void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::
dh::device_vector<std::int32_t> buffer(n_workers, 0); dh::device_vector<std::int32_t> buffer(n_workers, 0);
buffer[comm->Rank()] = comm->Rank(); buffer[comm->Rank()] = comm->Rank();
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)));
ASSERT_TRUE(rc.OK()); SafeColl(rc);
for (auto i = 0; i < n_workers; i++) { for (auto i = 0; i < n_workers; i++) {
ASSERT_EQ(buffer[i], i); ASSERT_EQ(buffer[i], i);
} }
@ -92,7 +92,7 @@ void TestAllgatherV(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])),
common::Span{sizes.data(), sizes.size()}, recv_segments, common::Span{sizes.data(), sizes.size()}, recv_segments,
common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
ASSERT_EQ(r[0], 1); ASSERT_EQ(r[0], 1);
for (std::size_t i = 1; i < r.size(); ++i) { for (std::size_t i = 1; i < r.size(); ++i) {

View File

@ -28,8 +28,8 @@ TEST(FederatedTrackerTest, Basic) {
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host); ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
rc = tracker->Shutdown(); rc = tracker->Shutdown();
ASSERT_TRUE(rc.OK()); SafeColl(rc);
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
ASSERT_FALSE(tracker->Ready()); ASSERT_FALSE(tracker->Ready());
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -38,7 +38,7 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
std::vector<std::thread> workers; std::vector<std::thread> workers;
using namespace std::chrono_literals; using namespace std::chrono_literals;
auto rc = tracker.WaitUntilReady(); auto rc = tracker.WaitUntilReady();
ASSERT_TRUE(rc.OK()) << rc.Report(); SafeColl(rc);
std::int32_t port = tracker.Port(); std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) { for (std::int32_t i = 0; i < n_workers; ++i) {
@ -50,8 +50,8 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
} }
rc = tracker.Shutdown(); rc = tracker.Shutdown();
ASSERT_TRUE(rc.OK()) << rc.Report(); SafeColl(rc);
ASSERT_TRUE(fut.get().OK()); SafeColl(fut.get());
} }
template <typename WorkerFn> template <typename WorkerFn>