[coll] Add federated coll. (#9738)

- Define a new data type, the proto file is copied for now.
- Merge client and communicator into `FederatedColl`.
- Define CUDA variant.
- Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
Jiaming Yuan
2023-11-01 04:06:46 +08:00
committed by GitHub
parent 6b98305db4
commit bc995a4865
24 changed files with 826 additions and 48 deletions

View File

@@ -24,7 +24,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT
Coll* MakeCUDAVar();
virtual Coll* MakeCUDAVar();
/**
* @brief Allreduce

View File

@@ -9,7 +9,8 @@
#include <string> // for string
#include <utility> // for move, forward
#include "allgather.h"
#include "../common/common.h" // for AssertGPUSupport
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
@@ -48,6 +49,14 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}
#if !defined(XGBOOST_USE_NCCL)
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_NCCL)
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,

View File

@@ -20,14 +20,14 @@
namespace xgboost::collective {
namespace {
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
kRootRank);
auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
if (!rc.OK()) {
return rc;
}
@@ -53,7 +53,7 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
}
} // namespace
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
return new NCCLComm{ctx, *this, pimpl};
}
@@ -76,6 +76,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
CHECK(rc.OK()) << rc.Report();
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
@@ -93,7 +94,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = GetUniqueId(root, &nccl_unique_id_);
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
CHECK(rc.OK()) << rc.Report();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));

View File

@@ -40,7 +40,7 @@ class Coll;
/**
* @brief Base communicator storing info about the tracker and other communicators.
*/
class Comm {
class Comm : public std::enable_shared_from_this<Comm> {
protected:
std::int32_t world_{-1};
std::int32_t rank_{0};
@@ -87,7 +87,7 @@ class Comm {
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
};
class RabitComm : public Comm {

View File

@@ -163,6 +163,12 @@ inline void AssertGPUSupport() {
#endif // XGBOOST_USE_CUDA
}
inline void AssertNCCLSupport() {
#if !defined(XGBOOST_USE_NCCL)
LOG(FATAL) << "XGBoost version not compiled with NCCL support.";
#endif // !defined(XGBOOST_USE_NCCL)
}
inline void AssertOneAPISupport() {
#ifndef XGBOOST_USE_ONEAPI
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";