Jiaming Yuan b771f58453
[coll] Define interface for bridging. (#9695)
* Define the basic interface that will shared by nccl, federated and native.
2023-10-20 16:20:48 +08:00

76 lines
2.5 KiB
C++

/**
* Copyright 2023, XGBoost Contributors
*/
#include "coll.h"
#include <algorithm> // for min, max
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data();
auto p_out = out.data();
for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}
};
auto fn = [&](auto elem_op) {
return coll::Allreduce(
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
};
switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); });
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); });
}
case Op::kSum: {
return fn(std::plus<>{});
}
case Op::kBitwiseAND: {
return fn(std::bit_and<>{});
}
case Op::kBitwiseOR: {
return fn(std::bit_or<>{});
}
case Op::kBitwiseXOR: {
return fn(std::bit_xor<>{});
}
}
return comm.Block();
}
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
return RingAllgather(comm, data, size);
}
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
}
} // namespace xgboost::collective