- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
21 lines
660 B
Plaintext
21 lines
660 B
Plaintext
/**
|
|
* Copyright 2023, XGBoost Contributors
|
|
*/
|
|
#include <memory> // for shared_ptr
|
|
|
|
#include "../../src/common/cuda_context.cuh"
|
|
#include "federated_comm.cuh"
|
|
#include "xgboost/context.h" // for Context
|
|
|
|
namespace xgboost::collective {
|
|
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
|
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
|
CHECK(impl);
|
|
}
|
|
|
|
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
|
return new CUDAFederatedComm{
|
|
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
|
|
}
|
|
} // namespace xgboost::collective
|