xgboost/plugin/federated/federated_comm.cuh
Jiaming Yuan bc995a4865
[coll] Add federated coll. (#9738)
- Define a new data type, the proto file is copied for now.
- Merge client and communicator into `FederatedColl`.
- Define CUDA variant.
- Migrate tests for CPU, add tests for CUDA.
2023-11-01 04:06:46 +08:00

21 lines
612 B
Plaintext

/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
dh::CUDAStreamView stream_;
public:
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
[[nodiscard]] auto Stream() const { return stream_; }
};
} // namespace xgboost::collective