/** * Copyright 2023-2024, XGBoost Contributors */ #include "coll.h" #include // for min, max, copy_n #include // for size_t #include // for int8_t, int64_t #include // for bit_and, bit_or, bit_xor, plus #include // for string #include // for is_floating_point_v, is_same_v #include // for move #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "allgather.h" // for RingAllgatherV, RingAllgather #include "allreduce.h" // for Allreduce #include "broadcast.h" // for Broadcast #include "comm.h" // for Comm #if defined(XGBOOST_USE_CUDA) #include "cuda_fp16.h" // for __half #endif namespace xgboost::collective { template bool constexpr IsFloatingPointV() { #if defined(XGBOOST_USE_CUDA) return std::is_floating_point_v || std::is_same_v; #else return std::is_floating_point_v; #endif // defined(XGBOOST_USE_CUDA) } [[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span data, ArrayInterfaceHandler::Type type, Op op) { namespace coll = ::xgboost::collective; auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto p_lhs = lhs.data(); auto p_out = out.data(); #if defined(__GNUC__) || defined(__clang__) // For the sum op, one can verify the simd by: addps %xmm15, %xmm14 #pragma omp simd #endif for (std::size_t i = 0; i < lhs.size(); ++i) { p_out[i] = elem_op(p_lhs[i], p_out[i]); } }; auto fn = [&](auto elem_op, auto t) { using T = decltype(t); auto erased_fn = [redop_fn, elem_op](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); redop_fn(lhs_t, rhs_t, elem_op); }; return cpu_impl::RingAllreduce(comm, data, erased_fn, type); }; std::string msg{"Floating point is not supported for bit wise collective operations."}; auto rc = DispatchDType(type, [&](auto t) { using T = decltype(t); switch (op) { case Op::kMax: { return fn([](auto l, auto r) { return std::max(l, r); }, t); } case Op::kMin: { return fn([](auto l, auto r) { return std::min(l, r); }, t); } case Op::kSum: { return fn(std::plus<>{}, t); } case Op::kBitwiseAND: { if constexpr (IsFloatingPointV()) { return Fail(msg); } else { return fn(std::bit_and<>{}, t); } } case Op::kBitwiseOR: { if constexpr (IsFloatingPointV()) { return Fail(msg); } else { return fn(std::bit_or<>{}, t); } } case Op::kBitwiseXOR: { if constexpr (IsFloatingPointV()) { return Fail(msg); } else { return fn(std::bit_xor<>{}, t); } } } return Fail("Invalid op."); }); return std::move(rc) << [&] { return comm.Block(); }; } [[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span data, std::int32_t root) { return cpu_impl::Broadcast(comm, data, root); } [[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data) { return RingAllgather(comm, data); } [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) { // get worker offset detail::AllgatherVOffset(sizes, recv_segments); // copy data auto current = recv.subspan(recv_segments[comm.Rank()], data.size_bytes()); if (current.data() != data.data()) { std::copy_n(data.data(), data.size(), current.data()); } switch (algo) { case AllgatherVAlgo::kRing: return detail::RingAllgatherV(comm, sizes, recv_segments, recv); case AllgatherVAlgo::kBcast: return cpu_impl::BroadcastAllgatherV(comm, sizes, recv); default: { return Fail("Unknown algorithm for allgather-v"); } } } #if !defined(XGBOOST_USE_NCCL) Coll* Coll::MakeCUDAVar() { LOG(FATAL) << "NCCL is required for device communication."; return nullptr; } #endif } // namespace xgboost::collective