Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -15,9 +15,9 @@
|
||||
#include <utility> // for pair
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
@@ -27,11 +27,10 @@
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
|
||||
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
|
||||
#include "dmlc/base.h" // for BeginPtr
|
||||
#include "dmlc/io.h" // for Stream
|
||||
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
|
||||
#include "dmlc/thread_local.h" // for ThreadLocalStore
|
||||
#include "rabit/c_api.h" // for RabitLinkTag
|
||||
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
|
||||
@@ -46,10 +45,6 @@
|
||||
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_server.h"
|
||||
#endif
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
|
||||
@@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
|
||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::Init(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize() {
|
||||
API_BEGIN();
|
||||
collective::Finalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
return collective::GetRank();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||
return collective::GetWorldSize();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||
return collective::IsDistributed();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
|
||||
char const *server_cert_path, char const *client_cert_path) {
|
||||
API_BEGIN();
|
||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Run a server without SSL for local testing.
|
||||
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
|
||||
API_BEGIN();
|
||||
federated::RunInsecureServer(port, world_size);
|
||||
API_END();
|
||||
}
|
||||
#endif
|
||||
|
||||
// force link rabit
|
||||
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file c_api_error.cc
|
||||
* \brief C error handling
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include "xgboost/c_api.h"
|
||||
#include "./c_api_error.h"
|
||||
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include "xgboost/c_api.h"
|
||||
#include "../collective/comm.h"
|
||||
#include "../collective/comm_group.h"
|
||||
|
||||
struct XGBAPIErrorEntry {
|
||||
std::string last_error;
|
||||
std::int32_t code{-1};
|
||||
};
|
||||
|
||||
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
|
||||
|
||||
XGB_DLL const char *XGBGetLastError() {
|
||||
return XGBAPIErrorStore::Get()->last_error.c_str();
|
||||
}
|
||||
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
|
||||
|
||||
void XGBAPISetLastError(const char* msg) {
|
||||
XGBAPIErrorStore::Get()->last_error = msg;
|
||||
XGBAPIErrorStore::Get()->code = -1;
|
||||
}
|
||||
|
||||
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#ifdef LOG_CAPI_INVOCATION
|
||||
@@ -30,7 +31,7 @@
|
||||
#define API_END() \
|
||||
} catch (dmlc::Error & _except_) { \
|
||||
return XGBAPIHandleException(_except_); \
|
||||
} catch (std::exception const &_except_) { \
|
||||
} catch (std::exception const& _except_) { \
|
||||
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
|
||||
} \
|
||||
return 0; // NOLINT(*)
|
||||
@@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
|
||||
* \param e the exception
|
||||
* \return the return value of API after exception is handled
|
||||
*/
|
||||
inline int XGBAPIHandleException(const dmlc::Error &e) {
|
||||
inline int XGBAPIHandleException(const dmlc::Error& e) {
|
||||
XGBAPISetLastError(e.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "../collective/allgather.h" // for Allgather
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../collective/broadcast.h" // for Broadcast
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/comm_group.h" // for GlobalCommGroup
|
||||
#include "../collective/communicator-inl.h" // for GetProcessorName
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
@@ -20,10 +25,36 @@
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
|
||||
#else
|
||||
#include "../common/error_msg.h" // for NoFederated
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
|
||||
Context ctx;
|
||||
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
|
||||
using T = decltype(t);
|
||||
auto data = linalg::MakeTensorView(
|
||||
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
|
||||
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
|
||||
SafeColl(rc);
|
||||
});
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
|
||||
Context ctx;
|
||||
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
|
||||
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Allgather(void *send_receive_buffer, std::size_t size) {
|
||||
Context ctx;
|
||||
auto const &comm = GlobalCommGroup();
|
||||
auto rc = Allgather(&ctx, *comm,
|
||||
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
|
||||
SafeColl(rc);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
@@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
|
||||
: kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
@@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
@@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
// Make sure no one else is waiting on the tracker.
|
||||
while (!ptr->first.unique()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > ptr->first->Timeout().count()) {
|
||||
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
|
||||
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||
break;
|
||||
@@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorInit(char const *json_config) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||
Json config{Json::Load(StringView{json_config})};
|
||||
collective::GlobalCommGroupInit(config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorFinalize(void) {
|
||||
API_BEGIN();
|
||||
collective::GlobalCommGroupFinalize();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||
API_BEGIN();
|
||||
return collective::GetRank();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
|
||||
|
||||
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||
API_BEGIN();
|
||||
collective::Print(message);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||
API_BEGIN();
|
||||
auto &local = *CollAPIThreadLocalStore::Get();
|
||||
local.ret_str = collective::GetProcessorName();
|
||||
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||
*name_str = local.ret_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
API_BEGIN();
|
||||
collective::Broadcast(send_receive_buffer, size, root);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||
int enum_op) {
|
||||
API_BEGIN();
|
||||
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not exposed to the public since the previous implementation didn't and we don't want to
|
||||
// add unnecessary communicator API to a machine learning library.
|
||||
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
|
||||
API_BEGIN();
|
||||
collective::Allgather(send_receive_buffer, count);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// Not yet exposed to the public, error recovery is still WIP.
|
||||
XGB_DLL int XGCommunicatorSignalError() {
|
||||
API_BEGIN();
|
||||
auto msg = XGBGetLastError();
|
||||
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
|
||||
API_END()
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/common.h"
|
||||
#include "common/config.h"
|
||||
#include "common/io.h"
|
||||
@@ -193,10 +192,6 @@ class CLI {
|
||||
|
||||
void CLITrain() {
|
||||
const double tstart_data_load = dmlc::GetTime();
|
||||
if (collective::IsDistributed()) {
|
||||
std::string pname = collective::GetProcessorName();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
@@ -235,15 +230,9 @@ class CLI {
|
||||
version += 1;
|
||||
}
|
||||
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (collective::IsDistributed()) {
|
||||
if (collective::GetRank() == 0) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
|
||||
collective::GetRank() == 0) {
|
||||
LOG(CONSOLE) << res;
|
||||
|
||||
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) {
|
||||
std::ostringstream os;
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
@@ -256,8 +245,7 @@ class CLI {
|
||||
<< " sec";
|
||||
// always save final round
|
||||
if ((param_.save_period == 0 ||
|
||||
param_.num_round % param_.save_period != 0) &&
|
||||
collective::GetRank() == 0) {
|
||||
param_.num_round % param_.save_period != 0)) {
|
||||
std::ostringstream os;
|
||||
if (param_.model_out == CLIParam::kNull) {
|
||||
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
|
||||
@@ -465,13 +453,6 @@ class CLI {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the collective communicator.
|
||||
Json json{JsonObject()};
|
||||
for (auto& kv : cfg) {
|
||||
json[kv.first] = String(kv.second);
|
||||
}
|
||||
collective::Init(json);
|
||||
|
||||
param_.Configure(cfg);
|
||||
}
|
||||
|
||||
@@ -507,10 +488,6 @@ class CLI {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
~CLI() {
|
||||
collective::Finalize();
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
@@ -13,7 +13,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
#include "allreduce.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@@ -24,15 +25,17 @@ namespace xgboost::collective {
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
*
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param device The device id.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -11,11 +11,44 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "allreduce.h"
|
||||
#include "broadcast.h"
|
||||
#include "comm.h"
|
||||
#include "communicator-inl.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/data.h" // for MetaINfo
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
template <typename Fn>
|
||||
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
|
||||
std::string msg;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
fn();
|
||||
} catch (dmlc::Error const& e) {
|
||||
msg = e.what();
|
||||
}
|
||||
}
|
||||
std::size_t msg_size{msg.size()};
|
||||
auto rc = Success() << [&] {
|
||||
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
|
||||
return rc;
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
msg.resize(msg_size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
return rc;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
@@ -30,29 +63,19 @@ namespace xgboost::collective {
|
||||
* @param size The size of the buffer.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename FN>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
FN&& function) {
|
||||
template <typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<FN>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (message.empty()) {
|
||||
collective::Broadcast(buffer, size, 0);
|
||||
} else {
|
||||
LOG(FATAL) << &message[0];
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and
|
||||
// result broadcast to other workers.
|
||||
return collective::Broadcast(
|
||||
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<FN>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Function&& function) {
|
||||
template <typename T, typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn);
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
std::size_t size{result->Size()};
|
||||
rc = std::move(rc) << [&] {
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
|
||||
} << [&] {
|
||||
result->Resize(size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
|
||||
* @return The global max of the input.
|
||||
*/
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
|
||||
MetaInfo const& info,
|
||||
T value) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(&value, 1);
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
|
||||
SafeColl(rc);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
|
||||
return GlobalSum(ctx, info, values->data(), values->size());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global ratio of the given two values across all workers.
|
||||
*
|
||||
|
||||
@@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
@@ -102,7 +103,7 @@ namespace detail {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
|
||||
@@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
|
||||
auto rc = RingAllgather(comm, typed);
|
||||
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allreduce small failed.", std::move(rc));
|
||||
}
|
||||
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||
CHECK_EQ(first.size(), data.size());
|
||||
@@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
@@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
@@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring Allreduce failed.", std::move(rc));
|
||||
}
|
||||
|
||||
auto prev = BootstrapPrev(comm.Rank(), comm.World());
|
||||
|
||||
@@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
}
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.");
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.", std::move(rc));
|
||||
}
|
||||
workers[r] = std::move(worker);
|
||||
}
|
||||
@@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
return rc;
|
||||
}
|
||||
std::int32_t rank{-1};
|
||||
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to recv rank.");
|
||||
}
|
||||
workers[rank] = std::move(peer);
|
||||
|
||||
@@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const {
|
||||
CHECK(loop_);
|
||||
loop_->Submit(op);
|
||||
loop_->Submit(std::move(op));
|
||||
}
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ CommGroup::CommGroup()
|
||||
// Common args
|
||||
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
|
||||
auto timeout =
|
||||
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
@@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
|
||||
sptr.reset();
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Init(Json const& config) { GlobalCommGroupInit(config); }
|
||||
|
||||
void Finalize() { GlobalCommGroupFinalize(); }
|
||||
|
||||
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
|
||||
|
||||
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
|
||||
|
||||
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] bool IsFederated() {
|
||||
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
|
||||
}
|
||||
|
||||
void Print(std::string const& message) {
|
||||
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() {
|
||||
std::string out;
|
||||
auto rc = GlobalCommGroup()->ProcessorName(&out);
|
||||
SafeColl(rc);
|
||||
return out;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost contributors
|
||||
*/
|
||||
#include "communicator-inl.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input) {
|
||||
auto n_inputs = input.size();
|
||||
std::vector<std::int64_t> sizes(n_inputs);
|
||||
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||
[](auto const &vec) { return vec.size(); });
|
||||
|
||||
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
|
||||
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||
offset[0] = 0;
|
||||
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||
}
|
||||
|
||||
std::vector<char> collected;
|
||||
for (auto const &vec : input) {
|
||||
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||
}
|
||||
auto out = AllgatherV(collected);
|
||||
|
||||
std::vector<std::vector<char>> result;
|
||||
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||
result.emplace_back(std::move(local));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,95 +0,0 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -3,308 +3,63 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Initialize the collective communicator.
|
||||
*/
|
||||
void Init(Json const& config);
|
||||
|
||||
/**
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const &config) { Communicator::Init(config); }
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
* @brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*/
|
||||
inline void Finalize() { Communicator::Finalize(); }
|
||||
void Finalize();
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
/**
|
||||
* @brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
* @return Rank of the worker.
|
||||
*/
|
||||
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||
[[nodiscard]] std::int32_t GetRank() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
/**
|
||||
* @brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
* @return Total world size.
|
||||
*/
|
||||
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
/**
|
||||
* @brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
* @return True if the communicator is distributed.
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
[[nodiscard]] bool IsDistributed() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is federated.
|
||||
/**
|
||||
* @brief Get if the communicator is federated.
|
||||
*
|
||||
* \return True if the communicator is federated.
|
||||
* @return True if the communicator is federated.
|
||||
*/
|
||||
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
|
||||
[[nodiscard]] bool IsFederated();
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
/**
|
||||
* @brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* @param message The message to be printed.
|
||||
*/
|
||||
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||
|
||||
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
void Print(std::string const& message);
|
||||
/**
|
||||
* @brief Gathers a single value all processes and distributes the result to all processes.
|
||||
* @brief Get the name of the processor.
|
||||
*
|
||||
* @param input The single value.
|
||||
* @return Name of the processor.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(T const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(std::vector<T> const &input) {
|
||||
if (input.empty()) {
|
||||
return input;
|
||||
}
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGatherV(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
if (!output.empty()) {
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||
* across different workers. Along with which, the size of each vector in
|
||||
* the input can also vary.
|
||||
*
|
||||
* @return The AllgatherV result, containing vectors from all workers.
|
||||
*/
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
|
||||
std::size_t total_size{0};
|
||||
for (auto const &s : input) {
|
||||
total_size += s.length() + 1; // +1 for null-terminators
|
||||
}
|
||||
std::string flat_string;
|
||||
flat_string.reserve(total_size);
|
||||
for (auto const &s : input) {
|
||||
flat_string.append(s);
|
||||
flat_string.push_back('\0'); // Append a null-terminator after each string
|
||||
}
|
||||
|
||||
auto const output = Communicator::Get()->AllGatherV(flat_string);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::size_t start_index = 0;
|
||||
// Iterate through the output, find each null-terminated substring.
|
||||
for (std::size_t i = 0; i < output.size(); i++) {
|
||||
if (output[i] == '\0') {
|
||||
// Construct a std::string from the char* substring
|
||||
result.emplace_back(&output[start_index]);
|
||||
// Move to the next substring
|
||||
start_index = i + 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*/
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||
static_cast<Operation>(op));
|
||||
}
|
||||
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
template <Operation op, typename T,
|
||||
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||
static_assert(sizeof(T) == sizeof(uint64_t));
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
std::string GetProcessorName();
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
thread_local std::string Communicator::nccl_path_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
|
||||
nccl_path_ = nccl;
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
type = arg;
|
||||
}
|
||||
if (type == CommunicatorType::kUnknown) {
|
||||
// Default to Rabit if unspecified.
|
||||
type = CommunicatorType::kRabit;
|
||||
}
|
||||
type_ = type;
|
||||
switch (type) {
|
||||
case CommunicatorType::kRabit: {
|
||||
communicator_.reset(RabitCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kFederated: {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
communicator_.reset(FederatedCommunicator::Create(config));
|
||||
#else
|
||||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kInMemory:
|
||||
case CommunicatorType::kInMemoryNccl: {
|
||||
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kUnknown:
|
||||
LOG(FATAL) << "Unknown communicator type.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,54 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#include "noop_communicator.h"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
thread_local auto old_device_ordinal = -1;
|
||||
// If the number of GPUs changes, we need to re-initialize NCCL.
|
||||
thread_local auto old_world_size = -1;
|
||||
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
||||
communicator_->GetWorldSize() != old_world_size) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,247 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/** @brief Defines the integral and floating data types. */
|
||||
enum class DataType {
|
||||
kInt8 = 0,
|
||||
kUInt8 = 1,
|
||||
kInt32 = 2,
|
||||
kUInt32 = 3,
|
||||
kInt64 = 4,
|
||||
kUInt64 = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation {
|
||||
kMax = 0,
|
||||
kMin = 1,
|
||||
kSum = 2,
|
||||
kBitwiseAND = 3,
|
||||
kBitwiseOR = 4,
|
||||
kBitwiseXOR = 5
|
||||
};
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A communicator class that handles collective communication.
|
||||
*/
|
||||
class Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Initialize the communicator. This can only be done once.
|
||||
*
|
||||
* @param config JSON configuration for the communicator.
|
||||
*/
|
||||
static void Init(Json const &config);
|
||||
|
||||
/** @brief Finalize the communicator. */
|
||||
static void Finalize();
|
||||
|
||||
/** @brief Get the communicator instance. */
|
||||
static Communicator *Get() { return communicator_.get(); }
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
/**
|
||||
* @brief Get the device communicator.
|
||||
*
|
||||
* @param device_ordinal ID of the device.
|
||||
* @return An instance of device communicator.
|
||||
*/
|
||||
static DeviceCommunicator *GetDevice(int device_ordinal);
|
||||
#endif
|
||||
|
||||
virtual ~Communicator() = default;
|
||||
|
||||
/** @brief Get the total number of processes. */
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
|
||||
/** @brief Get the rank of the current processes. */
|
||||
int GetRank() const { return rank_; }
|
||||
|
||||
/** @brief Whether the communicator is running in distributed mode. */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
|
||||
/** @brief Whether the communicator is running in federated mode. */
|
||||
virtual bool IsFederated() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGather(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGatherV(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
|
||||
* group.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root Rank of broadcast root.
|
||||
*/
|
||||
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the name of the processor.
|
||||
*/
|
||||
virtual std::string GetProcessorName() = 0;
|
||||
|
||||
/**
|
||||
* @brief Prints the message.
|
||||
*/
|
||||
virtual void Print(std::string const &message) = 0;
|
||||
|
||||
/** @brief Get the communicator type from environment variables. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromEnv() {
|
||||
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
|
||||
if (env != nullptr) {
|
||||
return StringToType(env);
|
||||
} else {
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromConfig(Json const &config) {
|
||||
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
|
||||
if (IsA<String const>(j_upper)) {
|
||||
return StringToType(get<String const>(j_upper).c_str());
|
||||
}
|
||||
auto const &j_lower = config["xgboost_communicator"];
|
||||
if (IsA<String const>(j_lower)) {
|
||||
return StringToType(get<String const>(j_lower).c_str());
|
||||
}
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Construct a new communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
*/
|
||||
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
|
||||
if (world_size < 1) {
|
||||
LOG(FATAL) << "World size " << world_size << " is less than 1.";
|
||||
}
|
||||
if (rank < 0) {
|
||||
LOG(FATAL) << "Rank " << rank << " is less than 0.";
|
||||
}
|
||||
if (rank >= world_size) {
|
||||
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuts down the communicator.
|
||||
*/
|
||||
virtual void Shutdown() = 0;
|
||||
|
||||
private:
|
||||
static CommunicatorType StringToType(char const *str) {
|
||||
CommunicatorType result = CommunicatorType::kUnknown;
|
||||
if (!CompareStringsCaseInsensitive("rabit", str)) {
|
||||
result = CommunicatorType::kRabit;
|
||||
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||
result = CommunicatorType::kFederated;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
|
||||
result = CommunicatorType::kInMemory;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
|
||||
result = CommunicatorType::kInMemoryNccl;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator type " << str;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
static thread_local std::string nccl_path_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,57 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Collective communicator for device buffers.
|
||||
*/
|
||||
class DeviceCommunicator {
|
||||
public:
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) = 0;
|
||||
|
||||
/** @brief Synchronize device operations. */
|
||||
virtual void Synchronize() = 0;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,94 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
explicit DeviceCommunicatorAdapter(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.resize(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
|
||||
auto const output = Allgather(host_buffer_);
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.resize(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank_) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
// Noop.
|
||||
}
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,12 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
InMemoryHandler InMemoryCommunicator::handler_{};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -15,14 +15,14 @@ namespace collective {
|
||||
/**
|
||||
* An in-memory communicator, useful for testing.
|
||||
*/
|
||||
class InMemoryCommunicator : public Communicator {
|
||||
class InMemoryCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator* Create(Json const& config) {
|
||||
static InMemoryCommunicator* Create(Json const& config) {
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
|
||||
@@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
|
||||
return new InMemoryCommunicator(world_size, rank);
|
||||
}
|
||||
|
||||
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
|
||||
InMemoryCommunicator(int world_size, int rank) {
|
||||
handler_.Init(world_size, rank);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_handler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Functor for allgather.
|
||||
*/
|
||||
@@ -16,7 +15,7 @@ class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(std::size_t world_size, std::size_t rank)
|
||||
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
|
||||
: world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@@ -30,8 +29,8 @@ class AllgatherFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -41,13 +40,13 @@ class AllgatherVFunctor {
|
||||
public:
|
||||
std::string const name{"AllgatherV"};
|
||||
|
||||
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
|
||||
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
|
||||
std::map<std::size_t, std::string_view>* data)
|
||||
: world_size_{world_size}, rank_{rank}, data_{data} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
data_->emplace(rank_, std::string_view{input, bytes});
|
||||
if (data_->size() == world_size_) {
|
||||
if (data_->size() == static_cast<std::size_t>(world_size_)) {
|
||||
for (auto const& kv : *data_) {
|
||||
buffer->append(kv.second);
|
||||
}
|
||||
@@ -56,8 +55,8 @@ class AllgatherVFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
std::map<std::size_t, std::string_view>* data_;
|
||||
};
|
||||
|
||||
@@ -68,7 +67,7 @@ class AllreduceFunctor {
|
||||
public:
|
||||
std::string const name{"Allreduce"};
|
||||
|
||||
AllreduceFunctor(DataType dataType, Operation operation)
|
||||
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
|
||||
: data_type_{dataType}, operation_{operation} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@@ -76,23 +75,23 @@ class AllreduceFunctor {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
|
||||
// Apply the reduce_operation to the input and the buffer.
|
||||
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
|
||||
Accumulate(input, bytes / n_bytes_type, &buffer->front());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
|
||||
Operation reduce_operation) const {
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kBitwiseAND:
|
||||
case Op::kBitwiseAND:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
case Op::kBitwiseOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseXOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
|
||||
break;
|
||||
default:
|
||||
@@ -101,27 +100,27 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kMax:
|
||||
case Op::kMax:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::max(a, b); });
|
||||
break;
|
||||
case Operation::kMin:
|
||||
case Op::kMin:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::min(a, b); });
|
||||
break;
|
||||
case Operation::kSum:
|
||||
case Op::kSum:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseAND:
|
||||
case Op::kBitwiseOR:
|
||||
case Op::kBitwiseXOR:
|
||||
AccumulateBitwise(buffer, input, size, reduce_operation);
|
||||
break;
|
||||
default:
|
||||
@@ -130,36 +129,37 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
void Accumulate(char const* input, std::size_t size, char* buffer) const {
|
||||
using Type = ArrayInterfaceHandler::Type;
|
||||
switch (data_type_) {
|
||||
case DataType::kInt8:
|
||||
case Type::kI1:
|
||||
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
|
||||
reinterpret_cast<std::int8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
case Type::kU1:
|
||||
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
|
||||
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
case Type::kI4:
|
||||
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
|
||||
reinterpret_cast<std::int32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
case Type::kU4:
|
||||
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
|
||||
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
case Type::kI8:
|
||||
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
|
||||
reinterpret_cast<std::int64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
case Type::kU8:
|
||||
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
|
||||
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
case Type::kF4:
|
||||
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
case Type::kF8:
|
||||
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
@@ -169,8 +169,8 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
Operation operation_;
|
||||
ArrayInterfaceHandler::Type data_type_;
|
||||
Op operation_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -180,7 +180,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
|
||||
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@@ -190,11 +190,11 @@ class BroadcastFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t rank_;
|
||||
std::size_t root_;
|
||||
std::int32_t rank_;
|
||||
std::int32_t root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type,
|
||||
Operation op) {
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root) {
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank,
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
@@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
#include "../data/array_interface.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Handles collective communication primitives in memory.
|
||||
*
|
||||
@@ -28,12 +27,11 @@ class InMemoryHandler {
|
||||
|
||||
/**
|
||||
* @brief Construct a handler with the given world size.
|
||||
* @param world_size Number of workers.
|
||||
* @param world Number of workers.
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(std::int32_t worldSize)
|
||||
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
@@ -43,7 +41,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(std::size_t world_size, std::size_t rank);
|
||||
void Init(std::int32_t world_size, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
@@ -53,7 +51,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, std::size_t rank);
|
||||
void Shutdown(uint64_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
@@ -64,7 +62,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform variable-length allgather.
|
||||
@@ -75,7 +73,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
@@ -88,7 +86,8 @@ class InMemoryHandler {
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
@@ -100,7 +99,7 @@ class InMemoryHandler {
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root);
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
|
||||
|
||||
private:
|
||||
/**
|
||||
@@ -115,17 +114,15 @@ class InMemoryHandler {
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
std::size_t rank, HandlerFunctor const& functor);
|
||||
std::int32_t rank, HandlerFunctor const& functor);
|
||||
|
||||
std::size_t world_size_{}; /// Number of workers.
|
||||
std::size_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::size_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::int32_t world_size_{}; /// Number of workers.
|
||||
std::int64_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception, current_exception, rethrow_exception
|
||||
#include <future> // for promise
|
||||
#include <memory> // for make_shared
|
||||
#include <mutex> // for lock_guard, unique_lock
|
||||
#include <queue> // for queue
|
||||
#include <string> // for string
|
||||
@@ -18,9 +20,10 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
auto error = [this](Op op) {
|
||||
op.pr->set_value();
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
@@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
|
||||
// Iterate through all the ops for poll
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
switch (op.code) {
|
||||
@@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
qcopy.push(op);
|
||||
qcopy.push(std::move(op));
|
||||
}
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
@@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
timer_.Stop(__func__);
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
// We wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
// Iterate through all the ops for performing the operations
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
@@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
error(op);
|
||||
return Fail("Encountered EOF. The other end is likely closed.",
|
||||
op.sock->GetSockError());
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
error(op);
|
||||
return rc;
|
||||
}
|
||||
|
||||
@@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
// not yet finished, push back to queue for the next round.
|
||||
qcopy.push(op);
|
||||
} else {
|
||||
op.pr->set_value();
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@@ -148,8 +150,7 @@ void Loop::Process() {
|
||||
};
|
||||
|
||||
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
|
||||
// answer the blocking call even if there are errors, otherwise the blocking will wait
|
||||
// forever.
|
||||
// answer the call even if there are errors.
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
@@ -170,44 +171,15 @@ void Loop::Process() {
|
||||
// Move the global queue into a local variable to unblock it.
|
||||
std::queue<Op> qcopy;
|
||||
|
||||
bool is_blocking = false;
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
auto op = std::move(queue_.front());
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
auto notify_if_block = [&] {
|
||||
if (is_blocking) {
|
||||
std::unique_lock lock{mu_};
|
||||
block_done_ = true;
|
||||
lock.unlock();
|
||||
block_cv_.notify_one();
|
||||
}
|
||||
};
|
||||
// Clear the local queue.
|
||||
auto rc = this->ProcessQueue(&qcopy);
|
||||
|
||||
// Handle error
|
||||
if (!rc.OK()) {
|
||||
@@ -215,8 +187,6 @@ void Loop::Process() {
|
||||
} else {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
|
||||
notify_if_block();
|
||||
} catch (std::exception const& e) {
|
||||
curr_exce_ = std::current_exception();
|
||||
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
|
||||
@@ -256,20 +226,28 @@ Result Loop::Stop() {
|
||||
stop_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->worker_.joinable()) {
|
||||
std::lock_guard<std::mutex> guard{rc_lock_};
|
||||
return Fail("Worker has stopped.", std::move(rc_));
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
|
||||
block_done_ = false;
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
for (auto& fut : futures_) {
|
||||
if (fut.valid()) {
|
||||
try {
|
||||
fut.get();
|
||||
} catch (std::future_error const&) {
|
||||
// Do nothing. If something went wrong in the worker, we have a std::future_error
|
||||
// due to broken promise. This function will transfer the rc back to the caller.
|
||||
}
|
||||
}
|
||||
}
|
||||
futures_.clear();
|
||||
|
||||
{
|
||||
// Transfer the rc.
|
||||
std::lock_guard<std::mutex> lock{rc_lock_};
|
||||
@@ -278,13 +256,13 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
op.pr = std::move(p);
|
||||
futures_.emplace_back(op.pr->get_future());
|
||||
CHECK_NE(op.n, 0);
|
||||
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
|
||||
@@ -7,9 +7,12 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <exception> // for exception_ptr
|
||||
#include <mutex> // for unique_lock, mutex
|
||||
#include <future> // for future
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -20,14 +23,15 @@ class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
std::shared_ptr<std::promise<void>> pr;
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
@@ -45,12 +49,11 @@ class Loop {
|
||||
private:
|
||||
std::thread worker_; // thread worker to execute the tasks
|
||||
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
std::condition_variable block_cv_; // CV used to notify the blocking call
|
||||
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
|
||||
std::queue<Op> queue_; // event queue
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
std::vector<std::future<void>> futures_;
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
|
||||
std::chrono::seconds timeout_;
|
||||
|
||||
@@ -61,7 +64,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
|
||||
StringView nccl_path)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world_size_)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result{ncclInt8};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Operation const &op) {
|
||||
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
|
||||
op == Operation::kBitwiseXOR;
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
auto const size = count * GetTypeSize(data_type);
|
||||
dh::caching_device_vector<char> buffer(size * world_size_);
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return stub_->GroupEnd(); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
@@ -1,91 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "nccl_stub.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new NCCL communicator.
|
||||
* @param device_ordinal The GPU device id.
|
||||
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||
*
|
||||
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||
*
|
||||
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override;
|
||||
void Synchronize() override;
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
auto rc = stub_->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
bool const needs_sync_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,32 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* A no-op communicator, used for non-distributed training.
|
||||
*/
|
||||
class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
std::string AllGather(std::string_view) override { return {}; }
|
||||
std::string AllGatherV(std::string_view) override { return {}; }
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return {}; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -41,20 +41,26 @@ struct Magic {
|
||||
|
||||
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
|
||||
std::int32_t magic{kMagic};
|
||||
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
|
||||
magic = 0;
|
||||
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
magic = 0;
|
||||
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -227,31 +233,43 @@ struct Error {
|
||||
|
||||
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||
std::int32_t err{ErrorSignal()};
|
||||
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return worker->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
};
|
||||
}
|
||||
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||
// close.
|
||||
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return self->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
};
|
||||
}
|
||||
// get signal, either for error or for shutdown.
|
||||
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
std::size_t n_recv{0};
|
||||
return Success() << [&] {
|
||||
return peer->RecvAll(&err, sizeof(err), &n_recv);
|
||||
} << [&] {
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
};
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class RabitCommunicator : public Communicator {
|
||||
public:
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::vector<std::string> args_str;
|
||||
for (auto &items : get<Object const>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kString: {
|
||||
args_str.push_back(items.first + "=" + get<String const>(items.second));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
args_str.push_back(items.first + "=1");
|
||||
} else {
|
||||
args_str.push_back(items.first + "=0");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<char *> args;
|
||||
for (auto &key_value : args_str) {
|
||||
args.push_back(&key_value[0]);
|
||||
}
|
||||
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
|
||||
LOG(FATAL) << "Failed to initialize Rabit";
|
||||
}
|
||||
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
|
||||
}
|
||||
|
||||
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
|
||||
|
||||
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
DoAllReduce<char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
DoAllReduce<float>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
DoAllReduce<double>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type";
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
rabit::Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
|
||||
|
||||
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
|
||||
|
||||
protected:
|
||||
void Shutdown() override { rabit::Finalize(); }
|
||||
|
||||
private:
|
||||
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kMin:
|
||||
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kSum:
|
||||
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
dmlc::DateLogger logger;
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
auto name = std::filesystem::path{ file }.filename();
|
||||
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
|
||||
@@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
|
||||
CHECK(!this->IsClosed());
|
||||
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
|
||||
std::int32_t len = static_cast<std::int32_t>(str.size());
|
||||
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
|
||||
auto bytes = this->SendAll(str.c_str(), str.size());
|
||||
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
|
||||
return bytes;
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = Success() << [&] {
|
||||
return this->SendAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to send string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
return this->SendAll(str.c_str(), str.size(), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != str.size()) {
|
||||
return Fail("Failed to send string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
SafeColl(rc);
|
||||
return n_bytes;
|
||||
}
|
||||
|
||||
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||
CHECK(!this->IsClosed());
|
||||
std::int32_t len;
|
||||
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
p_str->resize(len);
|
||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||
if (static_cast<decltype(len)>(bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_bytes{0};
|
||||
return Success() << [&] {
|
||||
return this->RecvAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
p_str->resize(len);
|
||||
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
|
||||
} << [&] {
|
||||
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
|
||||
@@ -31,14 +31,20 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
Tracker::Tracker(Json const& config)
|
||||
: sortby_{static_cast<SortBy>(
|
||||
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
|
||||
n_workers_{
|
||||
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
timeout_{std::chrono::seconds{
|
||||
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
|
||||
using std::chrono_literals::operator""s;
|
||||
// Some old configurations in JVM for the scala implementation (removed) use 0 to
|
||||
// indicate blocking. We continue that convention here.
|
||||
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
|
||||
}
|
||||
|
||||
Result Tracker::WaitUntilReady() const {
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
@@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
|
||||
timer.Start();
|
||||
while (!this->Ready()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > this->Timeout().count()) {
|
||||
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
|
||||
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
@@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
return listener_.NonBlocking(true);
|
||||
} << [&] {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
{
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
}
|
||||
if (state.running) {
|
||||
// Don't timeout if the communicator group is up and running.
|
||||
return poll.Poll(std::chrono::seconds{-1});
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
|
||||
/**
|
||||
*
|
||||
* @brief Implementation of RABIT tracker.
|
||||
@@ -52,7 +53,7 @@ class Tracker {
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
std::chrono::seconds timeout_{0};
|
||||
std::chrono::seconds timeout_{-1};
|
||||
std::atomic<bool> ready_{false};
|
||||
|
||||
public:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file io.h
|
||||
* \brief general stream interface for serialization, I/O
|
||||
* \author Tianqi Chen
|
||||
@@ -8,7 +8,6 @@
|
||||
#define XGBOOST_COMMON_IO_H_
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
|
||||
|
||||
#include <algorithm> // for min, fill_n, copy_n
|
||||
#include <array> // for array
|
||||
@@ -23,12 +22,99 @@
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "common.h"
|
||||
#include "common.h" // for DivRoundUp
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::common {
|
||||
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
|
||||
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
|
||||
struct MemoryFixSizeBuffer : public dmlc::SeekStream {
|
||||
public:
|
||||
// similar to SEEK_END in libc
|
||||
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Ctor
|
||||
*
|
||||
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
|
||||
* @param buffer_size Size of the source buffer
|
||||
*/
|
||||
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
|
||||
~MemoryFixSizeBuffer() override = default;
|
||||
|
||||
std::size_t Read(void *ptr, std::size_t size) override {
|
||||
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, std::size_t size) override {
|
||||
if (size == 0) return;
|
||||
CHECK_LE(curr_ptr_ + size, buffer_size_);
|
||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(std::size_t pos) override {
|
||||
if (pos == kSeekEnd) {
|
||||
curr_ptr_ = buffer_size_;
|
||||
} else {
|
||||
curr_ptr_ = static_cast<std::size_t>(pos);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Current position in the buffer (stream).
|
||||
*/
|
||||
std::size_t Tell() override { return curr_ptr_; }
|
||||
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
||||
|
||||
protected:
|
||||
/*! \brief in memory buffer */
|
||||
char *p_buffer_{nullptr};
|
||||
/*! \brief current pointer */
|
||||
std::size_t buffer_size_{0};
|
||||
/*! \brief current pointer */
|
||||
std::size_t curr_ptr_{0};
|
||||
};
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public dmlc::SeekStream {
|
||||
public:
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
~MemoryBufferStream() override = default;
|
||||
size_t Read(void *ptr, size_t size) override {
|
||||
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
|
||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
return nread;
|
||||
}
|
||||
void Write(const void *ptr, size_t size) override {
|
||||
if (size == 0) return;
|
||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||
p_buffer_->resize(curr_ptr_+size);
|
||||
}
|
||||
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
void Seek(size_t pos) override {
|
||||
curr_ptr_ = static_cast<size_t>(pos);
|
||||
}
|
||||
size_t Tell() override {
|
||||
return curr_ptr_;
|
||||
}
|
||||
virtual bool AtEnd() const {
|
||||
return curr_ptr_ == p_buffer_->length();
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief in memory buffer */
|
||||
std::string *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryBufferStream
|
||||
|
||||
/*!
|
||||
* \brief Input stream that support additional PeekRead operation,
|
||||
|
||||
@@ -116,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch)
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* \brief A view over gathered sketch values.
|
||||
* @brief A view over gathered sketch values.
|
||||
*/
|
||||
template <typename T>
|
||||
struct QuantileAllreduce {
|
||||
common::Span<T> global_values;
|
||||
common::Span<bst_idx_t> worker_indptr;
|
||||
common::Span<bst_idx_t> feature_indptr;
|
||||
size_t n_features{0};
|
||||
bst_feature_t n_features{0};
|
||||
/**
|
||||
* \brief Get sketch values of the a feature from a worker.
|
||||
* @brief Get sketch values of the a feature from a worker.
|
||||
*
|
||||
* \param rank rank of target worker
|
||||
* \param fidx feature idx
|
||||
* @param rank rank of target worker
|
||||
* @param fidx feature idx
|
||||
*/
|
||||
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
|
||||
// get span for worker
|
||||
@@ -154,7 +154,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
worker_segments.resize(1, 0);
|
||||
auto world = collective::GetWorldSize();
|
||||
auto rank = collective::GetRank();
|
||||
auto n_columns = sketches_.size();
|
||||
bst_feature_t n_columns = sketches_.size();
|
||||
|
||||
// get the size of each feature.
|
||||
std::vector<bst_idx_t> sketch_size;
|
||||
@@ -165,7 +165,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
sketch_size.push_back(reduced[i].size);
|
||||
}
|
||||
}
|
||||
// turn the size into CSC indptr
|
||||
// Turn the size into CSC indptr
|
||||
std::vector<bst_idx_t> &sketches_scan = *p_sketches_scan;
|
||||
sketches_scan.resize((n_columns + 1) * world, 0);
|
||||
size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker.
|
||||
@@ -174,7 +174,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
// Gather all column pointers
|
||||
auto rc =
|
||||
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
|
||||
collective::SafeColl(rc);
|
||||
if (!rc.OK()) {
|
||||
collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc)));
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
size_t back = (i + 1) * (n_columns + 1) - 1;
|
||||
auto n_entries = sketches_scan.at(back);
|
||||
@@ -206,7 +209,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
ctx, info,
|
||||
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
|
||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
|
||||
collective::SafeColl(rc);
|
||||
if (!rc.OK()) {
|
||||
collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
@@ -260,7 +265,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
|
||||
rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(global_categories.data(), global_categories.size()));
|
||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||
categories_.size()};
|
||||
static_cast<bst_feature_t>(categories_.size())};
|
||||
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
|
||||
if (!IsCat(feature_types_, fidx)) {
|
||||
return;
|
||||
@@ -285,8 +290,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
|
||||
size_t n_columns = sketches_.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||
bst_feature_t n_columns = sketches_.size();
|
||||
auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax);
|
||||
collective::SafeColl(rc);
|
||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||
|
||||
AllreduceCategories(ctx, info);
|
||||
@@ -300,8 +306,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
|
||||
// Prune the intermediate num cuts for synchronization.
|
||||
std::vector<bst_idx_t> global_column_size(columns_size_);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
|
||||
rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(global_column_size.data(), global_column_size.size()));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
#include <numeric> // for partial_sum
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../collective/allgather.h"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
@@ -499,7 +500,7 @@ void SketchContainer::FixError() {
|
||||
});
|
||||
}
|
||||
|
||||
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
|
||||
void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1 || is_column_split) {
|
||||
@@ -508,16 +509,18 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) {
|
||||
|
||||
timer_.Start(__func__);
|
||||
// Reduce the overhead on syncing.
|
||||
size_t global_sum_rows = num_rows_;
|
||||
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
|
||||
size_t intermediate_num_cuts =
|
||||
bst_idx_t global_sum_rows = num_rows_;
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&global_sum_rows, 1), collective::Op::kSum);
|
||||
SafeColl(rc);
|
||||
bst_idx_t intermediate_num_cuts =
|
||||
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
|
||||
this->Prune(intermediate_num_cuts);
|
||||
|
||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
|
||||
size_t n = d_columns_ptr.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n, 1);
|
||||
rc = collective::Allreduce(ctx, linalg::MakeVec(&n, 1), collective::Op::kMax);
|
||||
SafeColl(rc);
|
||||
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
|
||||
|
||||
// Get the columns ptr from all workers
|
||||
@@ -527,18 +530,25 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) {
|
||||
auto offset = rank * d_columns_ptr.size();
|
||||
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
|
||||
gathered_ptrs.begin() + offset);
|
||||
collective::AllReduce<collective::Operation::kSum>(device_.ordinal, gathered_ptrs.data().get(),
|
||||
gathered_ptrs.size());
|
||||
rc = collective::Allreduce(
|
||||
ctx, linalg::MakeVec(gathered_ptrs.data().get(), gathered_ptrs.size(), ctx->Device()),
|
||||
collective::Op::kSum);
|
||||
SafeColl(rc);
|
||||
|
||||
// Get the data from all workers.
|
||||
std::vector<size_t> recv_lengths;
|
||||
dh::caching_device_vector<char> recvbuf;
|
||||
collective::AllGatherV(device_.ordinal, this->Current().data().get(),
|
||||
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
|
||||
collective::Synchronize(device_.ordinal);
|
||||
std::vector<std::int64_t> recv_lengths;
|
||||
HostDeviceVector<std::int8_t> recvbuf;
|
||||
rc = collective::AllgatherV(
|
||||
ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_),
|
||||
&recv_lengths, &recvbuf);
|
||||
collective::SafeColl(rc);
|
||||
for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) {
|
||||
recv_lengths[i] = recv_lengths[i + 1] - recv_lengths[i];
|
||||
}
|
||||
recv_lengths.resize(recv_lengths.size() - 1);
|
||||
|
||||
// Segment the received data.
|
||||
auto s_recvbuf = dh::ToSpan(recvbuf);
|
||||
auto s_recvbuf = recvbuf.DeviceSpan();
|
||||
std::vector<Span<SketchEntry>> allworkers;
|
||||
offset = 0;
|
||||
for (int32_t i = 0; i < world; ++i) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2015-2020, XGBoost Contributors
|
||||
* Copyright 2015-2024, XGBoost Contributors
|
||||
* \file random.h
|
||||
* \brief Utility related to random.
|
||||
* \author Tianqi Chen
|
||||
@@ -19,11 +19,13 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/broadcast.h" // for Broadcast
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "algorithm.h" // ArgSort
|
||||
#include "common.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost::common {
|
||||
/*!
|
||||
@@ -227,9 +229,10 @@ class ColumnSampler {
|
||||
}
|
||||
};
|
||||
|
||||
inline auto MakeColumnSampler(Context const*) {
|
||||
inline auto MakeColumnSampler(Context const* ctx) {
|
||||
std::uint32_t seed = common::GlobalRandomEngine()();
|
||||
collective::Broadcast(&seed, sizeof(seed), 0);
|
||||
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
|
||||
collective::SafeColl(rc);
|
||||
auto cs = std::make_shared<common::ColumnSampler>(seed);
|
||||
return cs;
|
||||
}
|
||||
|
||||
@@ -615,7 +615,12 @@ auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(T) == 16) << error::NoF128();
|
||||
return dispatch(T{});
|
||||
// Avoid invalid type.
|
||||
if constexpr (sizeof(T) == 16) {
|
||||
return dispatch(T{});
|
||||
} else {
|
||||
return dispatch(double{});
|
||||
}
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
return dispatch(std::int8_t{});
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
#include <type_traits> // for remove_pointer_t, remove_reference
|
||||
|
||||
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../collective/allgather.h"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
|
||||
@@ -601,41 +602,42 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
}
|
||||
|
||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
|
||||
bool is_col_split = this->IsColumnSplit();
|
||||
|
||||
if (size != 0 && this->num_col_ != 0 && !is_col_split) {
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto elem = info[i];
|
||||
feature_type_names.emplace_back(elem);
|
||||
}
|
||||
if (IsColumnSplit()) {
|
||||
feature_type_names = collective::AllgatherStrings(feature_type_names);
|
||||
CHECK_EQ(feature_type_names.size(), num_col_)
|
||||
// Gather column info when data is split by columns
|
||||
auto gather_columns = [is_col_split, key, n_columns = this->num_col_](auto const& inputs) {
|
||||
if (is_col_split) {
|
||||
std::remove_const_t<std::remove_reference_t<decltype(inputs)>> result;
|
||||
auto rc = collective::AllgatherStrings(inputs, &result);
|
||||
collective::SafeColl(rc);
|
||||
CHECK_EQ(result.size(), n_columns)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
return result;
|
||||
}
|
||||
return inputs;
|
||||
};
|
||||
|
||||
if (StringView{key} == "feature_type") { // NOLINT
|
||||
this->feature_type_names.clear();
|
||||
std::copy(info, info + size, std::back_inserter(feature_type_names));
|
||||
feature_type_names = gather_columns(feature_type_names);
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
} else if (StringView{key} == "feature_name") { // NOLINT
|
||||
feature_names.clear();
|
||||
if (is_col_split) {
|
||||
auto const rank = collective::GetRank();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
auto elem = std::to_string(rank) + "." + info[i];
|
||||
local_feature_names.emplace_back(elem);
|
||||
}
|
||||
feature_names = collective::AllgatherStrings(local_feature_names);
|
||||
CHECK_EQ(feature_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
std::transform(info, info + size, std::back_inserter(feature_names),
|
||||
[rank](char const* elem) { return std::to_string(rank) + "." + elem; });
|
||||
} else {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
}
|
||||
std::copy(info, info + size, std::back_inserter(feature_names));
|
||||
}
|
||||
feature_names = gather_columns(feature_names);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown feature info name: " << key;
|
||||
}
|
||||
@@ -728,12 +730,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
|
||||
if (IsColumnSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||
} else {
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||
}
|
||||
void MetaInfo::SynchronizeNumberOfColumns(Context const* ctx) {
|
||||
auto op = IsColumnSplit() ? collective::Op::kSum : collective::Op::kMax;
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&num_col_, 1), op);
|
||||
collective::SafeColl(rc);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
#include <type_traits> // for underlying_type_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../collective/communicator-inl.h" // for IsDistributed
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
@@ -95,13 +96,13 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
|
||||
|
||||
namespace {
|
||||
// Synchronize feature type in case of empty DMatrix
|
||||
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
|
||||
void SyncFeatureType(Context const* ctx, std::vector<FeatureType>* p_h_ft) {
|
||||
if (!collective::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
auto& h_ft = *p_h_ft;
|
||||
auto n_ft = h_ft.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_ft, 1);
|
||||
bst_idx_t n_ft = h_ft.size();
|
||||
collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax));
|
||||
if (!h_ft.empty()) {
|
||||
// Check correct size if this is not an empty DMatrix.
|
||||
CHECK_EQ(h_ft.size(), n_ft);
|
||||
@@ -109,7 +110,8 @@ void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
|
||||
if (n_ft > 0) {
|
||||
h_ft.resize(n_ft);
|
||||
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
|
||||
collective::Allreduce<collective::Operation::kMax>(ptr, h_ft.size());
|
||||
collective::SafeColl(
|
||||
collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax));
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
@@ -175,7 +177,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
||||
// We use do while here as the first batch is fetched in ctor
|
||||
if (n_features == 0) {
|
||||
n_features = num_cols();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
|
||||
collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax));
|
||||
column_sizes.clear();
|
||||
column_sizes.resize(n_features, 0);
|
||||
info_.num_col_ = n_features;
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost contributors
|
||||
* Copyright 2020-2024, XGBoost contributors
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "gradient_index.h"
|
||||
#include "iterative_dmatrix.h"
|
||||
#include "proxy_dmatrix.cuh"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost::data {
|
||||
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
@@ -63,7 +61,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax);
|
||||
SafeColl(rc);
|
||||
this->info_.num_col_ = cols;
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
|
||||
@@ -171,12 +171,13 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
}
|
||||
if constexpr (get_value) {
|
||||
return std::invoke_result_t<
|
||||
Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
|
||||
} else {
|
||||
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (get_value) {
|
||||
return std::invoke_result_t<Fn,
|
||||
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
|
||||
} else {
|
||||
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014~2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file simple_dmatrix.cc
|
||||
* \brief the input data structure for gradient boosting
|
||||
* \author Tianqi Chen
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
|
||||
#include "../collective/allgather.h"
|
||||
#include "../common/error_msg.h" // for InconsistentMaxBin
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "adapter.h"
|
||||
@@ -76,8 +77,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
std::vector<std::uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
auto rc = collective::Allgather(ctx, linalg::MakeVec(buffer.data(), buffer.size()));
|
||||
SafeColl(rc);
|
||||
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
|
||||
if (offset == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <future> // for async
|
||||
#include <memory> // for unique_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <numeric> // for partial_sum
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair, move
|
||||
#include <vector> // for vector
|
||||
|
||||
@@ -35,7 +35,6 @@
|
||||
|
||||
#include "collective/aggregator.h" // for ApplyWithLabels
|
||||
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
|
||||
#include "collective/communicator.h" // for Operation
|
||||
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
|
||||
#include "common/common.h" // for ToString, Split
|
||||
@@ -208,7 +207,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
|
||||
}
|
||||
// sanity check
|
||||
void Validate(Context const*) {
|
||||
void Validate(Context const* ctx) {
|
||||
if (!collective::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
@@ -229,7 +228,8 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
|
||||
std::array<std::int32_t, 6> sync;
|
||||
std::copy(data.cbegin(), data.cend(), sync.begin());
|
||||
collective::Broadcast(sync.data(), sync.size(), 0);
|
||||
auto rc = collective::Broadcast(ctx, linalg::MakeVec(sync.data(), sync.size()), 0);
|
||||
collective::SafeColl(rc);
|
||||
CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin()))
|
||||
<< "Different model parameter across workers.";
|
||||
}
|
||||
@@ -754,7 +754,9 @@ class LearnerConfiguration : public Learner {
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
}
|
||||
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_feature, 1);
|
||||
auto rc =
|
||||
collective::Allreduce(&ctx_, linalg::MakeVec(&num_feature, 1), collective::Op::kMax);
|
||||
collective::SafeColl(rc);
|
||||
if (num_feature > mparam_.num_feature) {
|
||||
mparam_.num_feature = num_feature;
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2024, XGBoost Contributors
|
||||
* \file logging.cc
|
||||
* \brief Implementation of loggers.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <iostream>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
#include <string> // for string
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
|
||||
@@ -264,9 +264,14 @@ class EvalAUC : public MetricNoCache {
|
||||
info.weights_.SetDevice(ctx_->Device());
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
std::array<bst_idx_t, 2> meta{info.labels.Size(), preds.Size()};
|
||||
if (!info.IsVerticalFederated()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
|
||||
auto rc = collective::Allreduce(
|
||||
ctx_,
|
||||
linalg::MakeTensorView(DeviceOrd::CPU(), common::Span{meta.data(), meta.size()},
|
||||
meta.size()),
|
||||
collective::Op::kMax);
|
||||
collective::SafeColl(rc);
|
||||
}
|
||||
if (meta[0] == 0) {
|
||||
// Empty across all workers, which is not supported.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/copy.h> // for copy
|
||||
#include <thrust/scan.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cub/cub.cuh> // NOLINT
|
||||
#include <limits>
|
||||
@@ -11,7 +11,7 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/algorithm.cuh" // SegmentedArgSort
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
|
||||
@@ -201,13 +201,16 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
||||
});
|
||||
}
|
||||
|
||||
double ScaleClasses(Context const *ctx, common::Span<double> results,
|
||||
double ScaleClasses(Context const *ctx, bool is_column_split, common::Span<double> results,
|
||||
common::Span<double> local_area, common::Span<double> tp,
|
||||
common::Span<double> auc, size_t n_classes) {
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
// With vertical federated learning, only the root has label, other parties are not
|
||||
// evaluation metrics.
|
||||
if (collective::IsDistributed() && !(is_column_split && collective::IsFederated())) {
|
||||
std::int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
|
||||
auto rc = collective::Allreduce(
|
||||
ctx, linalg::MakeVec(results.data(), results.size(), ctx->Device()), collective::Op::kSum);
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<Pair>(
|
||||
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
|
||||
@@ -334,7 +337,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
|
||||
auto local_area = d_results.subspan(0, n_classes);
|
||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -438,7 +441,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
|
||||
tp[c] = 1.0f;
|
||||
}
|
||||
});
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
|
||||
@@ -835,7 +838,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
|
||||
InitCacheOnce<false>(predts, p_cache);
|
||||
|
||||
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
|
||||
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
|
||||
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); // NOLINT
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR.";
|
||||
size_t n_groups = info.group_ptr_.size() - 1;
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_METRIC_AUC_H_
|
||||
#define XGBOOST_METRIC_AUC_H_
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "xgboost/base.h" // bst_node_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
@@ -42,7 +40,7 @@ inline void UpdateLeafValues(Context const* ctx, std::vector<float>* p_quantiles
|
||||
auto& quantiles = *p_quantiles;
|
||||
auto const& h_node_idx = nidx;
|
||||
|
||||
size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size());
|
||||
bst_idx_t n_leaf = collective::GlobalMax(ctx, info, static_cast<bst_idx_t>(h_node_idx.size()));
|
||||
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
||||
if (quantiles.empty()) {
|
||||
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* Copyright 2017-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <algorithm> // for max, fill, min
|
||||
#include <any> // for any, any_cast
|
||||
@@ -12,7 +12,7 @@
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/categorical.h" // for IsCat, Decision
|
||||
#include "../common/common.h" // for DivRoundUp
|
||||
@@ -461,11 +461,17 @@ class ColumnSplitHelper {
|
||||
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
||||
}
|
||||
|
||||
void AllreduceBitVectors(Context const*) {
|
||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||
decision_storage_.size());
|
||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||
missing_storage_.size());
|
||||
void AllreduceBitVectors(Context const *ctx) {
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx, linalg::MakeVec(decision_storage_.data(), decision_storage_.size()),
|
||||
collective::Op::kBitwiseOR);
|
||||
} << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx, linalg::MakeVec(missing_storage_.data(), missing_storage_.size()),
|
||||
collective::Op::kBitwiseAND);
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
}
|
||||
|
||||
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
* Copyright 2017-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <thrust/copy.h>
|
||||
@@ -11,7 +11,7 @@
|
||||
#include <any> // for any, any_cast
|
||||
#include <memory>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
@@ -817,10 +817,18 @@ class ColumnSplitHelper {
|
||||
|
||||
void AllReduceBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
||||
dh::caching_device_vector<BitType>* missing_storage) const {
|
||||
collective::AllReduce<collective::Operation::kBitwiseOR>(
|
||||
ctx_->Ordinal(), decision_storage->data().get(), decision_storage->size());
|
||||
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
||||
ctx_->Ordinal(), missing_storage->data().get(), missing_storage->size());
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx_,
|
||||
linalg::MakeVec(decision_storage->data().get(), decision_storage->size(), ctx_->Device()),
|
||||
collective::Op::kBitwiseOR);
|
||||
} << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx_,
|
||||
linalg::MakeVec(missing_storage->data().get(), missing_storage->size(), ctx_->Device()),
|
||||
collective::Op::kBitwiseAND);
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
}
|
||||
|
||||
void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
/**
|
||||
* Copyright 2021-2023 XGBoost contributors
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
* \file common_row_partitioner.h
|
||||
* \brief Common partitioner logic for hist and approx methods.
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
|
||||
|
||||
#include <algorithm> // std::all_of
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <vector>
|
||||
#include <algorithm> // for all_of, fill
|
||||
#include <cinttypes> // for uint32_t
|
||||
#include <limits> // for numeric_limits
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/linalg_op.h" // cbegin
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/partition_builder.h"
|
||||
#include "hist/expand_entry.h" // CPUExpandEntry
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h" // TensorView
|
||||
#include "../collective/allreduce.h" // for Allreduce
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/linalg_op.h" // for cbegin
|
||||
#include "../common/numeric.h" // for Iota
|
||||
#include "../common/partition_builder.h" // for PartitionBuilder
|
||||
#include "../common/row_set.h" // for RowSetCollection
|
||||
#include "../common/threading_utils.h" // for ParallelFor2d
|
||||
#include "xgboost/base.h" // for bst_row_t
|
||||
#include "xgboost/collective/result.h" // for Success, SafeColl
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/linalg.h" // for TensorView
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::tree {
|
||||
|
||||
@@ -39,7 +43,7 @@ class ColumnSplitHelper {
|
||||
}
|
||||
|
||||
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
|
||||
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
|
||||
void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads,
|
||||
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
|
||||
std::vector<ExpandEntry> const& nodes,
|
||||
std::vector<int32_t> const& split_conditions, RegTree const* p_tree) {
|
||||
@@ -56,10 +60,12 @@ class ColumnSplitHelper {
|
||||
});
|
||||
|
||||
// Then aggregate the bit vectors across all the workers.
|
||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||
decision_storage_.size());
|
||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||
missing_storage_.size());
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Allreduce(ctx, &decision_storage_, collective::Op::kBitwiseOR);
|
||||
} << [&] {
|
||||
return collective::Allreduce(ctx, &missing_storage_, collective::Op::kBitwiseAND);
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
|
||||
// Finally use the bit vectors to partition the rows.
|
||||
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
|
||||
@@ -220,7 +226,7 @@ class CommonRowPartitioner {
|
||||
// Store results in intermediate buffers from partition_builder_
|
||||
if (is_col_split_) {
|
||||
column_split_helper_.Partition<BinIdxType, any_missing, any_cat>(
|
||||
space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree);
|
||||
ctx, space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree);
|
||||
} else {
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
|
||||
size_t begin = r.begin();
|
||||
|
||||
@@ -47,8 +47,10 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::GlobalSum(info, ctx->Device(), reinterpret_cast<double*>(d_sum.Values().data()),
|
||||
d_sum.Size() * 2);
|
||||
auto rc = collective::GlobalSum(ctx, info,
|
||||
linalg::MakeVec(reinterpret_cast<double*>(d_sum.Values().data()),
|
||||
d_sum.Size() * 2, ctx->Device()));
|
||||
SafeColl(rc);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost Contributors
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <algorithm> // std::max
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "../../collective/communicator-inl.cuh"
|
||||
#include "../../collective/allgather.h"
|
||||
#include "../../common/categorical.h"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "evaluate_splits.cuh"
|
||||
@@ -413,8 +413,14 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
|
||||
auto all_candidates = dh::ToSpan(all_candidate_storage);
|
||||
collective::AllGather(device_.ordinal, out_splits.data(), all_candidates.data(),
|
||||
out_splits.size() * sizeof(DeviceSplitCandidate));
|
||||
auto current_rank =
|
||||
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
|
||||
out_splits.size() * sizeof(DeviceSplitCandidate),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
auto rc = collective::Allgather(
|
||||
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
|
||||
collective::SafeColl(rc);
|
||||
|
||||
// Reduce to get the best candidate from all workers.
|
||||
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../collective/allgather.h"
|
||||
#include "../../common/categorical.h" // for CatBitField
|
||||
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
|
||||
#include "../../common/linalg_op.h" // for cbegin, cend, begin
|
||||
@@ -35,7 +36,7 @@ template <typename ExpandEntry>
|
||||
std::enable_if_t<std::is_same_v<ExpandEntry, CPUExpandEntry> ||
|
||||
std::is_same_v<ExpandEntry, MultiExpandEntry>,
|
||||
std::vector<ExpandEntry>>
|
||||
AllgatherColumnSplit(std::vector<ExpandEntry> const &entries) {
|
||||
AllgatherColumnSplit(Context const *ctx, std::vector<ExpandEntry> const &entries) {
|
||||
auto const n_entries = entries.size();
|
||||
|
||||
// First, gather all the primitive fields.
|
||||
@@ -52,7 +53,7 @@ AllgatherColumnSplit(std::vector<ExpandEntry> const &entries) {
|
||||
|
||||
serialized_entries.emplace_back(std::move(out));
|
||||
}
|
||||
auto all_serialized = collective::VectorAllgatherV(serialized_entries);
|
||||
auto all_serialized = collective::VectorAllgatherV(ctx, serialized_entries);
|
||||
CHECK_GE(all_serialized.size(), local_entries.size());
|
||||
|
||||
std::vector<ExpandEntry> all_entries(all_serialized.size());
|
||||
@@ -401,7 +402,7 @@ class HistEvaluator {
|
||||
if (is_col_split_) {
|
||||
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||
// expand entries accordingly.
|
||||
auto all_entries = AllgatherColumnSplit(entries);
|
||||
auto all_entries = AllgatherColumnSplit(ctx_, entries);
|
||||
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
|
||||
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||
entries[nidx_in_set].split.Update(
|
||||
@@ -632,7 +633,7 @@ class HistMultiEvaluator {
|
||||
if (is_col_split_) {
|
||||
// With column-wise data split, we gather the best splits from all the workers and update the
|
||||
// expand entries accordingly.
|
||||
auto all_entries = AllgatherColumnSplit(entries);
|
||||
auto all_entries = AllgatherColumnSplit(ctx_, entries);
|
||||
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
|
||||
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
|
||||
entries[nidx_in_set].split.Update(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||
@@ -7,26 +7,24 @@
|
||||
#include <algorithm> // for max
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <functional> // for function
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../collective/communicator-inl.h" // for Allreduce
|
||||
#include "../../collective/communicator.h" // for Operation
|
||||
#include "../../common/hist_util.h" // for GHistRow, ParallelGHi...
|
||||
#include "../../common/row_set.h" // for RowSetCollection
|
||||
#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d
|
||||
#include "../../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||
#include "hist_cache.h" // for BoundedHistCollection
|
||||
#include "param.h" // for HistMakerTrainParam
|
||||
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for BatchIterator, BatchSet
|
||||
#include "xgboost/linalg.h" // for MatrixView, All, Vect...
|
||||
#include "xgboost/logging.h" // for CHECK_GE
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
#include "../../collective/allreduce.h" // for Allreduce
|
||||
#include "../../common/hist_util.h" // for GHistRow, ParallelGHi...
|
||||
#include "../../common/row_set.h" // for RowSetCollection
|
||||
#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d
|
||||
#include "../../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||
#include "hist_cache.h" // for BoundedHistCollection
|
||||
#include "param.h" // for HistMakerTrainParam
|
||||
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for BatchIterator, BatchSet
|
||||
#include "xgboost/linalg.h" // for MatrixView, All, Vect...
|
||||
#include "xgboost/logging.h" // for CHECK_GE
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost::tree {
|
||||
/**
|
||||
@@ -171,7 +169,7 @@ class HistogramBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
void SyncHistogram(Context const *, RegTree const *p_tree,
|
||||
void SyncHistogram(Context const *ctx, RegTree const *p_tree,
|
||||
std::vector<bst_node_t> const &nodes_to_build,
|
||||
std::vector<bst_node_t> const &nodes_to_trick) {
|
||||
auto n_total_bins = buffer_.TotalBins();
|
||||
@@ -186,8 +184,10 @@ class HistogramBuilder {
|
||||
CHECK(!nodes_to_build.empty());
|
||||
auto first_nidx = nodes_to_build.front();
|
||||
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
|
||||
auto rc = collective::Allreduce(
|
||||
ctx, linalg::MakeVec(reinterpret_cast<double *>(this->hist_[first_nidx].data()), n),
|
||||
collective::Op::kSum);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
common::BlockedSpace2d const &subspace =
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "param.h"
|
||||
|
||||
#include <ios> // for binary
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../collective/communicator-inl.h" // for GetRank, Broadcast
|
||||
#include "../../collective/broadcast.h" // for Broadcast
|
||||
#include "../../collective/communicator-inl.h" // for GetRank
|
||||
#include "xgboost/json.h" // for Object, Json
|
||||
#include "xgboost/linalg.h" // for MakeVec
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost::tree {
|
||||
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
||||
|
||||
void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree) const {
|
||||
void HistMakerTrainParam::CheckTreesSynchronized(Context const* ctx,
|
||||
RegTree const* local_tree) const {
|
||||
if (!this->debug_synchronize) {
|
||||
return;
|
||||
}
|
||||
@@ -24,7 +28,15 @@ void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const*
|
||||
local_tree->SaveModel(&model);
|
||||
}
|
||||
Json::Dump(model, &s_model, std::ios::binary);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
|
||||
auto nchars{static_cast<std::int64_t>(s_model.size())};
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(&nchars, 1), 0);
|
||||
} << [&] {
|
||||
s_model.resize(nchars);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(s_model.data(), s_model.size()), 0);
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
|
||||
RegTree ref_tree{}; // rank 0 tree
|
||||
auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary);
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/aggregator.cuh"
|
||||
#include "../collective/broadcast.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
@@ -410,11 +410,16 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
});
|
||||
|
||||
collective::AllReduce<collective::Operation::kBitwiseOR>(
|
||||
ctx_->Ordinal(), decision_storage.data().get(), decision_storage.size());
|
||||
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
||||
ctx_->Ordinal(), missing_storage.data().get(), missing_storage.size());
|
||||
collective::Synchronize(ctx_->Ordinal());
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(decision_storage), decision_storage.size()),
|
||||
collective::Op::kBitwiseOR);
|
||||
} << [&] {
|
||||
return collective::Allreduce(
|
||||
ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(missing_storage), missing_storage.size()),
|
||||
collective::Op::kBitwiseAND);
|
||||
};
|
||||
collective::SafeColl(rc);
|
||||
|
||||
row_partitioner->UpdatePositionBatch(
|
||||
nidx, left_nidx, right_nidx, split_data,
|
||||
@@ -611,8 +616,11 @@ struct GPUHistMakerDevice {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
|
||||
collective::GlobalSum(info_, ctx_->Device(), reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms);
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx_, info_,
|
||||
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device()));
|
||||
SafeColl(rc);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
@@ -860,7 +868,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
// Synchronise the column sampling seed
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
auto rc = collective::Broadcast(
|
||||
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0);
|
||||
SafeColl(rc);
|
||||
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
||||
|
||||
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
|
||||
@@ -1001,9 +1011,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
|
||||
monitor_.Start(__func__);
|
||||
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
|
||||
// Synchronise the column sampling seed
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
|
||||
|
||||
p_last_fmat_ = p_fmat;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file updater_refresh.cc
|
||||
* \brief refresh the statistics and leaf value on the tree on the dataset
|
||||
* \author Tianqi Chen
|
||||
@@ -9,8 +9,7 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "../collective/allreduce.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
#include "./param.h"
|
||||
@@ -39,7 +38,7 @@ class TreeRefresher : public TreeUpdater {
|
||||
}
|
||||
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
|
||||
const std::vector<GradientPair> &gpair_h = gpair->Data()->ConstHostVector();
|
||||
// thread temporal space
|
||||
// Thread local variables.
|
||||
std::vector<std::vector<GradStats> > stemp;
|
||||
std::vector<RegTree::FVec> fvec_temp;
|
||||
// setup temp space for each thread
|
||||
@@ -61,9 +60,8 @@ class TreeRefresher : public TreeUpdater {
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
// to gain speedup in recovery
|
||||
auto lazy_get_stats = [&]() {
|
||||
|
||||
auto get_stats = [&]() {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
@@ -93,12 +91,17 @@ class TreeRefresher : public TreeUpdater {
|
||||
}
|
||||
});
|
||||
};
|
||||
lazy_get_stats();
|
||||
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
|
||||
stemp[0].size() * 2);
|
||||
int offset = 0;
|
||||
get_stats();
|
||||
// Synchronize the aggregated result.
|
||||
auto &sum_grad = stemp[0];
|
||||
// x2 for gradient and hessian.
|
||||
auto rc = collective::Allreduce(
|
||||
ctx_, linalg::MakeVec(&sum_grad.data()->sum_grad, sum_grad.size() * 2),
|
||||
collective::Op::kMax);
|
||||
collective::SafeColl(rc);
|
||||
bst_node_t offset = 0;
|
||||
for (auto tree : trees) {
|
||||
this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
|
||||
this->Refresh(param, dmlc::BeginPtr(sum_grad) + offset, 0, tree);
|
||||
offset += tree->NumNodes();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by XBGoost Contributors
|
||||
* Copyright 2014-2024, XBGoost Contributors
|
||||
* \file updater_sync.cc
|
||||
* \brief synchronize the tree in all distributed nodes
|
||||
*/
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/broadcast.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/io.h"
|
||||
#include "xgboost/json.h"
|
||||
@@ -44,7 +44,8 @@ class TreeSyncher : public TreeUpdater {
|
||||
}
|
||||
}
|
||||
fs.Seek(0);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
auto rc = collective::Broadcast(ctx_, linalg::MakeVec(s_model.data(), s_model.size()), 0);
|
||||
SafeColl(rc);
|
||||
for (auto tree : trees) {
|
||||
tree->Load(&fs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user