/** * Copyright 2022-2023 by XGBoost contributors */ #pragma once #include #include #include #include "communicator-inl.h" #include "communicator.h" #include "xgboost/json.h" namespace xgboost { namespace collective { class RabitCommunicator : public Communicator { public: static Communicator *Create(Json const &config) { std::vector args_str; for (auto &items : get(config)) { switch (items.second.GetValue().Type()) { case xgboost::Value::ValueKind::kString: { args_str.push_back(items.first + "=" + get(items.second)); break; } case xgboost::Value::ValueKind::kInteger: { args_str.push_back(items.first + "=" + std::to_string(get(items.second))); break; } case xgboost::Value::ValueKind::kBoolean: { if (get(items.second)) { args_str.push_back(items.first + "=1"); } else { args_str.push_back(items.first + "=0"); } break; } default: break; } } std::vector args; for (auto &key_value : args_str) { args.push_back(&key_value[0]); } if (!rabit::Init(static_cast(args.size()), &args[0])) { LOG(FATAL) << "Failed to initialize Rabit"; } return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); } RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} bool IsDistributed() const override { return rabit::IsDistributed(); } bool IsFederated() const override { return false; } std::string AllGather(std::string_view input) override { auto const per_rank = input.size(); auto const total_size = per_rank * GetWorldSize(); auto const index = per_rank * GetRank(); std::string result(total_size, '\0'); result.replace(index, per_rank, input); rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); return result; } std::string AllGatherV(std::string_view input) override { auto const size_node_slice = input.size(); auto const all_sizes = collective::Allgather(size_node_slice); auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); auto const begin_index = std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); auto const size_prev_slice = GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1]; std::string result(total_size, '\0'); result.replace(begin_index, size_node_slice, input); rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice); return result; } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) { case DataType::kInt8: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt8: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kInt32: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt32: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kInt64: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt64: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kFloat: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kDouble: DoAllReduce(send_receive_buffer, count, op); break; default: LOG(FATAL) << "Unknown data type"; } } void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { rabit::Broadcast(send_receive_buffer, size, root); } std::string GetProcessorName() override { return rabit::GetProcessorName(); } void Print(const std::string &message) override { rabit::TrackerPrint(message); } protected: void Shutdown() override { rabit::Finalize(); } private: template ::value> * = nullptr> void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { switch (op) { case Operation::kBitwiseAND: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseOR: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseXOR: rabit::Allreduce(static_cast(send_receive_buffer), count); break; default: LOG(FATAL) << "Unknown allreduce operation"; } } template ::value> * = nullptr> void DoBitwiseAllReduce(void *, std::size_t, Operation) { LOG(FATAL) << "Floating point types do not support bitwise operations."; } template void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { switch (op) { case Operation::kMax: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kMin: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kSum: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseAND: case Operation::kBitwiseOR: case Operation::kBitwiseXOR: DoBitwiseAllReduce(send_receive_buffer, count, op); break; default: LOG(FATAL) << "Unknown allreduce operation"; } } }; } // namespace collective } // namespace xgboost