[coll] Broadcast. (#9659)
This commit is contained in:
parent
81a059864a
commit
53049b16b8
@ -99,6 +99,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
|
|||||||
@ -99,6 +99,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
|
|||||||
83
src/collective/broadcast.cc
Normal file
83
src/collective/broadcast.cc
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "broadcast.h"
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t, int8_t
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
|
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective::cpu_impl {
|
||||||
|
namespace {
|
||||||
|
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
|
||||||
|
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
|
||||||
|
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
|
||||||
|
RBitField32 rankbits{
|
||||||
|
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
|
||||||
|
// prepare for counting trailing zeros.
|
||||||
|
for (std::int32_t i = 0; i < depth + 1; ++i) {
|
||||||
|
if (rankbits.Check(i)) {
|
||||||
|
maskbits.Set(i);
|
||||||
|
} else {
|
||||||
|
maskbits.Clear(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_NE(mask, 0);
|
||||||
|
auto k = TrailingZeroBits(mask);
|
||||||
|
auto shifted_parent = shifted_rank - (1 << k);
|
||||||
|
return shifted_parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shift the root node to rank 0
|
||||||
|
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||||
|
auto shifted_rank = (rank + world - root) % world;
|
||||||
|
return shifted_rank;
|
||||||
|
}
|
||||||
|
// shift back to the original rank
|
||||||
|
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||||
|
auto orig = (rank + root) % world;
|
||||||
|
return orig;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
|
||||||
|
// Binomial tree broadcast
|
||||||
|
// * Wiki
|
||||||
|
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
|
||||||
|
// * Impl
|
||||||
|
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
|
||||||
|
|
||||||
|
auto rank = comm.Rank();
|
||||||
|
auto world = comm.World();
|
||||||
|
|
||||||
|
// shift root to rank 0
|
||||||
|
auto shifted_rank = ShiftLeft(rank, world, root);
|
||||||
|
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
|
||||||
|
|
||||||
|
if (shifted_rank != 0) { // not root
|
||||||
|
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||||
|
comm.Chan(parent)->RecvAll(data);
|
||||||
|
auto rc = comm.Chan(parent)->Block();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("broadcast failed.", std::move(rc));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (std::int32_t i = depth; i >= 0; --i) {
|
||||||
|
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
|
||||||
|
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
|
||||||
|
auto sft_peer = shifted_rank + (1 << i);
|
||||||
|
auto peer = ShiftRight(sft_peer, world, root);
|
||||||
|
CHECK_NE(peer, root);
|
||||||
|
comm.Chan(peer)->SendAll(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return comm.Block();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective::cpu_impl
|
||||||
26
src/collective/broadcast.h
Normal file
26
src/collective/broadcast.h
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <cstdint> // for int32_t, int8_t
|
||||||
|
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/collective/result.h" // for
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace cpu_impl {
|
||||||
|
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief binomial tree broadcast is used on CPU with the default implementation.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<T> data, std::int32_t root) {
|
||||||
|
auto n_total_bytes = data.size_bytes();
|
||||||
|
auto erased =
|
||||||
|
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||||
|
return cpu_impl::Broadcast(comm, erased, root);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -5,22 +5,21 @@
|
|||||||
#ifndef XGBOOST_COMMON_BITFIELD_H_
|
#ifndef XGBOOST_COMMON_BITFIELD_H_
|
||||||
#define XGBOOST_COMMON_BITFIELD_H_
|
#define XGBOOST_COMMON_BITFIELD_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // for min
|
||||||
#include <bitset>
|
#include <bitset> // for bitset
|
||||||
#include <cinttypes>
|
#include <cstdint> // for uint32_t, uint64_t, uint8_t
|
||||||
#include <iostream>
|
#include <ostream> // for ostream
|
||||||
#include <sstream>
|
#include <type_traits> // for conditional_t, is_signed_v
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/device_ptr.h>
|
#include <thrust/device_ptr.h>
|
||||||
|
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#endif // defined(__CUDACC__)
|
#endif // defined(__CUDACC__)
|
||||||
|
|
||||||
#include "xgboost/span.h"
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ struct BitFieldContainer {
|
|||||||
private:
|
private:
|
||||||
value_type* bits_{nullptr};
|
value_type* bits_{nullptr};
|
||||||
size_type n_values_{0};
|
size_type n_values_{0};
|
||||||
static_assert(!std::is_signed<VT>::value, "Must use an unsiged type as the underlying storage.");
|
static_assert(!std::is_signed_v<VT>, "Must use an unsiged type as the underlying storage.");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
|
XGBOOST_DEVICE static Pos ToBitPos(index_type pos) {
|
||||||
@ -240,11 +239,39 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
|
|||||||
|
|
||||||
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
|
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
|
||||||
// must be unsigned.
|
// must be unsigned.
|
||||||
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
|
using LBitField64 = BitFieldContainer<std::uint64_t, LBitsPolicy<std::uint64_t>>;
|
||||||
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
|
using RBitField8 = BitFieldContainer<std::uint8_t, RBitsPolicy<unsigned char>>;
|
||||||
|
|
||||||
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
|
using LBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t>>;
|
||||||
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
|
using CLBitField32 = BitFieldContainer<std::uint32_t, LBitsPolicy<std::uint32_t, true>, true>;
|
||||||
|
using RBitField32 = BitFieldContainer<std::uint32_t, RBitsPolicy<std::uint32_t>>;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
inline std::uint32_t TrailingZeroBitsImpl(std::uint32_t value) {
|
||||||
|
auto n = sizeof(value) * 8;
|
||||||
|
std::uint32_t cnt{0};
|
||||||
|
for (decltype(n) i = 0; i < n; i++) {
|
||||||
|
if ((value >> i) & 1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
|
return cnt;
|
||||||
|
}
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
inline std::uint32_t TrailingZeroBits(std::uint32_t value) {
|
||||||
|
if (value == 0) {
|
||||||
|
return sizeof(value) * 8;
|
||||||
|
}
|
||||||
|
#if defined(__GNUC__)
|
||||||
|
return __builtin_ctz(value);
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
return _tzcnt_u32(value);
|
||||||
|
#else
|
||||||
|
return detail::TrailingZeroBitsImpl(value);
|
||||||
|
#endif // __GNUC__
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_COMMON_BITFIELD_H_
|
#endif // XGBOOST_COMMON_BITFIELD_H_
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
#ifndef XGBOOST_COMMON_COMMON_H_
|
#ifndef XGBOOST_COMMON_COMMON_H_
|
||||||
#define XGBOOST_COMMON_COMMON_H_
|
#define XGBOOST_COMMON_COMMON_H_
|
||||||
|
|
||||||
#include <algorithm> // for max
|
|
||||||
#include <array> // for array
|
#include <array> // for array
|
||||||
#include <cmath> // for ceil
|
#include <cmath> // for ceil
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
@ -181,7 +180,7 @@ inline void SetDevice(std::int32_t device) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Last index of a group in a CSR style of index pointer.
|
* @brief Last index of a group in a CSR style of index pointer.
|
||||||
*/
|
*/
|
||||||
template <typename Indexable>
|
template <typename Indexable>
|
||||||
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
||||||
|
|||||||
68
tests/cpp/collective/test_broadcast.cc
Normal file
68
tests/cpp/collective/test_broadcast.cc
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/collective/socket.h>
|
||||||
|
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/collective/broadcast.h" // for Broadcast
|
||||||
|
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
|
||||||
|
#include "test_worker.h" // for WorkerForTest
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace {
|
||||||
|
class Worker : public WorkerForTest {
|
||||||
|
public:
|
||||||
|
using WorkerForTest::WorkerForTest;
|
||||||
|
|
||||||
|
void Run() {
|
||||||
|
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||||
|
// basic test
|
||||||
|
std::vector<std::int32_t> data(1, comm_.Rank());
|
||||||
|
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
ASSERT_EQ(data[0], r);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < comm_.World(); ++r) {
|
||||||
|
std::vector<std::int32_t> data(1 << 16, comm_.Rank());
|
||||||
|
auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
ASSERT_EQ(data[0], r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BroadcastTest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(BroadcastTest, Basic) {
|
||||||
|
std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency());
|
||||||
|
std::chrono::seconds timeout{3};
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||||
|
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||||
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
workers.emplace_back([=] {
|
||||||
|
Worker worker{host, port, timeout, n_workers, i};
|
||||||
|
worker.Run();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : workers) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
Loading…
x
Reference in New Issue
Block a user