Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-09-12 15:21:12 -07:00
committed by GitHub
parent bc818316f2
commit a2686543a9
25 changed files with 1771 additions and 95 deletions

View File

@@ -0,0 +1,120 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override {
rabit::Finalize();
}
private:
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost