diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc index 980992d61..b62abdada 100644 --- a/plugin/federated/federated_coll.cc +++ b/plugin/federated/federated_coll.cc @@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() { [[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data, std::int32_t root) { - if (comm.Rank() == root) { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } else { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } + return BroadcastImpl(comm, &this->sequence_number_, data, root); } -[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data) { using namespace federated; // NOLINT auto fed = dynamic_cast(&comm); CHECK(fed); auto stub = fed->Handle(); + auto size = data.size_bytes() / comm.World(); auto offset = comm.Rank() * size; auto segment = data.subspan(offset, size); diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu index a922e1c11..3f604c50d 100644 --- a/plugin/federated/federated_coll.cu +++ b/plugin/federated/federated_coll.cu @@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() { }; } -[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data) { auto cufed = dynamic_cast(&comm); CHECK(cufed); std::vector h_data(data.size()); @@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size); + return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh index a1121d88f..6a690a33d 100644 --- a/plugin/federated/federated_coll.cuh +++ b/plugin/federated/federated_coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "../../src/collective/comm.h" // for Comm, Coll #include "federated_coll.h" // for FederatedColl @@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h index c261b01e1..12443a3e1 100644 --- a/plugin/federated/federated_coll.h +++ b/plugin/federated/federated_coll.h @@ -1,12 +1,9 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once #include "../../src/collective/coll.h" // for Coll #include "../../src/collective/comm.h" // for Comm -#include "../../src/common/io.h" // for ReadAll -#include "../../src/common/json_utils.h" // for OptionalArg -#include "xgboost/json.h" // for Json namespace xgboost::collective { class FederatedColl : public Coll { @@ -20,8 +17,7 @@ class FederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_comm.cuh b/plugin/federated/federated_comm.cuh index 58c52f67e..85cecb3eb 100644 --- a/plugin/federated/federated_comm.cuh +++ b/plugin/federated/federated_comm.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -9,7 +9,6 @@ #include "../../src/common/device_helpers.cuh" // for CUDAStreamView #include "federated_comm.h" // for FederatedComm #include "xgboost/context.h" // for Context -#include "xgboost/logging.h" namespace xgboost::collective { class CUDAFederatedComm : public FederatedComm { diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index 750d94abd..b39e1878a 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once @@ -11,7 +11,6 @@ #include // for string #include "../../src/collective/comm.h" // for HostComm -#include "../../src/common/json_utils.h" // for OptionalArg #include "xgboost/json.h" namespace xgboost::collective { @@ -51,6 +50,10 @@ class FederatedComm : public HostComm { std::int32_t rank) { this->Init(host, port, world, rank, {}, {}, {}); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } ~FederatedComm() override { stub_.reset(); } [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { @@ -65,5 +68,13 @@ class FederatedComm : public HostComm { [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] Result ProcessorName(std::string* out) const final { + auto rank = this->Rank(); + *out = "rank:" + std::to_string(rank); + return Success(); + }; }; } // namespace xgboost::collective diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index de760d9d8..4692ad6c2 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -1,22 +1,18 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #pragma once #include #include // for int32_t -#include // for future #include "../../src/collective/in_memory_handler.h" -#include "../../src/collective/tracker.h" // for Tracker -#include "xgboost/collective/result.h" // for Result namespace xgboost::federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} + explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h index 33592fefe..ac46b6eaa 100644 --- a/plugin/federated/federated_tracker.h +++ b/plugin/federated/federated_tracker.h @@ -17,8 +17,7 @@ namespace xgboost::collective { namespace federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} + explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 598b7f2f5..79d9793e6 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void common::Span{cast_d_ptr, static_cast::index_type>(size)}, {size}, DeviceOrd::CPU()); CHECK(t.CContiguous()); - Json interface{linalg::ArrayInterface(t)}; - CHECK(ArrayInterface<1>{interface}.is_contiguous); - str = Json::Dump(interface); + Json iface{linalg::ArrayInterface(t)}; + CHECK(ArrayInterface<1>{iface}.is_contiguous); + str = Json::Dump(iface); return str; }; diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index 01713dbad..24e94f3de 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -1,8 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for seconds -#include // for size_t #include // for future #include // for unique_ptr #include // for string @@ -10,6 +9,7 @@ #include // for pair #include "../collective/tracker.h" // for RabitTracker +#include "../common/timer.h" // for Timer #include "c_api_error.h" // for API_BEGIN #include "xgboost/c_api.h" #include "xgboost/collective/result.h" // for Result @@ -40,17 +40,27 @@ struct CollAPIEntry { }; using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; -void WaitImpl(TrackerHandleT *ptr) { - std::chrono::seconds wait_for{100}; +void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { + constexpr std::int64_t kDft{60}; + std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft}; + + common::Timer timer; + timer.Start(); + auto fut = ptr->second; while (fut.valid()) { auto res = fut.wait_for(wait_for); CHECK(res != std::future_status::deferred); + if (res == std::future_status::ready) { auto const &rc = ptr->second.get(); - CHECK(rc.OK()) << rc.Report(); + collective::SafeColl(rc); break; } + + if (timer.Duration() > timeout && timeout.count() != 0) { + collective::SafeColl(collective::Fail("Timeout waiting for the tracker.")); + } } } } // namespace @@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { auto *ptr = GetTrackerHandle(handle); xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); - WaitImpl(ptr); + // Internally, 0 indicates no timeout, which is the default since we don't want to + // interrupt the model training. + auto timeout = OptionalArg(jconfig, "timeout", std::int64_t{0}); + WaitImpl(ptr, std::chrono::seconds{timeout}); API_END(); } XGB_DLL int XGTrackerFree(TrackerHandle handle) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); - WaitImpl(ptr); + WaitImpl(ptr, ptr->first->Timeout()); delete ptr; API_END(); } diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 148cb6cd2..446db73b5 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allgather.h" @@ -7,6 +7,7 @@ #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr +#include // for move #include "broadcast.h" #include "comm.h" // for Comm, Channel @@ -29,16 +30,20 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto rc = Success() << [&] { auto send_rank = (rank + world - r + worker_off) % world; auto send_off = send_rank * segment_size; - send_off = std::min(send_off, data.size_bytes()); - auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); + bool is_last_segment = send_rank == (world - 1); + auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size; + auto send_seg = data.subspan(send_off, send_nbytes); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); } << [&] { auto recv_rank = (rank + world - r - 1 + worker_off) % world; auto recv_off = recv_rank * segment_size; - recv_off = std::min(recv_off, data.size_bytes()); - auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); + bool is_last_segment = recv_rank == (world - 1); + auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size; + auto recv_seg = data.subspan(recv_off, recv_nbytes); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return prev_ch->Block(); + }; if (!rc.OK()) { return rc; } @@ -91,7 +96,9 @@ namespace detail { auto recv_size = sizes[recv_rank]; auto recv_seg = erased_result.subspan(recv_off, recv_size); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return prev_ch->Block(); + }; if (!rc.OK()) { return rc; } diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 4f13014be..8de9f1984 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -1,25 +1,27 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for size_t #include // for int32_t #include // for shared_ptr #include // for accumulate +#include // for string #include // for remove_cv_t #include // for vector -#include "../common/type.h" // for EraseType +#include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel +#include "comm_group.h" // for CommGroup #include "xgboost/collective/result.h" // for Result -#include "xgboost/linalg.h" -#include "xgboost/span.h" // for Span +#include "xgboost/linalg.h" // for MakeVec +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** * @param worker_off Segment offset. For example, if the rank 2 worker specifies - * worker_off = 1, then it owns the third segment. + * worker_off = 1, then it owns the third segment (2 + 1). */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, @@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span sizes, } // namespace detail template -[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { - auto n_bytes = sizeof(T) * size; +[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data) { + // This function is also used for ring allreduce, hence we allow the last segment to be + // larger due to round-down. + auto n_bytes_per_segment = data.size_bytes() / comm.World(); auto erased = common::EraseType(data); auto rank = comm.Rank(); @@ -61,7 +65,7 @@ template auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); + auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch); if (!rc.OK()) { return rc; } @@ -76,7 +80,7 @@ template std::vector sizes(world, 0); sizes[rank] = data.size_bytes(); - auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}); if (!rc.OK()) { return rc; } diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index 93b76355f..d9cf8b828 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allreduce.h" @@ -16,7 +16,44 @@ #include "xgboost/span.h" // for Span namespace xgboost::collective::cpu_impl { +namespace { template +Result RingAllreduceSmall(Comm const& comm, common::Span data, Func const& op) { + auto rank = comm.Rank(); + auto world = comm.World(); + + auto next_ch = comm.Chan(BootstrapNext(rank, world)); + auto prev_ch = comm.Chan(BootstrapPrev(rank, world)); + + std::vector buffer(data.size_bytes() * world, 0); + auto s_buffer = common::Span{buffer.data(), buffer.size()}; + + auto offset = data.size_bytes() * rank; + auto self = s_buffer.subspan(offset, data.size_bytes()); + std::copy_n(data.data(), data.size_bytes(), self.data()); + + auto typed = common::RestoreType(s_buffer); + auto rc = RingAllgather(comm, typed); + + if (!rc.OK()) { + return rc; + } + auto first = s_buffer.subspan(0, data.size_bytes()); + CHECK_EQ(first.size(), data.size()); + + for (std::int32_t r = 1; r < world; ++r) { + auto offset = data.size_bytes() * r; + auto buf = s_buffer.subspan(offset, data.size_bytes()); + op(buf, first); + } + std::copy_n(first.data(), first.size(), data.data()); + + return Success(); +} +} // namespace + +template +// note that n_bytes_in_seg is calculated with round-down. Result RingScatterReduceTyped(Comm const& comm, common::Span data, std::size_t n_bytes_in_seg, Func const& op) { auto rank = comm.Rank(); @@ -27,14 +64,17 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto next_ch = comm.Chan(dst_rank); auto prev_ch = comm.Chan(src_rank); - std::vector buffer(n_bytes_in_seg, 0); + std::vector buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0); auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { // send to ring next - auto send_off = ((rank + world - r) % world) * n_bytes_in_seg; - send_off = std::min(send_off, data.size_bytes()); - auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); + auto send_rank = (rank + world - r) % world; + auto send_off = send_rank * n_bytes_in_seg; + + bool is_last_segment = send_rank == (world - 1); + + auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; auto send_seg = data.subspan(send_off, seg_nbytes); auto rc = next_ch->SendAll(send_seg); @@ -43,14 +83,21 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, } // receive from ring prev - auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; - recv_off = std::min(recv_off, data.size_bytes()); - seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg); + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = recv_rank * n_bytes_in_seg; + + is_last_segment = recv_rank == (world - 1); + + seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg; CHECK_EQ(seg_nbytes % sizeof(T), 0); auto recv_seg = data.subspan(recv_off, seg_nbytes); auto seg = s_buf.subspan(0, recv_seg.size()); - rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; + rc = std::move(rc) << [&] { + return prev_ch->RecvAll(seg); + } << [&] { + return comm.Block(); + }; if (!rc.OK()) { return rc; } @@ -68,6 +115,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons if (comm.World() == 1) { return Success(); } + if (data.size_bytes() == 0) { + return Success(); + } return DispatchDType(type, [&](auto t) { using T = decltype(t); // Divide the data into segments according to the number of workers. @@ -75,7 +125,11 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons CHECK_EQ(data.size_bytes() % n_bytes_elem, 0); auto n = data.size_bytes() / n_bytes_elem; auto world = comm.World(); - auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T); + if (n < static_cast(world)) { + return RingAllreduceSmall(comm, data, op); + } + + auto n_bytes_in_seg = (n / world) * sizeof(T); auto rc = RingScatterReduceTyped(comm, data, n_bytes_in_seg, op); if (!rc.OK()) { return rc; @@ -88,7 +142,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons return std::move(rc) << [&] { return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); - } << [&] { return comm.Block(); }; + } << [&] { + return comm.Block(); + }; }); } } // namespace xgboost::collective::cpu_impl diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 1f47d0c55..c6d03c6df 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -104,9 +104,8 @@ bool constexpr IsFloatingPointV() { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { - return RingAllgather(comm, data, size); +[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data) { + return RingAllgather(comm, data); } [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, diff --git a/src/collective/coll.cu b/src/collective/coll.cu index d1b66a8ce..b06435bfe 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -1,10 +1,9 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include // for int8_t, int64_t -#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../data/array_interface.h" #include "allgather.h" // for AllgatherVOffset @@ -162,14 +161,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { } << [&] { return nccl->Block(); }; } -[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data) { if (!comm.IsDistributed()) { return Success(); } auto nccl = dynamic_cast(&comm); CHECK(nccl); auto stub = nccl->Stub(); + auto size = data.size_bytes() / comm.World(); auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh index 6ededd101..4d45295d7 100644 --- a/src/collective/coll.cuh +++ b/src/collective/coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -8,8 +8,7 @@ #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "coll.h" // for Coll #include "comm.h" // for Comm -#include "nccl_stub.h" -#include "xgboost/span.h" // for Span +#include "xgboost/span.h" // for Span namespace xgboost::collective { class NCCLColl : public Coll { @@ -20,8 +19,7 @@ class NCCLColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const& comm, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const& comm, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/src/collective/coll.h b/src/collective/coll.h index 1afc8ed59..96fe35229 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this { * @brief Allgather * * @param [in,out] data Data buffer for input and output. - * @param [in] size Size of data for each worker. */ - [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data, - std::int64_t size); + [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data); /** * @brief Allgather with variable length. * diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 8260b28f6..23a8e89ed 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm.h" @@ -9,8 +9,9 @@ #include // for shared_ptr #include // for string #include // for move, forward - -#include "../common/common.h" // for AssertGPUSupport +#if !defined(XGBOOST_USE_NCCL) +#include "../common/common.h" // for AssertNCCLSupport +#endif // !defined(XGBOOST_USE_NCCL) #include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE @@ -21,11 +22,7 @@ namespace xgboost::collective { Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, std::string task_id) - : timeout_{timeout}, - retry_{retry}, - tracker_{host, port, -1}, - task_id_{std::move(task_id)}, - loop_{std::shared_ptr{new Loop{timeout}}} {} + : timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {} Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, std::string const& task_id, TCPSocket* out, std::int32_t rank, @@ -191,6 +188,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se std::int32_t retry, std::string task_id, StringView nccl_path) : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); if (!rc.OK()) { SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); @@ -254,9 +252,6 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // get ring neighbors std::string snext; tracker.Recv(&snext); - if (!rc.OK()) { - return Fail("Failed to receive the rank for the next worker.", std::move(rc)); - } auto jnext = Json::Load(StringView{snext}); proto::PeerInfo ninfo{jnext}; @@ -295,6 +290,10 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::Shutdown() { + if (!this->IsDistributed()) { + return Success(); + } + TCPSocket tracker; return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); @@ -308,6 +307,11 @@ RabitComm::~RabitComm() noexcept(false) { if (n_bytes != scmd.size()) { return Fail("Faled to send cmd."); } + + this->ResetState(); + return Success(); + } << [&] { + this->channels_.clear(); return Success(); }; } diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 56681253c..8788a2436 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -80,7 +80,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); GetCudaUUID(s_this_uuid, ctx->Device()); - auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); + auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); CHECK(rc.OK()) << rc.Report(); diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index a818d95f8..4add9ca61 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -50,6 +50,10 @@ class NCCLComm : public Comm { auto rc = this->Stream().Sync(false); return GetCUDAResult(rc); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } }; class NCCLChannel : public Channel { diff --git a/src/collective/comm.h b/src/collective/comm.h index 82aa2c45e..6ad5bc5c1 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -14,7 +14,7 @@ #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName #include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span @@ -54,8 +54,12 @@ class Comm : public std::enable_shared_from_this { std::thread error_worker_; std::string task_id_; std::vector> channels_; - std::shared_ptr loop_{new Loop{std::chrono::seconds{ - DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout + std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout + + void ResetState() { + this->world_ = -1; + this->rank_ = 0; + } public: Comm() = default; @@ -78,7 +82,10 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Rank() const { return rank_; } [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } [[nodiscard]] bool IsDistributed() const { return world_ != -1; } - void Submit(Loop::Op op) const { loop_->Submit(op); } + void Submit(Loop::Op op) const { + CHECK(loop_); + loop_->Submit(op); + } [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { @@ -88,6 +95,14 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] virtual Result ProcessorName(std::string* out) const { + auto rc = GetHostName(out); + return rc; + } + [[nodiscard]] virtual Result Shutdown() = 0; }; /** @@ -105,7 +120,7 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown(); + [[nodiscard]] Result Shutdown() final; public: // bootstrapping construction. diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index f7bbba754..7408882f6 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -1,22 +1,21 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm_group.h" #include // for transform +#include // for tolower #include // for seconds #include // for int32_t +#include // for back_inserter #include // for shared_ptr, unique_ptr #include // for string -#include // for vector -#include "../common/json_utils.h" // for OptionalArg -#include "coll.h" // for Coll -#include "comm.h" // for Comm -#include "tracker.h" // for GetHostAddress -#include "xgboost/collective/result.h" // for Result -#include "xgboost/context.h" // for DeviceOrd -#include "xgboost/json.h" // for Json +#include "../common/json_utils.h" // for OptionalArg +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "xgboost/context.h" // for DeviceOrd +#include "xgboost/json.h" // for Json #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_coll.h" @@ -117,6 +116,8 @@ void GlobalCommGroupInit(Json config) { void GlobalCommGroupFinalize() { auto& sptr = GlobalCommGroup(); + auto rc = sptr->Finalize(); sptr.reset(); + SafeColl(rc); } } // namespace xgboost::collective diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 2f6f91d73..61a58ba56 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -9,7 +9,6 @@ #include "coll.h" // for Comm #include "comm.h" // for Coll #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for GetHostName namespace xgboost::collective { /** @@ -35,15 +34,31 @@ class CommGroup { [[nodiscard]] auto Rank() const { return comm_->Rank(); } [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] Result Finalize() const { + return Success() << [this] { + if (gpu_comm_) { + return gpu_comm_->Shutdown(); + } + return Success(); + } << [&] { + return comm_->Shutdown(); + }; + } + [[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; + /** + * @brief Decide the context to use for communication. + * + * @param ctx Global context, provides the CUDA stream and ordinal. + * @param device The device used by the data to be communicated. + */ [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } [[nodiscard]] Result ProcessorName(std::string* out) const { - auto rc = GetHostName(out); - return rc; + return this->comm_->ProcessorName(out); } }; diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index f9ac52007..e9c69f537 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -32,7 +32,8 @@ class InMemoryHandler { * * This is used when the handler only needs to be initialized once with a known world size. */ - explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {} + explicit InMemoryHandler(std::int32_t worldSize) + : world_size_{static_cast(worldSize)} {} /** * @brief Initialize the handler with the world size and rank. diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index decad8786..b6158693b 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -34,7 +34,7 @@ class Worker : public WorkerForTest { std::vector data(comm_.World(), 0); data[comm_.Rank()] = comm_.Rank(); - auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1); + auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); for (std::int32_t r = 0; r < comm_.World(); ++r) { @@ -51,7 +51,7 @@ class Worker : public WorkerForTest { auto seg = s_data.subspan(comm_.Rank() * n, n); std::iota(seg.begin(), seg.end(), comm_.Rank()); - auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n); + auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); for (std::int32_t r = 0; r < comm_.World(); ++r) { @@ -104,7 +104,7 @@ class Worker : public WorkerForTest { std::vector sizes(comm_.World(), 0); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); std::shared_ptr pcoll{new Coll{}}; diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index 236108198..98ece7d17 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -33,7 +33,7 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(comm_.World(), -1); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); // create result dh::device_vector result(comm_.World(), -1); @@ -57,7 +57,7 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(nccl_comm_->World(), 0); sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); ASSERT_TRUE(rc.OK()) << rc.Report(); auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); // create result diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 8359d17a6..457594cd9 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -7,7 +7,6 @@ #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll -#include "../../../src/collective/tracker.h" #include "../../../src/common/type.h" // for EraseType #include "test_worker.h" // for WorkerForTest, TestDistributed diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index 04ec9f773..f7e11dec2 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -5,7 +5,7 @@ #include #include // for host_vector -#include "../../../src/common/common.h" +#include "../../../src/common/common.h" // for AllVisibleGPUs #include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector #include "../../../src/common/type.h" // for EraseType #include "test_worker.cuh" // for NCCLWorkerForTest diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index ad053f286..6b7000ef9 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -60,8 +60,7 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), - sizeof(int)); + auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index a6ec7e352..237bdeb9d 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -5,13 +5,13 @@ #include #include // for Result +#include "../../../../src/collective/allreduce.h" #include "../../../../src/common/common.h" // for AllVisibleGPUs #include "../../../../src/common/device_helpers.cuh" // for device_vector #include "../../../../src/common/type.h" // for EraseType #include "../../collective/test_worker.h" // for SocketTest #include "../../helpers.h" // for MakeCUDACtx #include "federated_coll.cuh" -#include "federated_comm.cuh" #include "test_worker.h" // for TestFederated namespace xgboost::collective { @@ -71,7 +71,7 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int)); + auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index cec180e70..b96524878 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -26,7 +26,6 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { namespace { void VerifyAllReduceSum() { auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); auto const device = GPUIDX; int count = 3; common::SetDevice(device);