[coll] Allreduce. (#9679)

This commit is contained in:
Jiaming Yuan
2023-10-17 13:57:14 +08:00
committed by GitHub
parent da6803b75b
commit 48ac9b6cbe
8 changed files with 301 additions and 81 deletions

View File

@@ -0,0 +1,90 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "allreduce.h"
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <vector> // for vector
#include "../data/array_interface.h" // for Type, DispatchDType
#include "allgather.h" // for RingAllgather
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
template <typename T>
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();
auto dst_rank = BootstrapNext(rank, world);
auto src_rank = BootstrapPrev(rank, world);
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
send_off = std::min(send_off, data.size_bytes());
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);
next_ch->SendAll(send_seg);
// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
recv_off = std::min(recv_off, data.size_bytes());
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
op(seg, recv_seg);
}
return Success();
}
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
auto n_bytes_elem = sizeof(T);
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
}
auto prev = BootstrapPrev(comm.Rank(), comm.World());
auto next = BootstrapNext(comm.Rank(), comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
});
}
} // namespace xgboost::collective::cpu_impl

View File

@@ -0,0 +1,39 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
using Func =
std::function<void(common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out)>;
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type);
} // namespace cpu_impl
template <typename T, typename Fn>
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
Comm const& comm, common::Span<T> data, Fn redop) {
auto erased = EraseType(data);
auto type = ToDType<T>::kType;
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = RestoreType<T const>(lhs);
auto rhs_t = RestoreType<T>(out);
redop(lhs_t, rhs_t);
};
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}
} // namespace xgboost::collective

View File

@@ -16,7 +16,7 @@
#include <utility>
#include <vector>
#include "../common/bitfield.h"
#include "../common/bitfield.h" // for RBitField8
#include "../common/common.h"
#include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h"
@@ -104,7 +104,20 @@ struct ArrayInterfaceErrors {
*/
class ArrayInterfaceHandler {
public:
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
enum Type : std::int8_t {
kF2 = 0,
kF4 = 1,
kF8 = 2,
kF16 = 3,
kI1 = 4,
kI2 = 5,
kI4 = 6,
kI8 = 7,
kU1 = 8,
kU2 = 9,
kU4 = 10,
kU8 = 11,
};
template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
@@ -587,6 +600,57 @@ class ArrayInterface {
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
};
template <typename Fn>
auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
switch (dtype) {
case ArrayInterfaceHandler::kF2: {
#if defined(XGBOOST_USE_CUDA)
return dispatch(__half{});
#else
LOG(FATAL) << "half type is only supported for CUDA input.";
break;
#endif
}
case ArrayInterfaceHandler::kF4: {
return dispatch(float{});
}
case ArrayInterfaceHandler::kF8: {
return dispatch(double{});
}
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(T) == 16) << error::NoF128();
return dispatch(T{});
}
case ArrayInterfaceHandler::kI1: {
return dispatch(std::int8_t{});
}
case ArrayInterfaceHandler::kI2: {
return dispatch(std::int16_t{});
}
case ArrayInterfaceHandler::kI4: {
return dispatch(std::int32_t{});
}
case ArrayInterfaceHandler::kI8: {
return dispatch(std::int64_t{});
}
case ArrayInterfaceHandler::kU1: {
return dispatch(std::uint8_t{});
}
case ArrayInterfaceHandler::kU2: {
return dispatch(std::uint16_t{});
}
case ArrayInterfaceHandler::kU4: {
return dispatch(std::uint32_t{});
}
case ArrayInterfaceHandler::kU8: {
return dispatch(std::uint64_t{});
}
}
return std::result_of_t<Fn(std::int8_t)>();
}
template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
// Only used for cuDF at the moment.
@@ -602,60 +666,7 @@ void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
std::numeric_limits<std::size_t>::max()},
array.shape, array.strides, device});
};
switch (array.type) {
case ArrayInterfaceHandler::kF2: {
#if defined(XGBOOST_USE_CUDA)
dispatch(__half{});
#endif
break;
}
case ArrayInterfaceHandler::kF4: {
dispatch(float{});
break;
}
case ArrayInterfaceHandler::kF8: {
dispatch(double{});
break;
}
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(long double) == 16) << error::NoF128();
dispatch(T{});
break;
}
case ArrayInterfaceHandler::kI1: {
dispatch(std::int8_t{});
break;
}
case ArrayInterfaceHandler::kI2: {
dispatch(std::int16_t{});
break;
}
case ArrayInterfaceHandler::kI4: {
dispatch(std::int32_t{});
break;
}
case ArrayInterfaceHandler::kI8: {
dispatch(std::int64_t{});
break;
}
case ArrayInterfaceHandler::kU1: {
dispatch(std::uint8_t{});
break;
}
case ArrayInterfaceHandler::kU2: {
dispatch(std::uint16_t{});
break;
}
case ArrayInterfaceHandler::kU4: {
dispatch(std::uint32_t{});
break;
}
case ArrayInterfaceHandler::kU8: {
dispatch(std::uint64_t{});
break;
}
}
DispatchDType(array.type, dispatch);
}
/**