From 5d1bcde7196d34ef7ac030f4463e0b45d35a6f3d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 19 Oct 2023 03:13:50 +0800 Subject: [PATCH] [coll] allgatherv. (#9688) --- src/collective/allgather.cc | 54 +++++++++++++++++++-- src/collective/allgather.h | 45 ++++++++++++++---- src/collective/broadcast.cc | 1 + src/collective/comm.cc | 2 + src/collective/comm.h | 22 ++++----- src/common/host_device_vector.cc | 1 + src/common/host_device_vector.cu | 1 + tests/cpp/collective/test_allgather.cc | 66 ++++++++++++++++++++++---- 8 files changed, 157 insertions(+), 35 deletions(-) diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index dba36c88c..378a06911 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -3,13 +3,16 @@ */ #include "allgather.h" -#include // for min +#include // for min, copy_n #include // for size_t -#include // for int8_t +#include // for int8_t, int32_t, int64_t #include // for shared_ptr +#include // for partial_sum +#include // for vector -#include "comm.h" // for Comm, Channel -#include "xgboost/span.h" // for Span +#include "comm.h" // for Comm, Channel +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span namespace xgboost::collective::cpu_impl { Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, @@ -39,4 +42,47 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size return Success(); } + +[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, + common::Span data, + common::Span erased_result) { + auto world = comm.World(); + auto rank = comm.Rank(); + + auto prev = BootstrapPrev(rank, comm.World()); + auto next = BootstrapNext(rank, comm.World()); + + auto prev_ch = comm.Chan(prev); + auto next_ch = comm.Chan(next); + + // get worker offset + std::vector offset(world + 1, 0); + std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1); + CHECK_EQ(*offset.cbegin(), 0); + + // copy data + auto current = erased_result.subspan(offset[rank], data.size_bytes()); + auto erased_data = EraseType(data); + std::copy_n(erased_data.data(), erased_data.size(), current.data()); + + for (std::int32_t r = 0; r < world; ++r) { + auto send_rank = (rank + world - r) % world; + auto send_off = offset[send_rank]; + auto send_size = sizes[send_rank]; + auto send_seg = erased_result.subspan(send_off, send_size); + next_ch->SendAll(send_seg); + + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = offset[recv_rank]; + auto recv_size = sizes[recv_rank]; + auto recv_seg = erased_result.subspan(recv_off, recv_size); + prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); + + auto rc = prev_ch->Block(); + if (!rc.OK()) { + return rc; + } + } + return comm.Block(); +} } // namespace xgboost::collective::cpu_impl diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 5dcb4ebdd..cb5f5b8af 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -2,12 +2,16 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for size_t -#include // for int32_t -#include // for shared_ptr +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include // for accumulate +#include // for remove_cv_t +#include // for vector -#include "comm.h" // for Comm, Channel -#include "xgboost/span.h" // for Span +#include "comm.h" // for Comm, Channel, EraseType +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { @@ -19,14 +23,16 @@ namespace cpu_impl { std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch); + +[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, + common::Span data, + common::Span erased_result); } // namespace cpu_impl template [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { - auto n_total_bytes = data.size_bytes(); auto n_bytes = sizeof(T) * size; - auto erased = - common::Span{reinterpret_cast(data.data()), n_total_bytes}; + auto erased = EraseType(data); auto rank = comm.Rank(); auto prev = BootstrapPrev(rank, comm.World()); @@ -40,4 +46,27 @@ template } return comm.Block(); } + +template +[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span data, + std::vector>* p_out) { + auto world = comm.World(); + auto rank = comm.Rank(); + + std::vector sizes(world, 0); + sizes[rank] = data.size_bytes(); + auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); + if (!rc.OK()) { + return rc; + } + + std::vector& result = *p_out; + auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); + result.resize(n_total_bytes / sizeof(T)); + auto h_result = common::Span{result.data(), result.size()}; + auto erased_result = EraseType(h_result); + auto erased_data = EraseType(data); + + return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result); +} } // namespace xgboost::collective diff --git a/src/collective/broadcast.cc b/src/collective/broadcast.cc index be7e8f55f..660bb9130 100644 --- a/src/collective/broadcast.cc +++ b/src/collective/broadcast.cc @@ -3,6 +3,7 @@ */ #include "broadcast.h" +#include // for ceil, log2 #include // for int32_t, int8_t #include // for move diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 7e0af9c18..9ee1e0e6a 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -11,8 +11,10 @@ #include "allgather.h" #include "protocol.h" // for kMagic +#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE #include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/json.h" // for Json, Object +#include "xgboost/string_view.h" // for StringView namespace xgboost::collective { Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, diff --git a/src/collective/comm.h b/src/collective/comm.h index f23810034..b501fcddd 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -2,20 +2,16 @@ * Copyright 2023, XGBoost Contributors */ #pragma once -#include // for seconds -#include // for condition_variable -#include // for size_t -#include // for int32_t -#include // for shared_ptr -#include // for mutex -#include // for queue -#include // for string -#include // for thread -#include // for remove_const_t -#include // for move -#include // for vector +#include // for seconds +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include // for string +#include // for thread +#include // for remove_const_t +#include // for move +#include // for vector -#include "../common/timer.h" #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 66d8024bd..a7a996c6c 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -175,6 +175,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_node_t template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 5f7b71043..4933a4b11 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -409,6 +409,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_node_t template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index 49ba591d0..a74b9f149 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -1,18 +1,23 @@ /** * Copyright 2023, XGBoost Contributors */ -#include -#include // for Span +#include // for ASSERT_EQ +#include // for Span, oper... -#include // for int32_t -#include // for iota -#include // for string -#include // for thread -#include // for vector +#include // for min +#include // for seconds +#include // for size_t +#include // for int32_t +#include // for iota +#include // for string +#include // for thread +#include // for vector -#include "../../../src/collective/allgather.h" -#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker -#include "test_worker.h" // for TestDistributed== +#include "../../../src/collective/allgather.h" // for RingAllgather +#include "../../../src/collective/comm.h" // for RabitComm +#include "gtest/gtest.h" // for AssertionR... +#include "test_worker.h" // for TestDistri... +#include "xgboost/collective/result.h" // for Result namespace xgboost::collective { namespace { @@ -57,6 +62,38 @@ class Worker : public WorkerForTest { } } } + + void TestV() { + { + // basic test + std::int32_t n{comm_.Rank()}; + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + for (std::int32_t i = 0; i < comm_.World(); ++i) { + ASSERT_EQ(result[i], i); + } + } + + { + // V test + std::vector data(comm_.Rank() + 1, comm_.Rank()); + std::vector result; + auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); + std::int32_t k{0}; + for (std::int32_t r = 0; r < comm_.World(); ++r) { + auto seg = common::Span{result.data(), result.size()}.subspan(k, (r + 1)); + if (comm_.Rank() == 0) { + for (auto v : seg) { + ASSERT_EQ(v, r); + } + k += seg.size(); + } + } + } + } }; } // namespace @@ -68,4 +105,13 @@ TEST_F(AllgatherTest, Basic) { worker.Run(); }); } + +TEST_F(AllgatherTest, V) { + std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + Worker worker{host, port, timeout, n_workers, r}; + worker.TestV(); + }); +} } // namespace xgboost::collective