enable ROCm on latest XGBoost
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
#include <algorithm> // for min, copy_n
|
||||
#include <algorithm> // for min, copy_n, fill_n
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
@@ -45,6 +45,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int8_t> erased_result) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
@@ -56,7 +57,8 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
// get worker offset
|
||||
std::vector<std::int64_t> offset(world + 1, 0);
|
||||
CHECK_EQ(world + 1, offset.size());
|
||||
std::fill_n(offset.data(), offset.size(), 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ namespace cpu_impl {
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t> offset,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
} // namespace cpu_impl
|
||||
|
||||
@@ -66,7 +67,9 @@ template <typename T>
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = EraseType(h_result);
|
||||
auto erased_data = EraseType(data);
|
||||
std::vector<std::int64_t> offset(world + 1);
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||
common::Span{offset.data(), offset.size()}, erased_result);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
75
src/collective/coll.cc
Normal file
75
src/collective/coll.cc
Normal file
@@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "coll.h"
|
||||
|
||||
#include <algorithm> // for min, max
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
||||
|
||||
#include "allgather.h" // for RingAllgatherV, RingAllgather
|
||||
#include "allreduce.h" // for Allreduce
|
||||
#include "broadcast.h" // for Broadcast
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
||||
Op op) {
|
||||
namespace coll = ::xgboost::collective;
|
||||
|
||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||
auto p_lhs = lhs.data();
|
||||
auto p_out = out.data();
|
||||
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
||||
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||
}
|
||||
};
|
||||
auto fn = [&](auto elem_op) {
|
||||
return coll::Allreduce(
|
||||
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
|
||||
};
|
||||
|
||||
switch (op) {
|
||||
case Op::kMax: {
|
||||
return fn([](auto l, auto r) { return std::max(l, r); });
|
||||
}
|
||||
case Op::kMin: {
|
||||
return fn([](auto l, auto r) { return std::min(l, r); });
|
||||
}
|
||||
case Op::kSum: {
|
||||
return fn(std::plus<>{});
|
||||
}
|
||||
case Op::kBitwiseAND: {
|
||||
return fn(std::bit_and<>{});
|
||||
}
|
||||
case Op::kBitwiseOR: {
|
||||
return fn(std::bit_or<>{});
|
||||
}
|
||||
case Op::kBitwiseXOR: {
|
||||
return fn(std::bit_xor<>{});
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root) {
|
||||
return cpu_impl::Broadcast(comm, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size) {
|
||||
return RingAllgather(comm, data, size);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv) {
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
66
src/collective/coll.h
Normal file
66
src/collective/coll.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
#include <memory> // for enable_shared_from_this
|
||||
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Interface and base implementation for collective.
|
||||
*/
|
||||
class Coll : public std::enable_shared_from_this<Coll> {
|
||||
public:
|
||||
Coll() = default;
|
||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||
|
||||
/**
|
||||
* @brief Allreduce
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] type data type.
|
||||
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
||||
* the CPU implementation.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op);
|
||||
/**
|
||||
* @brief Broadcast
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] root Root rank for broadcast.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::int32_t root);
|
||||
/**
|
||||
* @brief Allgather
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] size Size of data for each worker.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t> data, std::size_t size);
|
||||
/**
|
||||
* @brief Allgather with variable length.
|
||||
*
|
||||
* @param [in] data Input data for the current worker.
|
||||
* @param [in] sizes Size of the input from each worker.
|
||||
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
|
||||
* should be equal to (world + 1).
|
||||
* @param [out] recv pre-allocated buffer for output.
|
||||
*/
|
||||
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv);
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -23,7 +23,7 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
|
||||
retry_{retry},
|
||||
tracker_{host, port, -1},
|
||||
task_id_{std::move(task_id)},
|
||||
loop_{std::make_shared<Loop>(timeout)} {}
|
||||
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
|
||||
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
|
||||
Reference in New Issue
Block a user