diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 541c0fb52..808960319 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index faacd6d8d..43bfcf7c1 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/allgather.o \ + $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/communicator.o \ diff --git a/src/collective/broadcast.cc b/src/collective/broadcast.cc new file mode 100644 index 000000000..be7e8f55f --- /dev/null +++ b/src/collective/broadcast.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "broadcast.h" + +#include // for int32_t, int8_t +#include // for move + +#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32 +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective::cpu_impl { +namespace { +std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) { + std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff... + RBitField32 maskbits{common::Span{&mask, 1}}; + RBitField32 rankbits{ + common::Span{reinterpret_cast(&shifted_rank), 1}}; + // prepare for counting trailing zeros. + for (std::int32_t i = 0; i < depth + 1; ++i) { + if (rankbits.Check(i)) { + maskbits.Set(i); + } else { + maskbits.Clear(i); + } + } + + CHECK_NE(mask, 0); + auto k = TrailingZeroBits(mask); + auto shifted_parent = shifted_rank - (1 << k); + return shifted_parent; +} + +// Shift the root node to rank 0 +std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) { + auto shifted_rank = (rank + world - root) % world; + return shifted_rank; +} +// shift back to the original rank +std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) { + auto orig = (rank + root) % world; + return orig; +} +} // namespace + +Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) { + // Binomial tree broadcast + // * Wiki + // https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast + // * Impl + // https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf + + auto rank = comm.Rank(); + auto world = comm.World(); + + // shift root to rank 0 + auto shifted_rank = ShiftLeft(rank, world, root); + std::int32_t depth = std::ceil(std::log2(static_cast(world))) - 1; + + if (shifted_rank != 0) { // not root + auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root); + comm.Chan(parent)->RecvAll(data); + auto rc = comm.Chan(parent)->Block(); + if (!rc.OK()) { + return Fail("broadcast failed.", std::move(rc)); + } + } + + for (std::int32_t i = depth; i >= 0; --i) { + CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative + if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) { + auto sft_peer = shifted_rank + (1 << i); + auto peer = ShiftRight(sft_peer, world, root); + CHECK_NE(peer, root); + comm.Chan(peer)->SendAll(data); + } + } + + return comm.Block(); +} +} // namespace xgboost::collective::cpu_impl diff --git a/src/collective/broadcast.h b/src/collective/broadcast.h new file mode 100644 index 000000000..28db83815 --- /dev/null +++ b/src/collective/broadcast.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for int32_t, int8_t + +#include "comm.h" // for Comm +#include "xgboost/collective/result.h" // for +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +namespace cpu_impl { +Result Broadcast(Comm const& comm, common::Span data, std::int32_t root); +} + +/** + * @brief binomial tree broadcast is used on CPU with the default implementation. + */ +template +[[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) { + auto n_total_bytes = data.size_bytes(); + auto erased = + common::Span{reinterpret_cast(data.data()), n_total_bytes}; + return cpu_impl::Broadcast(comm, erased, root); +} +} // namespace xgboost::collective diff --git a/src/common/bitfield.h b/src/common/bitfield.h index 6cdf4412e..efabaa834 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -5,22 +5,21 @@ #ifndef XGBOOST_COMMON_BITFIELD_H_ #define XGBOOST_COMMON_BITFIELD_H_ -#include -#include -#include -#include -#include -#include -#include +#include // for min +#include // for bitset +#include // for uint32_t, uint64_t, uint8_t +#include // for ostream +#include // for conditional_t, is_signed_v #if defined(__CUDACC__) #include #include + #include "device_helpers.cuh" #endif // defined(__CUDACC__) -#include "xgboost/span.h" #include "common.h" +#include "xgboost/span.h" // for Span namespace xgboost { @@ -75,7 +74,7 @@ struct BitFieldContainer { private: value_type* bits_{nullptr}; size_type n_values_{0}; - static_assert(!std::is_signed::value, "Must use an unsiged type as the underlying storage."); + static_assert(!std::is_signed_v, "Must use an unsiged type as the underlying storage."); public: XGBOOST_DEVICE static Pos ToBitPos(index_type pos) { @@ -240,11 +239,39 @@ struct RBitsPolicy : public BitFieldContainer> { // Format: BitField, underlying type // must be unsigned. -using LBitField64 = BitFieldContainer>; -using RBitField8 = BitFieldContainer>; +using LBitField64 = BitFieldContainer>; +using RBitField8 = BitFieldContainer>; -using LBitField32 = BitFieldContainer>; -using CLBitField32 = BitFieldContainer, true>; +using LBitField32 = BitFieldContainer>; +using CLBitField32 = BitFieldContainer, true>; +using RBitField32 = BitFieldContainer>; + +namespace detail { +inline std::uint32_t TrailingZeroBitsImpl(std::uint32_t value) { + auto n = sizeof(value) * 8; + std::uint32_t cnt{0}; + for (decltype(n) i = 0; i < n; i++) { + if ((value >> i) & 1) { + break; + } + cnt++; + } + return cnt; +} +} // namespace detail + +inline std::uint32_t TrailingZeroBits(std::uint32_t value) { + if (value == 0) { + return sizeof(value) * 8; + } +#if defined(__GNUC__) + return __builtin_ctz(value); +#elif defined(_MSC_VER) + return _tzcnt_u32(value); +#else + return detail::TrailingZeroBitsImpl(value); +#endif // __GNUC__ +} } // namespace xgboost #endif // XGBOOST_COMMON_BITFIELD_H_ diff --git a/src/common/common.h b/src/common/common.h index bedff80b3..2abb34cb2 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -6,7 +6,6 @@ #ifndef XGBOOST_COMMON_COMMON_H_ #define XGBOOST_COMMON_COMMON_H_ -#include // for max #include // for array #include // for ceil #include // for size_t @@ -181,7 +180,7 @@ inline void SetDevice(std::int32_t device) { #endif /** - * Last index of a group in a CSR style of index pointer. + * @brief Last index of a group in a CSR style of index pointer. */ template XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) { diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc new file mode 100644 index 000000000..485f6dcdf --- /dev/null +++ b/tests/cpp/collective/test_broadcast.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include +#include + +#include // for int32_t +#include // for string +#include // for thread +#include // for vector + +#include "../../../src/collective/broadcast.h" // for Broadcast +#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker +#include "test_worker.h" // for WorkerForTest + +namespace xgboost::collective { +namespace { +class Worker : public WorkerForTest { + public: + using WorkerForTest::WorkerForTest; + + void Run() { + for (std::int32_t r = 0; r < comm_.World(); ++r) { + // basic test + std::vector data(1, comm_.Rank()); + auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_EQ(data[0], r); + } + + for (std::int32_t r = 0; r < comm_.World(); ++r) { + std::vector data(1 << 16, comm_.Rank()); + auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); + ASSERT_TRUE(rc.OK()) << rc.Report(); + ASSERT_EQ(data[0], r); + } + } +}; + +class BroadcastTest : public SocketTest {}; +} // namespace + +TEST_F(BroadcastTest, Basic) { + std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency()); + std::chrono::seconds timeout{3}; + + std::string host; + ASSERT_TRUE(GetHostAddress(&host).OK()); + RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + Worker worker{host, port, timeout, n_workers, i}; + worker.Run(); + }); + } + + for (auto& t : workers) { + t.join(); + } + + ASSERT_TRUE(fut.get().OK()); +} +} // namespace xgboost::collective