[coll] Add comm group. (#9759)
- Implement `CommGroup` for double dispatching. - Small cleanup to tracker for handling abort.
This commit is contained in:
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
22
tests/cpp/plugin/federated/test_federated_comm_group.cu
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/json.h> // for Json
|
||||
|
||||
#include "../../../../src/collective/comm_group.h"
|
||||
#include "../../helpers.h"
|
||||
#include "test_worker.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(CommGroup, FederatedGPU) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederatedGroup(n_workers, [&](std::shared_ptr<CommGroup> comm_group, std::int32_t r) {
|
||||
Context ctx = MakeCUDACtx(0);
|
||||
auto const& comm = comm_group->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||
ASSERT_EQ(comm_group->Rank(), r);
|
||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||
ASSERT_EQ(comm.Retry(), 2);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user