diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 808960319..37511ec62 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 43bfcf7c1..611cff874 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/tracker.o \ diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc new file mode 100644 index 000000000..6948f6758 --- /dev/null +++ b/src/collective/allreduce.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "allreduce.h" + +#include // for min +#include // for size_t +#include // for int32_t, int8_t +#include // for vector + +#include "../data/array_interface.h" // for Type, DispatchDType +#include "allgather.h" // for RingAllgather +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective::cpu_impl { +template +Result RingScatterReduceTyped(Comm const& comm, common::Span data, + std::size_t n_bytes_in_seg, Func const& op) { + auto rank = comm.Rank(); + auto world = comm.World(); + + auto dst_rank = BootstrapNext(rank, world); + auto src_rank = BootstrapPrev(rank, world); + auto next_ch = comm.Chan(dst_rank); + auto prev_ch = comm.Chan(src_rank); + + std::vector buffer(n_bytes_in_seg, 0); + auto s_buf = common::Span{buffer.data(), buffer.size()}; + + for (std::int32_t r = 0; r < world - 1; ++r) { + // send to ring next + auto send_off = ((rank + world - r) % world) * n_bytes_in_seg; + send_off = std::min(send_off, data.size_bytes()); + auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); + auto send_seg = data.subspan(send_off, seg_nbytes); + + next_ch->SendAll(send_seg); + + // receive from ring prev + auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; + recv_off = std::min(recv_off, data.size_bytes()); + seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg); + CHECK_EQ(seg_nbytes % sizeof(T), 0); + auto recv_seg = data.subspan(recv_off, seg_nbytes); + auto seg = s_buf.subspan(0, recv_seg.size()); + + prev_ch->RecvAll(seg); + auto rc = prev_ch->Block(); + if (!rc.OK()) { + return rc; + } + + // accumulate to recv_seg + CHECK_EQ(seg.size(), recv_seg.size()); + op(seg, recv_seg); + } + + return Success(); +} + +Result RingAllreduce(Comm const& comm, common::Span data, Func const& op, + ArrayInterfaceHandler::Type type) { + return DispatchDType(type, [&](auto t) { + using T = decltype(t); + // Divide the data into segments according to the number of workers. + auto n_bytes_elem = sizeof(T); + CHECK_EQ(data.size_bytes() % n_bytes_elem, 0); + auto n = data.size_bytes() / n_bytes_elem; + auto world = comm.World(); + auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T); + auto rc = RingScatterReduceTyped(comm, data, n_bytes_in_seg, op); + if (!rc.OK()) { + return rc; + } + + auto prev = BootstrapPrev(comm.Rank(), comm.World()); + auto next = BootstrapNext(comm.Rank(), comm.World()); + auto prev_ch = comm.Chan(prev); + auto next_ch = comm.Chan(next); + + rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); + if (!rc.OK()) { + return rc; + } + return comm.Block(); + }); +} +} // namespace xgboost::collective::cpu_impl diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h new file mode 100644 index 000000000..e3f8ab5b8 --- /dev/null +++ b/src/collective/allreduce.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for int8_t +#include // for function +#include // for is_invocable_v + +#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "comm.h" // for Comm, RestoreType +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +namespace cpu_impl { +using Func = + std::function lhs, common::Span out)>; + +Result RingAllreduce(Comm const& comm, common::Span data, Func const& op, + ArrayInterfaceHandler::Type type); +} // namespace cpu_impl + +template +std::enable_if_t, common::Span>, Result> Allreduce( + Comm const& comm, common::Span data, Fn redop) { + auto erased = EraseType(data); + auto type = ToDType::kType; + + auto erased_fn = [type, redop](common::Span lhs, + common::Span out) { + CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; + auto lhs_t = RestoreType(lhs); + auto rhs_t = RestoreType(out); + redop(lhs_t, rhs_t); + }; + + return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); +} +} // namespace xgboost::collective diff --git a/src/data/array_interface.h b/src/data/array_interface.h index c62a5cef2..0170e6a84 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -16,7 +16,7 @@ #include #include -#include "../common/bitfield.h" +#include "../common/bitfield.h" // for RBitField8 #include "../common/common.h" #include "../common/error_msg.h" // for NoF128 #include "xgboost/base.h" @@ -104,7 +104,20 @@ struct ArrayInterfaceErrors { */ class ArrayInterfaceHandler { public: - enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; + enum Type : std::int8_t { + kF2 = 0, + kF4 = 1, + kF8 = 2, + kF16 = 3, + kI1 = 4, + kI2 = 5, + kI4 = 6, + kI8 = 7, + kU1 = 8, + kU2 = 9, + kU4 = 10, + kU8 = 11, + }; template static PtrType GetPtrFromArrayData(Object::Map const &obj) { @@ -587,6 +600,57 @@ class ArrayInterface { ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16}; }; +template +auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) { + switch (dtype) { + case ArrayInterfaceHandler::kF2: { +#if defined(XGBOOST_USE_CUDA) + return dispatch(__half{}); +#else + LOG(FATAL) << "half type is only supported for CUDA input."; + break; +#endif + } + case ArrayInterfaceHandler::kF4: { + return dispatch(float{}); + } + case ArrayInterfaceHandler::kF8: { + return dispatch(double{}); + } + case ArrayInterfaceHandler::kF16: { + using T = long double; + CHECK(sizeof(T) == 16) << error::NoF128(); + return dispatch(T{}); + } + case ArrayInterfaceHandler::kI1: { + return dispatch(std::int8_t{}); + } + case ArrayInterfaceHandler::kI2: { + return dispatch(std::int16_t{}); + } + case ArrayInterfaceHandler::kI4: { + return dispatch(std::int32_t{}); + } + case ArrayInterfaceHandler::kI8: { + return dispatch(std::int64_t{}); + } + case ArrayInterfaceHandler::kU1: { + return dispatch(std::uint8_t{}); + } + case ArrayInterfaceHandler::kU2: { + return dispatch(std::uint16_t{}); + } + case ArrayInterfaceHandler::kU4: { + return dispatch(std::uint32_t{}); + } + case ArrayInterfaceHandler::kU8: { + return dispatch(std::uint64_t{}); + } + } + + return std::result_of_t(); +} + template void DispatchDType(ArrayInterface const array, DeviceOrd device, Fn fn) { // Only used for cuDF at the moment. @@ -602,60 +666,7 @@ void DispatchDType(ArrayInterface const array, DeviceOrd device, Fn fn) { std::numeric_limits::max()}, array.shape, array.strides, device}); }; - switch (array.type) { - case ArrayInterfaceHandler::kF2: { -#if defined(XGBOOST_USE_CUDA) - dispatch(__half{}); -#endif - break; - } - case ArrayInterfaceHandler::kF4: { - dispatch(float{}); - break; - } - case ArrayInterfaceHandler::kF8: { - dispatch(double{}); - break; - } - case ArrayInterfaceHandler::kF16: { - using T = long double; - CHECK(sizeof(long double) == 16) << error::NoF128(); - dispatch(T{}); - break; - } - case ArrayInterfaceHandler::kI1: { - dispatch(std::int8_t{}); - break; - } - case ArrayInterfaceHandler::kI2: { - dispatch(std::int16_t{}); - break; - } - case ArrayInterfaceHandler::kI4: { - dispatch(std::int32_t{}); - break; - } - case ArrayInterfaceHandler::kI8: { - dispatch(std::int64_t{}); - break; - } - case ArrayInterfaceHandler::kU1: { - dispatch(std::uint8_t{}); - break; - } - case ArrayInterfaceHandler::kU2: { - dispatch(std::uint16_t{}); - break; - } - case ArrayInterfaceHandler::kU4: { - dispatch(std::uint32_t{}); - break; - } - case ArrayInterfaceHandler::kU8: { - dispatch(std::uint64_t{}); - break; - } - } + DispatchDType(array.type, dispatch); } /** diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc new file mode 100644 index 000000000..62b87e411 --- /dev/null +++ b/tests/cpp/collective/test_allreduce.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include + +#include "../../../src/collective/allreduce.h" +#include "../../../src/collective/tracker.h" +#include "test_worker.h" // for WorkerForTest, TestDistributed + +namespace xgboost::collective { + +namespace { +class AllreduceWorker : public WorkerForTest { + public: + using WorkerForTest::WorkerForTest; + + void Basic() { + { + std::vector data(13, 0.0); + Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + for (std::size_t i = 0; i < rhs.size(); ++i) { + rhs[i] += lhs[i]; + } + }); + ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); + } + { + std::vector data(1, 1.0); + Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + for (std::size_t i = 0; i < rhs.size(); ++i) { + rhs[i] += lhs[i]; + } + }); + ASSERT_EQ(data[0], static_cast(comm_.World())); + } + } + + void Acc() { + std::vector data(314, 1.5); + Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + for (std::size_t i = 0; i < rhs.size(); ++i) { + rhs[i] += lhs[i]; + } + }); + for (std::size_t i = 0; i < data.size(); ++i) { + auto v = data[i]; + ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; + } + } +}; + +class AllreduceTest : public SocketTest {}; +} // namespace + +TEST_F(AllreduceTest, Basic) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + AllreduceWorker worker{host, port, timeout, n_workers, r}; + worker.Basic(); + }); +} + +TEST_F(AllreduceTest, Sum) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + AllreduceWorker worker{host, port, timeout, n_workers, r}; + worker.Acc(); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 485f6dcdf..0ade86567 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -10,8 +10,8 @@ #include // for vector #include "../../../src/collective/broadcast.h" // for Broadcast -#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker -#include "test_worker.h" // for WorkerForTest +#include "../../../src/collective/tracker.h" // for GetHostAddress +#include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { namespace { @@ -41,28 +41,11 @@ class BroadcastTest : public SocketTest {}; } // namespace TEST_F(BroadcastTest, Basic) { - std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency()); - std::chrono::seconds timeout{3}; - - std::string host; - ASSERT_TRUE(GetHostAddress(&host).OK()); - RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; - auto fut = tracker.Run(); - - std::vector workers; - std::int32_t port = tracker.Port(); - - for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { - Worker worker{host, port, timeout, n_workers, i}; - worker.Run(); - }); - } - - for (auto& t : workers) { - t.join(); - } - - ASSERT_TRUE(fut.get().OK()); + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.Run(); + }); } } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 3c9d02f03..a3d6de875 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -88,4 +88,27 @@ class TrackerTest : public SocketTest { ASSERT_TRUE(rc.OK()) << rc.Report(); } }; + +template +void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { + std::chrono::seconds timeout{1}; + + std::string host; + ASSERT_TRUE(GetHostAddress(&host).OK()); + RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { worker_fn(host, port, timeout, i); }); + } + + for (auto& t : workers) { + t.join(); + } + + ASSERT_TRUE(fut.get().OK()); +} } // namespace xgboost::collective