/*! * Copyright 2022 XGBoost contributors */ #pragma once #include #include #include #include "communicator.h" #include "xgboost/json.h" namespace xgboost { namespace collective { class RabitCommunicator : public Communicator { public: static Communicator *Create(Json const &config) { std::vector args_str; for (auto &items : get(config)) { switch (items.second.GetValue().Type()) { case xgboost::Value::ValueKind::kString: { args_str.push_back(items.first + "=" + get(items.second)); break; } case xgboost::Value::ValueKind::kInteger: { args_str.push_back(items.first + "=" + std::to_string(get(items.second))); break; } case xgboost::Value::ValueKind::kBoolean: { if (get(items.second)) { args_str.push_back(items.first + "=1"); } else { args_str.push_back(items.first + "=0"); } break; } default: break; } } std::vector args; for (auto &key_value : args_str) { args.push_back(&key_value[0]); } if (!rabit::Init(static_cast(args.size()), &args[0])) { LOG(FATAL) << "Failed to initialize Rabit"; } return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); } RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} bool IsDistributed() const override { return rabit::IsDistributed(); } bool IsFederated() const override { return false; } void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) { case DataType::kInt8: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt8: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kInt32: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt32: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kInt64: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kUInt64: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kFloat: DoAllReduce(send_receive_buffer, count, op); break; case DataType::kDouble: DoAllReduce(send_receive_buffer, count, op); break; default: LOG(FATAL) << "Unknown data type"; } } void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { rabit::Broadcast(send_receive_buffer, size, root); } std::string GetProcessorName() override { return rabit::GetProcessorName(); } void Print(const std::string &message) override { rabit::TrackerPrint(message); } protected: void Shutdown() override { rabit::Finalize(); } private: template ::value> * = nullptr> void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { switch (op) { case Operation::kBitwiseAND: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseOR: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseXOR: rabit::Allreduce(static_cast(send_receive_buffer), count); break; default: LOG(FATAL) << "Unknown allreduce operation"; } } template ::value> * = nullptr> void DoBitwiseAllReduce(void *, std::size_t, Operation) { LOG(FATAL) << "Floating point types do not support bitwise operations."; } template void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { switch (op) { case Operation::kMax: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kMin: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kSum: rabit::Allreduce(static_cast(send_receive_buffer), count); break; case Operation::kBitwiseAND: case Operation::kBitwiseOR: case Operation::kBitwiseXOR: DoBitwiseAllReduce(send_receive_buffer, count, op); break; default: LOG(FATAL) << "Unknown allreduce operation"; } } }; } // namespace collective } // namespace xgboost