- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
27 lines
1.3 KiB
Plaintext
27 lines
1.3 KiB
Plaintext
/**
|
|
* Copyright 2023, XGBoost contributors
|
|
*/
|
|
#include "../../src/collective/comm.h" // for Comm, Coll
|
|
#include "federated_coll.h" // for FederatedColl
|
|
#include "xgboost/collective/result.h" // for Result
|
|
#include "xgboost/span.h" // for Span
|
|
|
|
namespace xgboost::collective {
|
|
class CUDAFederatedColl : public Coll {
|
|
std::shared_ptr<FederatedColl> p_impl_;
|
|
|
|
public:
|
|
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
|
|
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
|
ArrayInterfaceHandler::Type type, Op op) override;
|
|
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
|
std::int32_t root) override;
|
|
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
|
std::int64_t size) override;
|
|
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
|
common::Span<std::int64_t const> sizes,
|
|
common::Span<std::int64_t> recv_segments,
|
|
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
|
};
|
|
} // namespace xgboost::collective
|