[coll] Broadcast. (#9659)

This commit is contained in:
Jiaming Yuan 2023-10-14 09:34:37 +08:00 committed by GitHub
parent 81a059864a
commit 53049b16b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 220 additions and 15 deletions

View File

@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/communicator.o \

View File

@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \ $(PKGROOT)/src/collective/communicator.o \

View File

@ -0,0 +1,83 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "broadcast.h"
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
namespace {
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
RBitField32 rankbits{
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
// prepare for counting trailing zeros.
for (std::int32_t i = 0; i < depth + 1; ++i) {
if (rankbits.Check(i)) {
maskbits.Set(i);
} else {
maskbits.Clear(i);
}
}
CHECK_NE(mask, 0);
auto k = TrailingZeroBits(mask);
auto shifted_parent = shifted_rank - (1 << k);
return shifted_parent;
}
// Shift the root node to rank 0
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto shifted_rank = (rank + world - root) % world;
return shifted_rank;
}
// shift back to the original rank
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
auto orig = (rank + root) % world;
return orig;
}
} // namespace
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
// Binomial tree broadcast
// * Wiki
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
// * Impl
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
auto rank = comm.Rank();
auto world = comm.World();
// shift root to rank 0
auto shifted_rank = ShiftLeft(rank, world, root);
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
if (shifted_rank != 0) { // not root
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
comm.Chan(parent)->RecvAll(data);
auto rc = comm.Chan(parent)->Block();
if (!rc.OK()) {
return Fail("broadcast failed.", std::move(rc));
}
}
for (std::int32_t i = depth; i >= 0; --i) {
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
auto sft_peer = shifted_rank + (1 << i);
auto peer = ShiftRight(sft_peer, world, root);
CHECK_NE(peer, root);
comm.Chan(peer)->SendAll(data);
}
}
return comm.Block();
}
} // namespace xgboost::collective::cpu_impl

View File

@ -0,0 +1,26 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);
}
/**
* @brief binomial tree broadcast is used on CPU with the default implementation.
*/
template <typename T>
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<T> data, std::int32_t root) {
auto n_total_bytes = data.size_bytes();
auto erased =
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}
} // namespace xgboost::collective

View File

@ -5,22 +5,21 @@
#ifndef XGBOOST_COMMON_BITFIELD_H_ #ifndef XGBOOST_COMMON_BITFIELD_H_
#define XGBOOST_COMMON_BITFIELD_H_ #define XGBOOST_COMMON_BITFIELD_H_
#include <algorithm> #include <algorithm> // for min
#include <bitset> #include <bitset> // for bitset
#include <cinttypes> #include <cstdint> // for uint32_t, uint64_t, uint8_t
#include <iostream> #include <ostream> // for ostream
#include <sstream> #include <type_traits> // for conditional_t, is_signed_v
#include <string>
#include <vector>
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include "device_helpers.cuh" #include "device_helpers.cuh"
#endif // defined(__CUDACC__) #endif // defined(__CUDACC__)
#include "xgboost/span.h"
#include "common.h" #include "common.h"
#include "xgboost/span.h" // for Span
namespace xgboost { namespace xgboost {
@ -75,7 +74,7 @@ struct BitFieldContainer {
private: private:
value_type* bits_{nullptr}; value_type* bits_{nullptr};
size_type n_values_{0}; size_type n_values_{0};
static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage."); static_assert(!std::is_signed_v<VT>, "Must use an unsiged type as the underlying storage.");
public: public:
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) { XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
@ -240,11 +239,39 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type // Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
// must be unsigned. // must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>; using LBitField64 = BitFieldContainer<std::uint64_t, LBitsPolicy<std::uint64_t>>;
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>; using RBitField8 = BitFieldContainer<std::uint8_t, RBitsPolicy<unsigned char>>;
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>; using LBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t>>;
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>; using CLBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t, true>, true>;
using RBitField32 = BitFieldContainer<std::uint32_t, RBitsPolicy<std::uint32_t>>;
namespace detail {
inline std::uint32_t TrailingZeroBitsImpl(std::uint32_t value) {
auto n = sizeof(value) * 8;
std::uint32_t cnt{0};
for (decltype(n) i = 0; i < n; i++) {
if ((value >> i) & 1) {
break;
}
cnt++;
}
return cnt;
}
} // namespace detail
inline std::uint32_t TrailingZeroBits(std::uint32_t value) {
if (value == 0) {
return sizeof(value) * 8;
}
#if defined(__GNUC__)
return __builtin_ctz(value);
#elif defined(_MSC_VER)
return _tzcnt_u32(value);
#else
return detail::TrailingZeroBitsImpl(value);
#endif // __GNUC__
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_BITFIELD_H_ #endif // XGBOOST_COMMON_BITFIELD_H_

View File

@ -6,7 +6,6 @@
#ifndef XGBOOST_COMMON_COMMON_H_ #ifndef XGBOOST_COMMON_COMMON_H_
#define XGBOOST_COMMON_COMMON_H_ #define XGBOOST_COMMON_COMMON_H_
#include <algorithm> // for max
#include <array> // for array #include <array> // for array
#include <cmath> // for ceil #include <cmath> // for ceil
#include <cstddef> // for size_t #include <cstddef> // for size_t
@ -181,7 +180,7 @@ inline void SetDevice(std::int32_t device) {
#endif #endif
/** /**
* Last index of a group in a CSR style of index pointer. * @brief Last index of a group in a CSR style of index pointer.
*/ */
template <typename Indexable> template <typename Indexable>
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) { XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {

View File

@ -0,0 +1,68 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <cstdint> // for int32_t
#include <string> // for string
#include <thread> // for thread
#include <vector> // for vector
#include "../../../src/collective/broadcast.h" // for Broadcast
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
#include "test_worker.h" // for WorkerForTest
namespace xgboost::collective {
namespace {
class Worker : public WorkerForTest {
public:
using WorkerForTest::WorkerForTest;
void Run() {
for (std::int32_t r = 0; r < comm_.World(); ++r) {
// basic test
std::vector<std::int32_t> data(1, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(data[0], r);
}
for (std::int32_t r = 0; r < comm_.World(); ++r) {
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(data[0], r);
}
}
};
class BroadcastTest : public SocketTest {};
} // namespace
TEST_F(BroadcastTest, Basic) {
std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency());
std::chrono::seconds timeout{3};
std::string host;
ASSERT_TRUE(GetHostAddress(&host).OK());
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
auto fut = tracker.Run();
std::vector<std::thread> workers;
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] {
Worker worker{host, port, timeout, n_workers, i};
worker.Run();
});
}
for (auto& t : workers) {
t.join();
}
ASSERT_TRUE(fut.get().OK());
}
} // namespace xgboost::collective