76 lines
2.5 KiB
C++
76 lines
2.5 KiB
C++
/**
|
|
* Copyright 2023, XGBoost Contributors
|
|
*/
|
|
#include "coll.h"
|
|
|
|
#include <algorithm> // for min, max
|
|
#include <cstddef> // for size_t
|
|
#include <cstdint> // for int8_t, int64_t
|
|
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
|
|
|
#include "allgather.h" // for RingAllgatherV, RingAllgather
|
|
#include "allreduce.h" // for Allreduce
|
|
#include "broadcast.h" // for Broadcast
|
|
#include "comm.h" // for Comm
|
|
#include "xgboost/context.h" // for Context
|
|
|
|
namespace xgboost::collective {
|
|
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
|
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
|
Op op) {
|
|
namespace coll = ::xgboost::collective;
|
|
|
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
|
auto p_lhs = lhs.data();
|
|
auto p_out = out.data();
|
|
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
|
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
|
}
|
|
};
|
|
auto fn = [&](auto elem_op) {
|
|
return coll::Allreduce(
|
|
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
|
|
};
|
|
|
|
switch (op) {
|
|
case Op::kMax: {
|
|
return fn([](auto l, auto r) { return std::max(l, r); });
|
|
}
|
|
case Op::kMin: {
|
|
return fn([](auto l, auto r) { return std::min(l, r); });
|
|
}
|
|
case Op::kSum: {
|
|
return fn(std::plus<>{});
|
|
}
|
|
case Op::kBitwiseAND: {
|
|
return fn(std::bit_and<>{});
|
|
}
|
|
case Op::kBitwiseOR: {
|
|
return fn(std::bit_or<>{});
|
|
}
|
|
case Op::kBitwiseXOR: {
|
|
return fn(std::bit_xor<>{});
|
|
}
|
|
}
|
|
return comm.Block();
|
|
}
|
|
|
|
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
|
common::Span<std::int8_t> data, std::int32_t root) {
|
|
return cpu_impl::Broadcast(comm, data, root);
|
|
}
|
|
|
|
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
|
common::Span<std::int8_t> data, std::size_t size) {
|
|
return RingAllgather(comm, data, size);
|
|
}
|
|
|
|
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
|
common::Span<std::int8_t const> data,
|
|
common::Span<std::int64_t const> sizes,
|
|
common::Span<std::int64_t> recv_segments,
|
|
common::Span<std::int8_t> recv) {
|
|
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
|
}
|
|
} // namespace xgboost::collective
|