[coll] Define interface for bridging. (#9695)

* Define the basic interface that will shared by nccl, federated and native.
This commit is contained in:
Jiaming Yuan 2023-10-20 16:20:48 +08:00 committed by GitHub
parent 6fbe6248f4
commit b771f58453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 174 additions and 3 deletions

View File

@ -102,6 +102,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \

View File

@ -102,6 +102,7 @@ OBJECTS= \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/tracker.o \
$(PKGROOT)/src/collective/communicator.o \
$(PKGROOT)/src/collective/in_memory_communicator.o \

View File

@ -3,7 +3,7 @@
*/
#include "allgather.h"
#include <algorithm> // for min, copy_n
#include <algorithm> // for min, copy_n, fill_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
@ -45,6 +45,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result) {
auto world = comm.World();
auto rank = comm.Rank();
@ -56,7 +57,8 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto next_ch = comm.Chan(next);
// get worker offset
std::vector<std::int64_t> offset(world + 1, 0);
CHECK_EQ(world + 1, offset.size());
std::fill_n(offset.data(), offset.size(), 0);
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
CHECK_EQ(*offset.cbegin(), 0);

View File

@ -26,6 +26,7 @@ namespace cpu_impl {
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int64_t> offset,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl
@ -66,7 +67,9 @@ template <typename T>
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = EraseType(h_result);
auto erased_data = EraseType(data);
std::vector<std::int64_t> offset(world + 1);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data,
common::Span{offset.data(), offset.size()}, erased_result);
}
} // namespace xgboost::collective

75
src/collective/coll.cc Normal file
View File

@ -0,0 +1,75 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "coll.h"
#include <algorithm> // for min, max
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include "allgather.h" // for RingAllgatherV, RingAllgather
#include "allreduce.h" // for Allreduce
#include "broadcast.h" // for Broadcast
#include "comm.h" // for Comm
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
[[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm,
common::Span<std::int8_t> data, ArrayInterfaceHandler::Type,
Op op) {
namespace coll = ::xgboost::collective;
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data();
auto p_out = out.data();
for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}
};
auto fn = [&](auto elem_op) {
return coll::Allreduce(
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
};
switch (op) {
case Op::kMax: {
return fn([](auto l, auto r) { return std::max(l, r); });
}
case Op::kMin: {
return fn([](auto l, auto r) { return std::min(l, r); });
}
case Op::kSum: {
return fn(std::plus<>{});
}
case Op::kBitwiseAND: {
return fn(std::bit_and<>{});
}
case Op::kBitwiseOR: {
return fn(std::bit_or<>{});
}
case Op::kBitwiseXOR: {
return fn(std::bit_xor<>{});
}
}
return comm.Block();
}
[[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root) {
return cpu_impl::Broadcast(comm, data, root);
}
[[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size) {
return RingAllgather(comm, data, size);
}
[[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv) {
return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv);
}
} // namespace xgboost::collective

66
src/collective/coll.h Normal file
View File

@ -0,0 +1,66 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <memory> // for enable_shared_from_this
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
/**
* @brief Interface and base implementation for collective.
*/
class Coll : public std::enable_shared_from_this<Coll> {
public:
Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT
/**
* @brief Allreduce
*
* @param [in,out] data Data buffer for input and output.
* @param [in] type data type.
* @param [in] op Reduce operation. For custom operation, user needs to reach down to
* the CPU implementation.
*/
[[nodiscard]] virtual Result Allreduce(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op);
/**
* @brief Broadcast
*
* @param [in,out] data Data buffer for input and output.
* @param [in] root Root rank for broadcast.
*/
[[nodiscard]] virtual Result Broadcast(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::int32_t root);
/**
* @brief Allgather
*
* @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker.
*/
[[nodiscard]] virtual Result Allgather(Context const* ctx, Comm const& comm,
common::Span<std::int8_t> data, std::size_t size);
/**
* @brief Allgather with variable length.
*
* @param [in] data Input data for the current worker.
* @param [in] sizes Size of the input from each worker.
* @param [out] recv_segments pre-allocated offset for each worker in the output, size
* should be equal to (world + 1).
* @param [out] recv pre-allocated buffer for output.
*/
[[nodiscard]] virtual Result AllgatherV(Context const* ctx, Comm const& comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv);
};
} // namespace xgboost::collective

View File

@ -4,6 +4,7 @@
#include <gtest/gtest.h>
#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed
@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest {
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
void BitOr() {
Context ctx;
std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::make_shared<Coll>();
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
ASSERT_TRUE(rc.OK()) << rc.Report();
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
};
class AllreduceTest : public SocketTest {};
@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) {
worker.Acc();
});
}
TEST_F(AllreduceTest, BitOr) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.BitOr();
});
}
} // namespace xgboost::collective