/** * Copyright 2023, XGBoost Contributors */ #include // for shared_ptr #include "../../src/common/cuda_context.cuh" #include "federated_comm.cuh" #include "xgboost/context.h" // for Context namespace xgboost::collective { CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr impl) : FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} { CHECK(impl); CHECK(ctx->IsCUDA()); dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); } Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr) const { return new CUDAFederatedComm{ ctx, std::dynamic_pointer_cast(this->shared_from_this())}; } } // namespace xgboost::collective