Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-09-12 15:21:12 -07:00
committed by GitHub
parent bc818316f2
commit a2686543a9
25 changed files with 1771 additions and 95 deletions

View File

@@ -22,6 +22,7 @@
#include "c_api_error.h"
#include "c_api_utils.h"
#include "../collective/communicator.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
@@ -1370,6 +1371,62 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
API_END();
}
using xgboost::collective::Communicator;
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
Json config { Json::Load(StringView{json_config}) };
Communicator::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
Communicator::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return Communicator::Get()->GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return Communicator::Get()->GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return Communicator::Get()->IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
Communicator::Get()->Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = Communicator::Get()->GetProcessorName();
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
Communicator::Get()->AllReduce(
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
static_cast<xgboost::collective::Operation>(enum_op));
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {