Update collective implementation. (#10152)
* Update collective implementation. - Cleanup resource during `Finalize` to avoid handling threads in destructor. - Calculate the size for allgather automatically. - Use simple allgather for small (smaller than the number of worker) allreduce.
This commit is contained in:
parent
230010d9a0
commit
8bad677c2f
@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (comm.Rank() == root) {
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
} else {
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
}
|
||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
auto size = data.size_bytes() / comm.World();
|
||||
|
||||
auto offset = comm.Rank() * size;
|
||||
auto segment = data.subspan(offset, size);
|
||||
|
||||
@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#include "../../src/collective/comm.h" // for Comm, Coll
|
||||
#include "federated_coll.h" // for FederatedColl
|
||||
@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedColl : public Coll {
|
||||
@ -20,8 +17,7 @@ class FederatedColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/comm.h" // for HostComm
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
|
||||
std::int32_t rank) {
|
||||
this->Init(host, port, world, rank, {}, {}, {});
|
||||
}
|
||||
[[nodiscard]] Result Shutdown() final {
|
||||
this->ResetState();
|
||||
return Success();
|
||||
}
|
||||
~FederatedComm() override { stub_.reset(); }
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||
@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
/**
|
||||
* @brief Get a string ID for the current process.
|
||||
*/
|
||||
[[nodiscard]] Result ProcessorName(std::string* out) const final {
|
||||
auto rank = this->Rank();
|
||||
*out = "rank:" + std::to_string(rank);
|
||||
return Success();
|
||||
};
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -1,22 +1,18 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
@ -17,8 +17,7 @@ namespace xgboost::collective {
|
||||
namespace federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
@ -694,9 +694,9 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
||||
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
|
||||
{size}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json interface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{interface}.is_contiguous);
|
||||
str = Json::Dump(interface);
|
||||
Json iface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{iface}.is_contiguous);
|
||||
str = Json::Dump(iface);
|
||||
return str;
|
||||
};
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
@ -10,6 +9,7 @@
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@ -40,17 +40,27 @@ struct CollAPIEntry {
|
||||
};
|
||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr) {
|
||||
std::chrono::seconds wait_for{100};
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{60};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
|
||||
auto fut = ptr->second;
|
||||
while (fut.valid()) {
|
||||
auto res = fut.wait_for(wait_for);
|
||||
CHECK(res != std::future_status::deferred);
|
||||
|
||||
if (res == std::future_status::ready) {
|
||||
auto const &rc = ptr->second.get();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
collective::SafeColl(rc);
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -106,14 +116,17 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
WaitImpl(ptr);
|
||||
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||
// interrupt the model training.
|
||||
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
WaitImpl(ptr);
|
||||
WaitImpl(ptr, ptr->first->Timeout());
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <utility> // for move
|
||||
|
||||
#include "broadcast.h"
|
||||
#include "comm.h" // for Comm, Channel
|
||||
@ -29,16 +30,20 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
|
||||
auto send_seg = data.subspan(send_off, send_nbytes);
|
||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
bool is_last_segment = recv_rank == (world - 1);
|
||||
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
|
||||
auto recv_seg = data.subspan(recv_off, recv_nbytes);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
} << [&] {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -91,7 +96,9 @@ namespace detail {
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
} << [&] {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -1,25 +1,27 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for remove_cv_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "comm_group.h" // for CommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/linalg.h" // for MakeVec
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||
* worker_off = 1, then it owns the third segment.
|
||||
* worker_off = 1, then it owns the third segment (2 + 1).
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
|
||||
// This function is also used for ring allreduce, hence we allow the last segment to be
|
||||
// larger due to round-down.
|
||||
auto n_bytes_per_segment = data.size_bytes() / comm.World();
|
||||
auto erased = common::EraseType(data);
|
||||
|
||||
auto rank = comm.Rank();
|
||||
@ -61,7 +65,7 @@ template <typename T>
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -76,7 +80,7 @@ template <typename T>
|
||||
|
||||
std::vector<std::int64_t> sizes(world, 0);
|
||||
sizes[rank] = data.size_bytes();
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "allreduce.h"
|
||||
|
||||
@ -16,7 +16,44 @@
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace {
|
||||
template <typename T>
|
||||
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
|
||||
auto rank = comm.Rank();
|
||||
auto world = comm.World();
|
||||
|
||||
auto next_ch = comm.Chan(BootstrapNext(rank, world));
|
||||
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));
|
||||
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
|
||||
auto s_buffer = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
auto offset = data.size_bytes() * rank;
|
||||
auto self = s_buffer.subspan(offset, data.size_bytes());
|
||||
std::copy_n(data.data(), data.size_bytes(), self.data());
|
||||
|
||||
auto typed = common::RestoreType<T>(s_buffer);
|
||||
auto rc = RingAllgather(comm, typed);
|
||||
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||
CHECK_EQ(first.size(), data.size());
|
||||
|
||||
for (std::int32_t r = 1; r < world; ++r) {
|
||||
auto offset = data.size_bytes() * r;
|
||||
auto buf = s_buffer.subspan(offset, data.size_bytes());
|
||||
op(buf, first);
|
||||
}
|
||||
std::copy_n(first.data(), first.size(), data.data());
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
// note that n_bytes_in_seg is calculated with round-down.
|
||||
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t n_bytes_in_seg, Func const& op) {
|
||||
auto rank = comm.Rank();
|
||||
@ -27,14 +64,17 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
// send to ring next
|
||||
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = send_rank * n_bytes_in_seg;
|
||||
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
|
||||
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
|
||||
auto rc = next_ch->SendAll(send_seg);
|
||||
@ -43,14 +83,21 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
}
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = recv_rank * n_bytes_in_seg;
|
||||
|
||||
is_last_segment = recv_rank == (world - 1);
|
||||
|
||||
seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
||||
rc = std::move(rc) << [&] {
|
||||
return prev_ch->RecvAll(seg);
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -68,6 +115,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
if (comm.World() == 1) {
|
||||
return Success();
|
||||
}
|
||||
if (data.size_bytes() == 0) {
|
||||
return Success();
|
||||
}
|
||||
return DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
// Divide the data into segments according to the number of workers.
|
||||
@ -75,7 +125,11 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
||||
auto n = data.size_bytes() / n_bytes_elem;
|
||||
auto world = comm.World();
|
||||
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
|
||||
if (n < static_cast<decltype(n)>(world)) {
|
||||
return RingAllreduceSmall<T>(comm, data, op);
|
||||
}
|
||||
|
||||
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
@ -88,7 +142,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
|
||||
return std::move(rc) << [&] {
|
||||
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
} << [&] { return comm.Block(); };
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
|
||||
@ -104,9 +104,8 @@ bool constexpr IsFloatingPointV() {
|
||||
return cpu_impl::Broadcast(comm, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
return RingAllgather(comm, data, size);
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||
return RingAllgather(comm, data);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../data/array_interface.h"
|
||||
#include "allgather.h" // for AllgatherVOffset
|
||||
@ -162,14 +161,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
auto size = data.size_bytes() / comm.World();
|
||||
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -8,8 +8,7 @@
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "nccl_stub.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLColl : public Coll {
|
||||
@ -20,8 +19,7 @@ class NCCLColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @brief Allgather
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] size Size of data for each worker.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size);
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data);
|
||||
/**
|
||||
* @brief Allgather with variable length.
|
||||
*
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "comm.h"
|
||||
|
||||
@ -9,8 +9,9 @@
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
#include "../common/common.h" // for AssertNCCLSupport
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
@ -21,11 +22,7 @@
|
||||
namespace xgboost::collective {
|
||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: timeout_{timeout},
|
||||
retry_{retry},
|
||||
tracker_{host, port, -1},
|
||||
task_id_{std::move(task_id)},
|
||||
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
|
||||
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
|
||||
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
@ -191,6 +188,7 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||
nccl_path_{std::move(nccl_path)} {
|
||||
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
if (!rc.OK()) {
|
||||
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||
@ -254,9 +252,6 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
// get ring neighbors
|
||||
std::string snext;
|
||||
tracker.Recv(&snext);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||
}
|
||||
auto jnext = Json::Load(StringView{snext});
|
||||
|
||||
proto::PeerInfo ninfo{jnext};
|
||||
@ -295,6 +290,10 @@ RabitComm::~RabitComm() noexcept(false) {
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Shutdown() {
|
||||
if (!this->IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
TCPSocket tracker;
|
||||
return Success() << [&] {
|
||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||
@ -308,6 +307,11 @@ RabitComm::~RabitComm() noexcept(false) {
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Faled to send cmd.");
|
||||
}
|
||||
|
||||
this->ResetState();
|
||||
return Success();
|
||||
} << [&] {
|
||||
this->channels_.clear();
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
|
||||
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
|
||||
@ -50,6 +50,10 @@ class NCCLComm : public Comm {
|
||||
auto rc = this->Stream().Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
[[nodiscard]] Result Shutdown() final {
|
||||
this->ResetState();
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
#include "loop.h" // for Loop
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
@ -54,8 +54,12 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
std::thread error_worker_;
|
||||
std::string task_id_;
|
||||
std::vector<std::shared_ptr<Channel>> channels_;
|
||||
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
|
||||
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
|
||||
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
|
||||
|
||||
void ResetState() {
|
||||
this->world_ = -1;
|
||||
this->rank_ = 0;
|
||||
}
|
||||
|
||||
public:
|
||||
Comm() = default;
|
||||
@ -78,7 +82,10 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
void Submit(Loop::Op op) const {
|
||||
CHECK(loop_);
|
||||
loop_->Submit(op);
|
||||
}
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
@ -88,6 +95,14 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
/**
|
||||
* @brief Get a string ID for the current process.
|
||||
*/
|
||||
[[nodiscard]] virtual Result ProcessorName(std::string* out) const {
|
||||
auto rc = GetHostName(out);
|
||||
return rc;
|
||||
}
|
||||
[[nodiscard]] virtual Result Shutdown() = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -105,7 +120,7 @@ class RabitComm : public HostComm {
|
||||
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
[[nodiscard]] Result Shutdown() final;
|
||||
|
||||
public:
|
||||
// bootstrapping construction.
|
||||
|
||||
@ -1,22 +1,21 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "comm_group.h"
|
||||
|
||||
#include <algorithm> // for transform
|
||||
#include <cctype> // for tolower
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <iterator> // for back_inserter
|
||||
#include <memory> // for shared_ptr, unique_ptr
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "tracker.h" // for GetHostAddress
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for DeviceOrd
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/context.h" // for DeviceOrd
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_coll.h"
|
||||
@ -117,6 +116,8 @@ void GlobalCommGroupInit(Json config) {
|
||||
|
||||
void GlobalCommGroupFinalize() {
|
||||
auto& sptr = GlobalCommGroup();
|
||||
auto rc = sptr->Finalize();
|
||||
sptr.reset();
|
||||
SafeColl(rc);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#include "coll.h" // for Comm
|
||||
#include "comm.h" // for Coll
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for GetHostName
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
@ -35,15 +34,31 @@ class CommGroup {
|
||||
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
||||
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] Result Finalize() const {
|
||||
return Success() << [this] {
|
||||
if (gpu_comm_) {
|
||||
return gpu_comm_->Shutdown();
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
return comm_->Shutdown();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] static CommGroup* Create(Json config);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
||||
/**
|
||||
* @brief Decide the context to use for communication.
|
||||
*
|
||||
* @param ctx Global context, provides the CUDA stream and ordinal.
|
||||
* @param device The device used by the data to be communicated.
|
||||
*/
|
||||
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
||||
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
||||
|
||||
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
||||
auto rc = GetHostName(out);
|
||||
return rc;
|
||||
return this->comm_->ProcessorName(out);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -32,7 +32,8 @@ class InMemoryHandler {
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
|
||||
explicit InMemoryHandler(std::int32_t worldSize)
|
||||
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
|
||||
@ -34,7 +34,7 @@ class Worker : public WorkerForTest {
|
||||
std::vector<std::int32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = comm_.Rank();
|
||||
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1);
|
||||
auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
@ -51,7 +51,7 @@ class Worker : public WorkerForTest {
|
||||
auto seg = s_data.subspan(comm_.Rank() * n, n);
|
||||
std::iota(seg.begin(), seg.end(), comm_.Rank());
|
||||
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n);
|
||||
auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
|
||||
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||
@ -104,7 +104,7 @@ class Worker : public WorkerForTest {
|
||||
|
||||
std::vector<std::int64_t> sizes(comm_.World(), 0);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
std::shared_ptr<Coll> pcoll{new Coll{}};
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <gtest/gtest.h>
|
||||
@ -33,7 +33,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(comm_.World(), -1);
|
||||
sizes[comm_.Rank()] = s_data.size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
// create result
|
||||
dh::device_vector<std::int32_t> result(comm_.World(), -1);
|
||||
@ -57,7 +57,7 @@ class Worker : public NCCLWorkerForTest {
|
||||
// get size
|
||||
std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
|
||||
sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
// create result
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/host_vector.h> // for host_vector
|
||||
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector
|
||||
#include "../../../src/common/type.h" // for EraseType
|
||||
#include "test_worker.cuh" // for NCCLWorkerForTest
|
||||
|
||||
@ -60,8 +60,7 @@ TEST_F(FederatedCollTest, Allgather) {
|
||||
|
||||
std::vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
|
||||
sizeof(int));
|
||||
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/result.h> // for Result
|
||||
|
||||
#include "../../../../src/collective/allreduce.h"
|
||||
#include "../../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../../../../src/common/device_helpers.cuh" // for device_vector
|
||||
#include "../../../../src/common/type.h" // for EraseType
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "../../helpers.h" // for MakeCUDACtx
|
||||
#include "federated_coll.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "test_worker.h" // for TestFederated
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -71,7 +71,7 @@ void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::
|
||||
|
||||
dh::device_vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int));
|
||||
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
|
||||
@ -26,7 +26,6 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||
namespace {
|
||||
void VerifyAllReduceSum() {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const device = GPUIDX;
|
||||
int count = 3;
|
||||
common::SetDevice(device);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user