[coll] Add federated coll. (#9738)
- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
@@ -24,7 +24,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
Coll() = default;
|
||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||
|
||||
Coll* MakeCUDAVar();
|
||||
virtual Coll* MakeCUDAVar();
|
||||
|
||||
/**
|
||||
* @brief Allreduce
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "allgather.h"
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
@@ -48,6 +49,14 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
|
||||
@@ -20,14 +20,14 @@
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
|
||||
kRootRank);
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -53,7 +53,7 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
}
|
||||
|
||||
@@ -76,6 +76,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
@@ -93,7 +94,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, &nccl_unique_id_);
|
||||
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class Coll;
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
class Comm : public std::enable_shared_from_this<Comm> {
|
||||
protected:
|
||||
std::int32_t world_{-1};
|
||||
std::int32_t rank_{0};
|
||||
@@ -87,7 +87,7 @@ class Comm {
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
|
||||
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
|
||||
@@ -163,6 +163,12 @@ inline void AssertGPUSupport() {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
inline void AssertNCCLSupport() {
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
LOG(FATAL) << "XGBoost version not compiled with NCCL support.";
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
}
|
||||
|
||||
inline void AssertOneAPISupport() {
|
||||
#ifndef XGBOOST_USE_ONEAPI
|
||||
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";
|
||||
|
||||
Reference in New Issue
Block a user