Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -15,9 +15,9 @@
#include <utility> // for pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/error_msg.h" // for NoFederated
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
@@ -27,11 +27,10 @@
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "rabit/c_api.h" // for RabitLinkTag
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
@@ -46,10 +45,6 @@
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
#endif
using namespace xgboost; // NOLINT(*);
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
@@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
*out_features = dmlc::BeginPtr(feature_names_c);
API_END();
}
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize() {
API_BEGIN();
collective::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return collective::GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return collective::GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return collective::IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}
// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
}
#endif
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@@ -1,22 +1,28 @@
/*!
* Copyright (c) 2015 by Contributors
/**
* Copyright 2015-2023, XGBoost Contributors
* \file c_api_error.cc
* \brief C error handling
*/
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "./c_api_error.h"
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "../collective/comm.h"
#include "../collective/comm_group.h"
struct XGBAPIErrorEntry {
std::string last_error;
std::int32_t code{-1};
};
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
XGB_DLL const char *XGBGetLastError() {
return XGBAPIErrorStore::Get()->last_error.c_str();
}
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
void XGBAPISetLastError(const char* msg) {
XGBAPIErrorStore::Get()->last_error = msg;
XGBAPIErrorStore::Get()->code = -1;
}
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }

View File

@@ -10,6 +10,7 @@
#include <dmlc/logging.h>
#include "c_api_utils.h"
#include "xgboost/collective/result.h"
/*! \brief macro to guard beginning and end section of all functions */
#ifdef LOG_CAPI_INVOCATION
@@ -30,7 +31,7 @@
#define API_END() \
} catch (dmlc::Error & _except_) { \
return XGBAPIHandleException(_except_); \
} catch (std::exception const &_except_) { \
} catch (std::exception const& _except_) { \
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
} \
return 0; // NOLINT(*)
@@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int XGBAPIHandleException(const dmlc::Error &e) {
inline int XGBAPIHandleException(const dmlc::Error& e) {
XGBAPISetLastError(e.what());
return -1;
}

View File

@@ -9,10 +9,15 @@
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "../collective/allgather.h" // for Allgather
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/broadcast.h" // for Broadcast
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/comm_group.h" // for GlobalCommGroup
#include "../collective/communicator-inl.h" // for GetProcessorName
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
@@ -20,10 +25,36 @@
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
#else
#include "../common/error_msg.h" // for NoFederated
#endif
namespace xgboost::collective {
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
Context ctx;
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
using T = decltype(t);
auto data = linalg::MakeTensorView(
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
SafeColl(rc);
});
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
Context ctx;
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
SafeColl(rc);
}
void Allgather(void *send_receive_buffer, std::size_t size) {
Context ctx;
auto const &comm = GlobalCommGroup();
auto rc = Allgather(&ctx, *comm,
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
SafeColl(rc);
}
} // namespace xgboost::collective
using namespace xgboost; // NOLINT
namespace {
@@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
: kDft};
common::Timer timer;
timer.Start();
@@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
break;
}
if (timer.Duration() > timeout && timeout.count() != 0) {
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
@@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
@@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
delete ptr;
API_END();
}
XGB_DLL int XGCommunicatorInit(char const *json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::GlobalCommGroupInit(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
collective::GlobalCommGroupFinalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
API_BEGIN();
return collective::GetRank();
API_END();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto &local = *CollAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
// Not exposed to the public since the previous implementation didn't and we don't want to
// add unnecessary communicator API to a machine learning library.
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
API_BEGIN();
collective::Allgather(send_receive_buffer, count);
API_END();
}
// Not yet exposed to the public, error recovery is still WIP.
XGB_DLL int XGCommunicatorSignalError() {
API_BEGIN();
auto msg = XGBGetLastError();
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
API_END()
}