[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,11 +1,8 @@
|
||||
// Copyright (c) 2014-2022 by Contributors
|
||||
#include <rabit/rabit.h>
|
||||
#include <rabit/c_api.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@@ -22,12 +19,11 @@
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../collective/communicator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
@@ -215,7 +211,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
#else
|
||||
if (rabit::IsDistributed()) {
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, "
|
||||
<< "will split data among workers";
|
||||
load_row_split = true;
|
||||
@@ -1560,44 +1556,42 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
||||
API_END();
|
||||
}
|
||||
|
||||
using xgboost::collective::Communicator;
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config { Json::Load(StringView{json_config}) };
|
||||
Communicator::Init(config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::Init(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize() {
|
||||
API_BEGIN();
|
||||
Communicator::Finalize();
|
||||
collective::Finalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank() {
|
||||
return Communicator::Get()->GetRank();
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
return collective::GetRank();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize() {
|
||||
return Communicator::Get()->GetWorldSize();
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||
return collective::GetWorldSize();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed() {
|
||||
return Communicator::Get()->IsDistributed();
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||
return collective::IsDistributed();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->Print(message);
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
local.ret_str = Communicator::Get()->GetProcessorName();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
@@ -1605,16 +1599,14 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
Communicator::Get()->AllReduce(
|
||||
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
|
||||
static_cast<xgboost::collective::Operation>(enum_op));
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user