/** * Copyright 2023-2024, XGBoost contributors */ #include "../../src/collective/comm.h" // for Comm, Coll #include "federated_coll.h" // for FederatedColl #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span namespace xgboost::collective { class CUDAFederatedColl : public Coll { std::shared_ptr p_impl_; public: explicit CUDAFederatedColl(std::shared_ptr pimpl) : p_impl_{std::move(pimpl)} {} [[nodiscard]] Result Allreduce(Comm const &comm, common::Span data, ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) override; }; } // namespace xgboost::collective