[coll] Broadcast. (#9659)
This commit is contained in:
83
src/collective/broadcast.cc
Normal file
83
src/collective/broadcast.cc
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "broadcast.h"
|
||||
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/bitfield.h" // for TrailingZeroBits, RBitField32
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace {
|
||||
std::int32_t ShiftedParentRank(std::int32_t shifted_rank, std::int32_t depth) {
|
||||
std::uint32_t mask{std::uint32_t{0} - 1}; // Oxff...
|
||||
RBitField32 maskbits{common::Span<std::uint32_t>{&mask, 1}};
|
||||
RBitField32 rankbits{
|
||||
common::Span<std::uint32_t>{reinterpret_cast<std::uint32_t*>(&shifted_rank), 1}};
|
||||
// prepare for counting trailing zeros.
|
||||
for (std::int32_t i = 0; i < depth + 1; ++i) {
|
||||
if (rankbits.Check(i)) {
|
||||
maskbits.Set(i);
|
||||
} else {
|
||||
maskbits.Clear(i);
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_NE(mask, 0);
|
||||
auto k = TrailingZeroBits(mask);
|
||||
auto shifted_parent = shifted_rank - (1 << k);
|
||||
return shifted_parent;
|
||||
}
|
||||
|
||||
// Shift the root node to rank 0
|
||||
std::int32_t ShiftLeft(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||
auto shifted_rank = (rank + world - root) % world;
|
||||
return shifted_rank;
|
||||
}
|
||||
// shift back to the original rank
|
||||
std::int32_t ShiftRight(std::int32_t rank, std::int32_t world, std::int32_t root) {
|
||||
auto orig = (rank + root) % world;
|
||||
return orig;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root) {
|
||||
// Binomial tree broadcast
|
||||
// * Wiki
|
||||
// https://en.wikipedia.org/wiki/Broadcast_(parallel_pattern)#Binomial_Tree_Broadcast
|
||||
// * Impl
|
||||
// https://people.mpi-inf.mpg.de/~mehlhorn/ftp/NewToolbox/collective.pdf
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto world = comm.World();
|
||||
|
||||
// shift root to rank 0
|
||||
auto shifted_rank = ShiftLeft(rank, world, root);
|
||||
std::int32_t depth = std::ceil(std::log2(static_cast<double>(world))) - 1;
|
||||
|
||||
if (shifted_rank != 0) { // not root
|
||||
auto parent = ShiftRight(ShiftedParentRank(shifted_rank, depth), world, root);
|
||||
comm.Chan(parent)->RecvAll(data);
|
||||
auto rc = comm.Chan(parent)->Block();
|
||||
if (!rc.OK()) {
|
||||
return Fail("broadcast failed.", std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
for (std::int32_t i = depth; i >= 0; --i) {
|
||||
CHECK_GE((i + 1), 0); // weird clang-tidy error that i might be negative
|
||||
if (shifted_rank % (1 << (i + 1)) == 0 && shifted_rank + (1 << i) < world) {
|
||||
auto sft_peer = shifted_rank + (1 << i);
|
||||
auto peer = ShiftRight(sft_peer, world, root);
|
||||
CHECK_NE(peer, root);
|
||||
comm.Chan(peer)->SendAll(data);
|
||||
}
|
||||
}
|
||||
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
26
src/collective/broadcast.h
Normal file
26
src/collective/broadcast.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, std::int32_t root);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief binomial tree broadcast is used on CPU with the default implementation.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<T> data, std::int32_t root) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto erased =
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
return cpu_impl::Broadcast(comm, erased, root);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user