Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -15,9 +15,9 @@
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
@@ -27,11 +27,10 @@
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
|
||||
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
|
||||
#include "dmlc/base.h" // for BeginPtr
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "rabit/c_api.h" // for RabitLinkTag
|
||||
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||
@@ -46,10 +45,6 @@
|
||||
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
#endif
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
@@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::Init(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize() {
|
||||
API_BEGIN();
|
||||
collective::Finalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
return collective::GetRank();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||
return collective::GetWorldSize();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||
return collective::IsDistributed();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
|
||||
char const *server_cert_path, char const *client_cert_path) {
|
||||
API_BEGIN();
|
||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Run a server without SSL for local testing.
|
||||
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
|
||||
API_BEGIN();
|
||||
federated::RunInsecureServer(port, world_size);
|
||||
API_END();
|
||||
}
|
||||
#endif
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file c_api_error.cc
|
||||
* \brief C error handling
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include "xgboost/c_api.h"
|
||||
#include "./c_api_error.h"
|
||||
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
#include "../collective/comm.h"
|
||||
#include "../collective/comm_group.h"
|
||||
|
||||
struct XGBAPIErrorEntry {
|
||||
std::string last_error;
|
||||
std::int32_t code{-1};
|
||||
};
|
||||
|
||||
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
|
||||
|
||||
XGB_DLL const char *XGBGetLastError() {
|
||||
return XGBAPIErrorStore::Get()->last_error.c_str();
|
||||
}
|
||||
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
|
||||
|
||||
void XGBAPISetLastError(const char* msg) {
|
||||
XGBAPIErrorStore::Get()->last_error = msg;
|
||||
XGBAPIErrorStore::Get()->code = -1;
|
||||
}
|
||||
|
||||
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#ifdef LOG_CAPI_INVOCATION
|
||||
@@ -30,7 +31,7 @@
|
||||
#define API_END() \
|
||||
} catch (dmlc::Error & _except_) { \
|
||||
return XGBAPIHandleException(_except_); \
|
||||
} catch (std::exception const &_except_) { \
|
||||
} catch (std::exception const& _except_) { \
|
||||
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
|
||||
} \
|
||||
return 0; // NOLINT(*)
|
||||
@@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
|
||||
* \param e the exception
|
||||
* \return the return value of API after exception is handled
|
||||
*/
|
||||
inline int XGBAPIHandleException(const dmlc::Error &e) {
|
||||
inline int XGBAPIHandleException(const dmlc::Error& e) {
|
||||
XGBAPISetLastError(e.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "../collective/allgather.h" // for Allgather
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../collective/broadcast.h" // for Broadcast
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/comm_group.h" // for GlobalCommGroup
|
||||
#include "../collective/communicator-inl.h" // for GetProcessorName
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
@@ -20,10 +25,36 @@
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
|
||||
#else
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
|
||||
Context ctx;
|
||||
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
|
||||
using T = decltype(t);
|
||||
auto data = linalg::MakeTensorView(
|
||||
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
|
||||
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
|
||||
SafeColl(rc);
|
||||
});
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
|
||||
Context ctx;
|
||||
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
|
||||
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Allgather(void *send_receive_buffer, std::size_t size) {
|
||||
Context ctx;
|
||||
auto const &comm = GlobalCommGroup();
|
||||
auto rc = Allgather(&ctx, *comm,
|
||||
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
|
||||
SafeColl(rc);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
@@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
|
||||
: kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
@@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
@@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
// Make sure no one else is waiting on the tracker.
|
||||
while (!ptr->first.unique()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > ptr->first->Timeout().count()) {
|
||||
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
|
||||
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||
break;
|
||||
@@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const *json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::GlobalCommGroupInit(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize(void) {
|
||||
API_BEGIN();
|
||||
collective::GlobalCommGroupFinalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
API_BEGIN();
|
||||
return collective::GetRank();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto &local = *CollAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not exposed to the public since the previous implementation didn't and we don't want to
|
||||
// add unnecessary communicator API to a machine learning library.
|
||||
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
|
||||
API_BEGIN();
|
||||
collective::Allgather(send_receive_buffer, count);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not yet exposed to the public, error recovery is still WIP.
|
||||
XGB_DLL int XGCommunicatorSignalError() {
|
||||
API_BEGIN();
|
||||
auto msg = XGBGetLastError();
|
||||
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
|
||||
API_END()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user