/** * Copyright 2023, XGBoost Contributors */ #include #include // for Json #include "../../../../src/collective/comm_group.h" #include "../../helpers.h" #include "test_worker.h" namespace xgboost::collective { TEST(CommGroup, FederatedGPU) { std::int32_t n_workers = common::AllVisibleGPUs(); TestFederatedGroup(n_workers, [&](std::shared_ptr comm_group, std::int32_t r) { Context ctx = MakeCUDACtx(0); auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CUDA(0)); ASSERT_EQ(comm_group->Rank(), r); ASSERT_EQ(comm.TaskID(), std::to_string(r)); ASSERT_EQ(comm.Retry(), 2); }); } } // namespace xgboost::collective