diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 95c0824d9..387b4a368 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -115,7 +115,7 @@ FederatedTracker::~FederatedTracker() = default; Result FederatedTracker::Shutdown() { auto rc = this->WaitUntilReady(); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); try { server_->Shutdown(); @@ -132,7 +132,7 @@ Result FederatedTracker::Shutdown() { std::string host; rc = GetHostAddress(&host); - CHECK(rc.OK()); + SafeColl(rc); Json args{Object{}}; args["dmlc_tracker_uri"] = String{host}; args["dmlc_tracker_port"] = this->Port(); diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index 61e34cb57..7764a2adc 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -117,7 +117,7 @@ class Worker : public WorkerForTest { common::Span{sizes.data(), sizes.size()}, common::Span{recv_segments.data(), recv_segments.size()}, common::EraseType(s_recv), AllgatherVAlgo::kBcast); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); CheckV(s_recv); // Test inplace @@ -130,7 +130,7 @@ class Worker : public WorkerForTest { common::Span{sizes.data(), sizes.size()}, common::Span{recv_segments.data(), recv_segments.size()}, common::EraseType(s_recv), algo); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); CheckV(s_recv); }; diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 1ce2f35fd..6af659a3f 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -24,7 +24,7 @@ class AllreduceWorker : public WorkerForTest { rhs[i] += lhs[i]; } }); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); } { @@ -34,7 +34,7 @@ class AllreduceWorker : public WorkerForTest { rhs[i] += lhs[i]; } }); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); ASSERT_EQ(data[0], static_cast(comm_.World())); } } @@ -49,7 +49,7 @@ class AllreduceWorker : public WorkerForTest { rhs[i] += lhs[i]; } }); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (auto v : data) { ASSERT_EQ(v, comm_.World()); } @@ -62,7 +62,7 @@ class AllreduceWorker : public WorkerForTest { rhs[i] += lhs[i]; } }); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc index c7229ff77..d73191361 100644 --- a/tests/cpp/collective/test_coll_c_api.cc +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -41,7 +41,7 @@ TEST_F(TrackerAPITest, CAPI) { auto args = Json::Load(StringView{cargs}); std::string host; - ASSERT_TRUE(GetHostAddress(&host).OK()); + SafeColl(GetHostAddress(&host)); ASSERT_EQ(host, get(args["dmlc_tracker_uri"])); auto port = get(args["dmlc_tracker_port"]); ASSERT_NE(port, 0); diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index c1eb06465..dc3351b2b 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -47,6 +47,6 @@ TEST_F(CommTest, Channel) { w.join(); } - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index ea57da9b4..8e455d100 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -70,7 +70,7 @@ TEST_F(SocketTest, Bind) { auto sock = TCPSocket::Create(domain); std::int32_t port{0}; auto rc = sock.Bind(any, &port); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); ASSERT_NE(port, 0); }; diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index e31e26628..e44760597 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -59,7 +59,7 @@ TEST_F(TrackerTest, Print) { std::vector workers; auto rc = tracker.WaitUntilReady(); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); std::int32_t port = tracker.Port(); @@ -74,7 +74,7 @@ TEST_F(TrackerTest, Print) { w.join(); } - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); } TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } @@ -88,7 +88,7 @@ TEST_F(TrackerTest, AfterShutdown) { std::vector workers; auto rc = tracker.WaitUntilReady(); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); std::int32_t port = tracker.Port(); @@ -101,7 +101,7 @@ TEST_F(TrackerTest, AfterShutdown) { w.join(); } - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); // Launch workers again, they should fail. workers.clear(); diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index f1889200b..230c3796d 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -50,8 +50,8 @@ class WorkerForTest { for (std::int32_t i = 0; i < comm_.World(); ++i) { if (i != comm_.Rank()) { ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); - ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); - ASSERT_TRUE(comm_.Chan(i)->Socket()->SetNoDelay().OK()); + SafeColl(comm_.Chan(i)->Socket()->SetBufSize(n_bytes)); + SafeColl(comm_.Chan(i)->Socket()->SetNoDelay()); } } } @@ -131,7 +131,7 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { t.join(); } - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); } inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, @@ -182,7 +182,7 @@ void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need t.join(); } - ASSERT_TRUE(fut.get().OK()); + SafeColl(fut.get()); system::SocketFinalize(); } diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index 6b7000ef9..6c5c74f4f 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -28,7 +28,7 @@ TEST_F(FederatedCollTest, Allreduce) { auto coll = std::make_shared(); auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), ArrayInterfaceHandler::kI4, Op::kSum); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (auto i = 0; i < 5; i++) { ASSERT_EQ(buffer[i], expected[i]); } @@ -49,7 +49,7 @@ TEST_F(FederatedCollTest, Broadcast) { rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); ASSERT_EQ(buffer, "hello"); } - ASSERT_TRUE(rc.OK()); + SafeColl(rc); }); } @@ -61,7 +61,7 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); } @@ -87,7 +87,7 @@ TEST_F(FederatedCollTest, AllgatherV) { common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing); EXPECT_EQ(r, "Federated Learning!!!"); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); }); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index 008952a4f..f3b906613 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -41,7 +41,7 @@ void TestAllreduce(std::shared_ptr comm, std::int32_t rank, std:: auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), ArrayInterfaceHandler::kI4, Op::kSum); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (auto i = 0; i < 5; i++) { ASSERT_EQ(buffer[i], expected[i]); } @@ -63,7 +63,7 @@ void TestBroadcast(std::shared_ptr comm, std::int32_t rank) { rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); ASSERT_EQ(buffer, expect); } - ASSERT_TRUE(rc.OK()); + SafeColl(rc); } void TestAllgather(std::shared_ptr comm, std::int32_t rank, std::int32_t n_workers) { @@ -72,7 +72,7 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); } @@ -92,7 +92,7 @@ void TestAllgatherV(std::shared_ptr comm, std::int32_t rank) { auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), common::Span{sizes.data(), sizes.size()}, recv_segments, common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); ASSERT_EQ(r[0], 1); for (std::size_t i = 1; i < r.size(); ++i) { diff --git a/tests/cpp/plugin/federated/test_federated_tracker.cc b/tests/cpp/plugin/federated/test_federated_tracker.cc index aa979ff15..19bfd798a 100644 --- a/tests/cpp/plugin/federated/test_federated_tracker.cc +++ b/tests/cpp/plugin/federated/test_federated_tracker.cc @@ -28,8 +28,8 @@ TEST(FederatedTrackerTest, Basic) { ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); rc = tracker->Shutdown(); - ASSERT_TRUE(rc.OK()); - ASSERT_TRUE(fut.get().OK()); + SafeColl(rc); + SafeColl(fut.get()); ASSERT_FALSE(tracker->Ready()); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index 8ec76237d..5f53965da 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -38,7 +38,7 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) { std::vector workers; using namespace std::chrono_literals; auto rc = tracker.WaitUntilReady(); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); std::int32_t port = tracker.Port(); for (std::int32_t i = 0; i < n_workers; ++i) { @@ -50,8 +50,8 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) { } rc = tracker.Shutdown(); - ASSERT_TRUE(rc.OK()) << rc.Report(); - ASSERT_TRUE(fut.get().OK()); + SafeColl(rc); + SafeColl(fut.get()); } template