xgboost/plugin/federated/federated_comm.cu
Jiaming Yuan 6c0a190f6d
[coll] Add comm group. (#9759)
- Implement `CommGroup` for double dispatching.
- Small cleanup to tracker for handling abort.
2023-11-07 11:12:31 +08:00

23 lines
732 B
Plaintext

/**
* Copyright 2023, XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include "../../src/common/cuda_context.cuh"
#include "federated_comm.cuh"
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
CHECK(impl);
CHECK(ctx->IsCUDA());
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
}
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
return new CUDAFederatedComm{
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
}
} // namespace xgboost::collective