[coll] Add global functions. (#10203)
This commit is contained in:
@@ -165,7 +165,7 @@ template <typename T>
|
||||
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
|
||||
std::array<T, 2> results{dividend, divisor};
|
||||
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
|
||||
collective::SafeColl(rc);
|
||||
SafeColl(rc);
|
||||
std::tie(dividend, divisor) = std::tuple_cat(results);
|
||||
if (divisor <= 0) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
@@ -33,6 +33,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
|
||||
auto send_seg = data.subspan(send_off, send_nbytes);
|
||||
CHECK_NE(send_seg.size(), 0);
|
||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
@@ -40,9 +41,10 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
bool is_last_segment = recv_rank == (world - 1);
|
||||
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
|
||||
auto recv_seg = data.subspan(recv_off, recv_nbytes);
|
||||
CHECK_NE(recv_seg.size(), 0);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] {
|
||||
return prev_ch->Block();
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
@@ -106,4 +108,47 @@ namespace detail {
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
|
||||
auto n_inputs = input.size();
|
||||
std::vector<std::int64_t> sizes(n_inputs);
|
||||
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||
[](auto const& vec) { return vec.size(); });
|
||||
|
||||
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
|
||||
|
||||
HostDeviceVector<std::int8_t> recv;
|
||||
auto rc =
|
||||
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
|
||||
SafeColl(rc);
|
||||
|
||||
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
|
||||
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||
offset[0] = 0;
|
||||
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||
}
|
||||
|
||||
std::vector<char> collected;
|
||||
for (auto const& vec : input) {
|
||||
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||
}
|
||||
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
|
||||
&recv);
|
||||
SafeColl(rc);
|
||||
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
|
||||
|
||||
std::vector<std::vector<char>> result;
|
||||
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||
result.emplace_back(std::move(local));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, std::vector<std::vector<char>> const& input) {
|
||||
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -102,4 +102,115 @@ template <typename T>
|
||||
|
||||
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
|
||||
auto const& cctx = comm.Ctx(ctx, data.Device());
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Allgather(cctx, erased);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather all data from all workers.
|
||||
*
|
||||
* @param data The input and output buffer, needs to be pre-allocated by the caller.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
|
||||
auto const& cg = *GlobalCommGroup();
|
||||
if (data.Size() % cg.World() != 0) {
|
||||
return Fail("The total number of elements should be multiple of the number of workers.");
|
||||
}
|
||||
return Allgather(ctx, cg, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data,
|
||||
std::vector<std::int64_t>* recv_segments,
|
||||
HostDeviceVector<std::int8_t>* recv) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
std::vector<std::int64_t> sizes(comm.World(), 0);
|
||||
sizes[comm.Rank()] = data.Values().size_bytes();
|
||||
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
|
||||
auto rc = comm.Backend(DeviceOrd::CPU())
|
||||
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
recv_segments->resize(sizes.size() + 1);
|
||||
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
|
||||
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
|
||||
recv->SetDevice(data.Device());
|
||||
recv->Resize(total_bytes);
|
||||
|
||||
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
|
||||
|
||||
auto backend = comm.Backend(data.Device());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
|
||||
return backend->AllgatherV(
|
||||
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
|
||||
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allgather with variable length data.
|
||||
*
|
||||
* @param data The input data.
|
||||
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
|
||||
* from the first worker, [2, 5) elements are from the second one.
|
||||
* @param recv The buffer storing the result.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
|
||||
std::vector<std::int64_t>* recv_segments,
|
||||
HostDeviceVector<std::int8_t>* recv) {
|
||||
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||
* across different workers. Along with which, the size of each vector in
|
||||
* the input can also vary.
|
||||
*
|
||||
* @return The AllgatherV result, containing vectors from all workers.
|
||||
*/
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, std::vector<std::vector<char>> const& input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
|
||||
std::vector<std::string>* p_result) {
|
||||
std::vector<std::vector<char>> inputs(input.size());
|
||||
for (std::size_t i = 0; i < input.size(); ++i) {
|
||||
inputs[i] = {input[i].cbegin(), input[i].cend()};
|
||||
}
|
||||
Context ctx;
|
||||
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
|
||||
auto& result = *p_result;
|
||||
result.resize(out.size());
|
||||
for (std::size_t i = 0; i < out.size(); ++i) {
|
||||
result[i] = {out[i].cbegin(), out[i].cend()};
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -68,39 +68,35 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
// send to ring next
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = send_rank * n_bytes_in_seg;
|
||||
common::Span<std::int8_t> seg, recv_seg;
|
||||
auto rc = Success() << [&] {
|
||||
// send to ring next
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = send_rank * n_bytes_in_seg;
|
||||
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
|
||||
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
|
||||
auto rc = next_ch->SendAll(send_seg);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
return next_ch->SendAll(send_seg);
|
||||
} << [&] {
|
||||
// receive from ring prev
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = recv_rank * n_bytes_in_seg;
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = recv_rank * n_bytes_in_seg;
|
||||
bool is_last_segment = recv_rank == (world - 1);
|
||||
|
||||
is_last_segment = recv_rank == (world - 1);
|
||||
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
|
||||
seg_nbytes = is_last_segment ? data.size_bytes() - recv_off : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
rc = std::move(rc) << [&] {
|
||||
recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
seg = s_buf.subspan(0, recv_seg.size());
|
||||
return prev_ch->RecvAll(seg);
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int8_t
|
||||
#include <functional> // for function
|
||||
#include <type_traits> // for is_invocable_v, enable_if_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/type.h" // for EraseType, RestoreType
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm, RestoreType
|
||||
#include "comm_group.h" // for GlobalCommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
||||
auto erased = common::EraseType(data);
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||
common::Span<std::int8_t> out) {
|
||||
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
|
||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||
auto rhs_t = common::RestoreType<T>(out);
|
||||
@@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
||||
|
||||
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
||||
}
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
|
||||
linalg::TensorView<T, kDim> data, Op op) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
|
||||
}
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
|
||||
return Allreduce(ctx, *GlobalCommGroup(), data, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specialization for std::vector.
|
||||
*/
|
||||
template <typename T, typename Alloc>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
|
||||
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specialization for scalar value.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
|
||||
Allreduce(Context const* ctx, T* data, Op op) {
|
||||
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for
|
||||
#include "../common/type.h"
|
||||
#include "comm.h" // for Comm, EraseType
|
||||
#include "comm_group.h" // for CommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/linalg.h" // for VectorView
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -23,4 +27,21 @@ template <typename T>
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
return cpu_impl::Broadcast(comm, erased, root);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data, std::int32_t root) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
|
||||
return Broadcast(ctx, *GlobalCommGroup(), data, root);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
auto rc = stub->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
@@ -81,8 +81,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
|
||||
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
std::size_t j = 0;
|
||||
@@ -103,7 +102,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
[&] {
|
||||
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||
};
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
@@ -114,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
|
||||
#include <thrust/logical.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
#include <thrust/transform_scan.h>
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#define COMMON_HIST_UTIL_CUH_
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/sort.h> // for sort
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* Copyright 2019-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/sort.h> // for sort
|
||||
#include <thrust/transform.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
Reference in New Issue
Block a user