Consistently report error in tests. (#10453)
This commit is contained in:
@@ -28,7 +28,7 @@ TEST_F(FederatedCollTest, Allreduce) {
|
||||
auto coll = std::make_shared<FederatedColl>();
|
||||
auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
|
||||
ArrayInterfaceHandler::kI4, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
ASSERT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
@@ -49,7 +49,7 @@ TEST_F(FederatedCollTest, Broadcast) {
|
||||
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
|
||||
ASSERT_EQ(buffer, "hello");
|
||||
}
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ TEST_F(FederatedCollTest, Allgather) {
|
||||
std::vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
}
|
||||
@@ -87,7 +87,7 @@ TEST_F(FederatedCollTest, AllgatherV) {
|
||||
common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing);
|
||||
|
||||
EXPECT_EQ(r, "Federated Learning!!!");
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -41,7 +41,7 @@ void TestAllreduce(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::
|
||||
|
||||
auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)),
|
||||
ArrayInterfaceHandler::kI4, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
ASSERT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
@@ -63,7 +63,7 @@ void TestBroadcast(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
|
||||
ASSERT_EQ(buffer, expect);
|
||||
}
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
|
||||
@@ -72,7 +72,7 @@ void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::
|
||||
dh::device_vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
}
|
||||
@@ -92,7 +92,7 @@ void TestAllgatherV(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])),
|
||||
common::Span{sizes.data(), sizes.size()}, recv_segments,
|
||||
common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
|
||||
ASSERT_EQ(r[0], 1);
|
||||
for (std::size_t i = 1; i < r.size(); ++i) {
|
||||
|
||||
@@ -28,8 +28,8 @@ TEST(FederatedTrackerTest, Basic) {
|
||||
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
|
||||
|
||||
rc = tracker->Shutdown();
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
SafeColl(rc);
|
||||
SafeColl(fut.get());
|
||||
ASSERT_FALSE(tracker->Ready());
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -38,7 +38,7 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
std::vector<std::thread> workers;
|
||||
using namespace std::chrono_literals;
|
||||
auto rc = tracker.WaitUntilReady();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
@@ -50,8 +50,8 @@ void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
}
|
||||
|
||||
rc = tracker.Shutdown();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
ASSERT_TRUE(fut.get().OK());
|
||||
SafeColl(rc);
|
||||
SafeColl(fut.get());
|
||||
}
|
||||
|
||||
template <typename WorkerFn>
|
||||
|
||||
Reference in New Issue
Block a user