[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -1,11 +1,8 @@
// Copyright (c) 2014-2022 by Contributors
#include <rabit/rabit.h>
#include <rabit/c_api.h>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <vector>
#include <string>
#include <memory>
@@ -22,12 +19,11 @@
#include "c_api_error.h"
#include "c_api_utils.h"
#include "../collective/communicator.h"
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
#include "../data/simple_dmatrix.h"
#include "../data/proxy_dmatrix.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
@@ -215,7 +211,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
#if defined(XGBOOST_USE_FEDERATED)
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
#else
if (rabit::IsDistributed()) {
if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
load_row_split = true;
@@ -1560,44 +1556,42 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
API_END();
}
using xgboost::collective::Communicator;
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config { Json::Load(StringView{json_config}) };
Communicator::Init(config);
Json config{Json::Load(StringView{json_config})};
collective::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize() {
API_BEGIN();
Communicator::Finalize();
collective::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank() {
return Communicator::Get()->GetRank();
XGB_DLL int XGCommunicatorGetRank(void) {
return collective::GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize() {
return Communicator::Get()->GetWorldSize();
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return collective::GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed() {
return Communicator::Get()->IsDistributed();
XGB_DLL int XGCommunicatorIsDistributed(void) {
return collective::IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
Communicator::Get()->Print(message);
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = Communicator::Get()->GetProcessorName();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
@@ -1605,16 +1599,14 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
Communicator::Get()->AllReduce(
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
static_cast<xgboost::collective::Operation>(enum_op));
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}