[coll] Define interface for bridging. (#9695)
* Define the basic interface that will shared by nccl, federated and native.
This commit is contained in:
parent
6fbe6248f4
commit
b771f58453
@ -102,6 +102,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/collective/allreduce.o \
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
|
$(PKGROOT)/src/collective/coll.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||||
|
|||||||
@ -102,6 +102,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/collective/allreduce.o \
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
|
$(PKGROOT)/src/collective/coll.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include "allgather.h"
|
#include "allgather.h"
|
||||||
|
|
||||||
#include <algorithm> // for min, copy_n
|
#include <algorithm> // for min, copy_n, fill_n
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
@ -45,6 +45,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
|
|
||||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int8_t const> data,
|
common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t> offset,
|
||||||
common::Span<std::int8_t> erased_result) {
|
common::Span<std::int8_t> erased_result) {
|
||||||
auto world = comm.World();
|
auto world = comm.World();
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
@ -56,7 +57,8 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
auto next_ch = comm.Chan(next);
|
auto next_ch = comm.Chan(next);
|
||||||
|
|
||||||
// get worker offset
|
// get worker offset
|
||||||
std::vector<std::int64_t> offset(world + 1, 0);
|
CHECK_EQ(world + 1, offset.size());
|
||||||
|
std::fill_n(offset.data(), offset.size(), 0);
|
||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||||
CHECK_EQ(*offset.cbegin(), 0);
|
CHECK_EQ(*offset.cbegin(), 0);
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ namespace cpu_impl {
|
|||||||
|
|
||||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int8_t const> data,
|
common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t> offset,
|
||||||
common::Span<std::int8_t> erased_result);
|
common::Span<std::int8_t> erased_result);
|
||||||
} // namespace cpu_impl
|
} // namespace cpu_impl
|
||||||
|
|
||||||
@ -66,7 +67,9 @@ template <typename T>
|
|||||||
auto h_result = common::Span{result.data(), result.size()};
|
auto h_result = common::Span{result.data(), result.size()};
|
||||||
auto erased_result = EraseType(h_result);
|
auto erased_result = EraseType(h_result);
|
||||||
auto erased_data = EraseType(data);
|
auto erased_data = EraseType(data);
|
||||||
|
std::vector<std::int64_t> offset(world + 1);
|
||||||
|
|
||||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
|
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
|
||||||
|
common::Span{offset.data(), offset.size()}, erased_result);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
75
src/collective/coll.cc
Normal file
75
src/collective/coll.cc
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "coll.h"
|
||||||
|
|
||||||
|
#include <algorithm> // for min, max
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int8_t, int64_t
|
||||||
|
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
||||||
|
|
||||||
|
#include "allgather.h" // for RingAllgatherV, RingAllgather
|
||||||
|
#include "allreduce.h" // for Allreduce
|
||||||
|
#include "broadcast.h" // for Broadcast
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
|
||||||
|
Op op) {
|
||||||
|
namespace coll = ::xgboost::collective;
|
||||||
|
|
||||||
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||||
|
auto p_lhs = lhs.data();
|
||||||
|
auto p_out = out.data();
|
||||||
|
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
||||||
|
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto fn = [&](auto elem_op) {
|
||||||
|
return coll::Allreduce(
|
||||||
|
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case Op::kMax: {
|
||||||
|
return fn([](auto l, auto r) { return std::max(l, r); });
|
||||||
|
}
|
||||||
|
case Op::kMin: {
|
||||||
|
return fn([](auto l, auto r) { return std::min(l, r); });
|
||||||
|
}
|
||||||
|
case Op::kSum: {
|
||||||
|
return fn(std::plus<>{});
|
||||||
|
}
|
||||||
|
case Op::kBitwiseAND: {
|
||||||
|
return fn(std::bit_and<>{});
|
||||||
|
}
|
||||||
|
case Op::kBitwiseOR: {
|
||||||
|
return fn(std::bit_or<>{});
|
||||||
|
}
|
||||||
|
case Op::kBitwiseXOR: {
|
||||||
|
return fn(std::bit_xor<>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return comm.Block();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data, std::int32_t root) {
|
||||||
|
return cpu_impl::Broadcast(comm, data, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data, std::size_t size) {
|
||||||
|
return RingAllgather(comm, data, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
|
||||||
|
common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
common::Span<std::int8_t> recv) {
|
||||||
|
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
66
src/collective/coll.h
Normal file
66
src/collective/coll.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int8_t, int64_t
|
||||||
|
#include <memory> // for enable_shared_from_this
|
||||||
|
|
||||||
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
/**
|
||||||
|
* @brief Interface and base implementation for collective.
|
||||||
|
*/
|
||||||
|
class Coll : public std::enable_shared_from_this<Coll> {
|
||||||
|
public:
|
||||||
|
Coll() = default;
|
||||||
|
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Allreduce
|
||||||
|
*
|
||||||
|
* @param [in,out] data Data buffer for input and output.
|
||||||
|
* @param [in] type data type.
|
||||||
|
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
|
||||||
|
* the CPU implementation.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data,
|
||||||
|
ArrayInterfaceHandler::Type type, Op op);
|
||||||
|
/**
|
||||||
|
* @brief Broadcast
|
||||||
|
*
|
||||||
|
* @param [in,out] data Data buffer for input and output.
|
||||||
|
* @param [in] root Root rank for broadcast.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data, std::int32_t root);
|
||||||
|
/**
|
||||||
|
* @brief Allgather
|
||||||
|
*
|
||||||
|
* @param [in,out] data Data buffer for input and output.
|
||||||
|
* @param [in] size Size of data for each worker.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
|
||||||
|
common::Span<std::int8_t> data, std::size_t size);
|
||||||
|
/**
|
||||||
|
* @brief Allgather with variable length.
|
||||||
|
*
|
||||||
|
* @param [in] data Input data for the current worker.
|
||||||
|
* @param [in] sizes Size of the input from each worker.
|
||||||
|
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
|
||||||
|
* should be equal to (world + 1).
|
||||||
|
* @param [out] recv pre-allocated buffer for output.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
|
||||||
|
common::Span<std::int8_t const> data,
|
||||||
|
common::Span<std::int64_t const> sizes,
|
||||||
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
common::Span<std::int8_t> recv);
|
||||||
|
};
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -4,6 +4,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/collective/allreduce.h"
|
#include "../../../src/collective/allreduce.h"
|
||||||
|
#include "../../../src/collective/coll.h" // for Coll
|
||||||
#include "../../../src/collective/tracker.h"
|
#include "../../../src/collective/tracker.h"
|
||||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||||
|
|
||||||
@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest {
|
|||||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BitOr() {
|
||||||
|
Context ctx;
|
||||||
|
std::vector<std::uint32_t> data(comm_.World(), 0);
|
||||||
|
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||||
|
auto pcoll = std::make_shared<Coll>();
|
||||||
|
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
|
||||||
|
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
|
for (auto v : data) {
|
||||||
|
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class AllreduceTest : public SocketTest {};
|
class AllreduceTest : public SocketTest {};
|
||||||
@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) {
|
|||||||
worker.Acc();
|
worker.Acc();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AllreduceTest, BitOr) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.BitOr();
|
||||||
|
});
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user