From 77b069c25d47c47ceebbaf6235d1b2d996174944 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 24 Dec 2022 14:40:05 -0800 Subject: [PATCH] Support bitwise allreduce operations in the communicator (#8623) --- plugin/federated/federated.proto | 3 + python-package/xgboost/collective.py | 3 + rabit/include/rabit/internal/engine.h | 4 +- rabit/include/rabit/internal/rabit-inl.h | 14 +++ rabit/include/rabit/rabit.h | 10 ++ rabit/src/rabit_c_api.cc | 36 ++++++- src/collective/communicator.h | 9 +- src/collective/in_memory_handler.cc | 28 ++++++ src/collective/rabit_communicator.h | 33 ++++++- .../collective/test_in_memory_communicator.cc | 93 +++++++++++++++---- 10 files changed, 207 insertions(+), 26 deletions(-) diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index 751861d9b..136687109 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -25,6 +25,9 @@ enum ReduceOperation { MAX = 0; MIN = 1; SUM = 2; + BITWISE_AND = 3; + BITWISE_OR = 4; + BITWISE_XOR = 5; } message AllreduceRequest { diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 8021316e8..45d018cc7 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -191,6 +191,9 @@ class Op(IntEnum): MAX = 0 MIN = 1 SUM = 2 + BITWISE_AND = 3 + BITWISE_OR = 4 + BITWISE_XOR = 5 def allreduce( # pylint:disable=invalid-name diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h index 4ebbf68db..aa074fb39 100644 --- a/rabit/include/rabit/internal/engine.h +++ b/rabit/include/rabit/internal/engine.h @@ -133,7 +133,9 @@ enum OpType { kMax = 0, kMin = 1, kSum = 2, - kBitwiseOR = 3 + kBitwiseAND = 3, + kBitwiseOR = 4, + kBitwiseXOR = 5, }; /*!\brief enum of supported data types */ enum DataType { diff --git a/rabit/include/rabit/internal/rabit-inl.h b/rabit/include/rabit/internal/rabit-inl.h index 1f4b2c0e2..49b086320 100644 --- a/rabit/include/rabit/internal/rabit-inl.h +++ b/rabit/include/rabit/internal/rabit-inl.h @@ -85,6 +85,13 @@ struct Sum { dst += src; } }; +struct BitAND { + static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND; + template + inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) + dst &= src; + } +}; struct BitOR { static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR; template @@ -92,6 +99,13 @@ struct BitOR { dst |= src; } }; +struct BitXOR { + static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR; + template + inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) + dst ^= src; + } +}; template inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) { const DType *src = static_cast(src_); diff --git a/rabit/include/rabit/rabit.h b/rabit/include/rabit/rabit.h index 8284a4b6b..10ea9a47f 100644 --- a/rabit/include/rabit/rabit.h +++ b/rabit/include/rabit/rabit.h @@ -50,11 +50,21 @@ struct Min; * \brief sum reduction operator */ struct Sum; +/*! + * \class rabit::op::BitAND + * \brief bitwise AND reduction operator + */ +struct BitAND; /*! * \class rabit::op::BitOR * \brief bitwise OR reduction operator */ struct BitOR; +/*! + * \class rabit::op::BitXOR + * \brief bitwise XOR reduction operator + */ +struct BitXOR; } // namespace op /*! * \brief initializes rabit, call this once at the beginning of your program diff --git a/rabit/src/rabit_c_api.cc b/rabit/src/rabit_c_api.cc index 63abcf83f..c90fae830 100644 --- a/rabit/src/rabit_c_api.cc +++ b/rabit/src/rabit_c_api.cc @@ -23,6 +23,17 @@ struct FHelper { } }; +template +struct FHelper { + static void + Allreduce(DType *, + size_t , + void (*)(void *arg), + void *) { + utils::Error("DataType does not support bitwise AND operation"); + } +}; + template struct FHelper { static void @@ -30,7 +41,18 @@ struct FHelper { size_t , void (*)(void *arg), void *) { - utils::Error("DataType does not support bitwise or operation"); + utils::Error("DataType does not support bitwise OR operation"); + } +}; + +template +struct FHelper { + static void + Allreduce(DType *, + size_t , + void (*)(void *arg), + void *) { + utils::Error("DataType does not support bitwise XOR operation"); } }; @@ -111,12 +133,24 @@ void Allreduce(void *sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; + case kBitwiseAND: + Allreduce + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; case kBitwiseOR: Allreduce (sendrecvbuf, count, enum_dtype, prepare_fun, prepare_arg); return; + case kBitwiseXOR: + Allreduce + (sendrecvbuf, + count, enum_dtype, + prepare_fun, prepare_arg); + return; default: utils::Error("unknown enum_op"); } } diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 65da9320f..9ce637b94 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -58,7 +58,14 @@ inline std::size_t GetTypeSize(DataType data_type) { } /** @brief Defines the reduction operation. */ -enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; +enum class Operation { + kMax = 0, + kMin = 1, + kSum = 2, + kBitwiseAND = 3, + kBitwiseOR = 4, + kBitwiseXOR = 5 +}; class DeviceCommunicator; diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index 790024402..09518fd96 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -30,6 +30,29 @@ class AllreduceFunctor { } private: + template ::value>* = nullptr> + void AccumulateBitwise(T* buffer, T const* input, std::size_t size, + Operation reduce_operation) const { + switch (reduce_operation) { + case Operation::kBitwiseAND: + std::transform(buffer, buffer + size, input, buffer, std::bit_and()); + break; + case Operation::kBitwiseOR: + std::transform(buffer, buffer + size, input, buffer, std::bit_or()); + break; + case Operation::kBitwiseXOR: + std::transform(buffer, buffer + size, input, buffer, std::bit_xor()); + break; + default: + throw std::invalid_argument("Invalid reduce operation"); + } + } + + template ::value>* = nullptr> + void AccumulateBitwise(T*, T const*, std::size_t, Operation) const { + LOG(FATAL) << "Floating point types do not support bitwise operations."; + } + template void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const { switch (reduce_operation) { @@ -44,6 +67,11 @@ class AllreduceFunctor { case Operation::kSum: std::transform(buffer, buffer + size, input, buffer, std::plus()); break; + case Operation::kBitwiseAND: + case Operation::kBitwiseOR: + case Operation::kBitwiseXOR: + AccumulateBitwise(buffer, input, size, reduce_operation); + break; default: throw std::invalid_argument("Invalid reduce operation"); } diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index d17cabc01..712b76eff 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -96,11 +96,33 @@ class RabitCommunicator : public Communicator { void Print(const std::string &message) override { rabit::TrackerPrint(message); } protected: - void Shutdown() override { - rabit::Finalize(); - } + void Shutdown() override { rabit::Finalize(); } private: + template ::value> * = nullptr> + void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { + switch (op) { + case Operation::kBitwiseAND: + rabit::Allreduce(static_cast(send_receive_buffer), + count); + break; + case Operation::kBitwiseOR: + rabit::Allreduce(static_cast(send_receive_buffer), count); + break; + case Operation::kBitwiseXOR: + rabit::Allreduce(static_cast(send_receive_buffer), + count); + break; + default: + LOG(FATAL) << "Unknown allreduce operation"; + } + } + + template ::value> * = nullptr> + void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { + LOG(FATAL) << "Floating point types do not support bitwise operations."; + } + template void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { switch (op) { @@ -113,6 +135,11 @@ class RabitCommunicator : public Communicator { case Operation::kSum: rabit::Allreduce(static_cast(send_receive_buffer), count); break; + case Operation::kBitwiseAND: + case Operation::kBitwiseOR: + case Operation::kBitwiseXOR: + DoBitwiseAllReduce(send_receive_buffer, count, op); + break; default: LOG(FATAL) << "Unknown allreduce operation"; } diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc index ef70e292e..1e4f6521f 100644 --- a/tests/cpp/collective/test_in_memory_communicator.cc +++ b/tests/cpp/collective/test_in_memory_communicator.cc @@ -4,6 +4,7 @@ #include #include +#include #include #include "../../../src/collective/in_memory_communicator.h" @@ -13,7 +14,37 @@ namespace collective { class InMemoryCommunicatorTest : public ::testing::Test { public: - static void VerifyAllreduce(int rank) { + static void Verify(void (*function)(int)) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(function, rank); + } + for (auto &thread : threads) { + thread.join(); + } + } + + static void AllreduceMax(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax); + int expected[] = {3, 4, 5, 6, 7}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(buffer[i], expected[i]); + } + } + + static void AllreduceMin(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin); + int expected[] = {1, 2, 3, 4, 5}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(buffer[i], expected[i]); + } + } + + static void AllreduceSum(int rank) { InMemoryCommunicator comm{kWorldSize, rank}; int buffer[] = {1, 2, 3, 4, 5}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); @@ -23,7 +54,35 @@ class InMemoryCommunicatorTest : public ::testing::Test { } } - static void VerifyBroadcast(int rank) { + static void AllreduceBitwiseAND(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + std::bitset<2> original(rank); + auto buffer = original.to_ulong(); + comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND); + EXPECT_EQ(buffer, 0UL); + } + + static void AllreduceBitwiseOR(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + std::bitset<2> original(rank); + auto buffer = original.to_ulong(); + comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR); + std::bitset<2> actual(buffer); + std::bitset<2> expected{0b11}; + EXPECT_EQ(actual, expected); + } + + static void AllreduceBitwiseXOR(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + std::bitset<3> original(rank * 2); + auto buffer = original.to_ulong(); + comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR); + std::bitset<3> actual(buffer); + std::bitset<3> expected{0b110}; + EXPECT_EQ(actual, expected); + } + + static void Broadcast(int rank) { InMemoryCommunicator comm{kWorldSize, rank}; if (rank == 0) { std::string buffer{"hello"}; @@ -88,25 +147,19 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } -TEST_F(InMemoryCommunicatorTest, Allreduce) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyAllreduce, rank)); - } - for (auto &thread : threads) { - thread.join(); - } -} +TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } -TEST_F(InMemoryCommunicatorTest, Broadcast) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyBroadcast, rank)); - } - for (auto &thread : threads) { - thread.join(); - } -} +TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } + +TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); } + +TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); } + +TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); } + +TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); } + +TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); } } // namespace collective } // namespace xgboost