Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -15,9 +15,9 @@
#include <utility> // for pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/error_msg.h" // for NoFederated
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
@@ -27,11 +27,10 @@
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
#include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM...
#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/io.h" // for Stream
#include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "rabit/c_api.h" // for RabitLinkTag
#include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage
@@ -46,10 +45,6 @@
#include "xgboost/string_view.h" // for StringView, operator<<
#include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS...
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_server.h"
#endif
using namespace xgboost; // NOLINT(*);
XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) {
@@ -1759,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config,
*out_features = dmlc::BeginPtr(feature_names_c);
API_END();
}
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize() {
API_BEGIN();
collective::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return collective::GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return collective::GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return collective::IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {
API_BEGIN();
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}
// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
}
#endif
// force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@@ -1,22 +1,28 @@
/*!
* Copyright (c) 2015 by Contributors
/**
* Copyright 2015-2023, XGBoost Contributors
* \file c_api_error.cc
* \brief C error handling
*/
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "./c_api_error.h"
#include <dmlc/thread_local.h>
#include "xgboost/c_api.h"
#include "../collective/comm.h"
#include "../collective/comm_group.h"
struct XGBAPIErrorEntry {
std::string last_error;
std::int32_t code{-1};
};
using XGBAPIErrorStore = dmlc::ThreadLocalStore<XGBAPIErrorEntry>;
XGB_DLL const char *XGBGetLastError() {
return XGBAPIErrorStore::Get()->last_error.c_str();
}
XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); }
void XGBAPISetLastError(const char* msg) {
XGBAPIErrorStore::Get()->last_error = msg;
XGBAPIErrorStore::Get()->code = -1;
}
XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; }

View File

@@ -10,6 +10,7 @@
#include <dmlc/logging.h>
#include "c_api_utils.h"
#include "xgboost/collective/result.h"
/*! \brief macro to guard beginning and end section of all functions */
#ifdef LOG_CAPI_INVOCATION
@@ -30,7 +31,7 @@
#define API_END() \
} catch (dmlc::Error & _except_) { \
return XGBAPIHandleException(_except_); \
} catch (std::exception const &_except_) { \
} catch (std::exception const& _except_) { \
return XGBAPIHandleException(dmlc::Error(_except_.what())); \
} \
return 0; // NOLINT(*)
@@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg);
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int XGBAPIHandleException(const dmlc::Error &e) {
inline int XGBAPIHandleException(const dmlc::Error& e) {
XGBAPISetLastError(e.what());
return -1;
}

View File

@@ -9,10 +9,15 @@
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "../collective/allgather.h" // for Allgather
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/broadcast.h" // for Broadcast
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/comm_group.h" // for GlobalCommGroup
#include "../collective/communicator-inl.h" // for GetProcessorName
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
@@ -20,10 +25,36 @@
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_tracker.h" // for FederatedTracker
#else
#include "../common/error_msg.h" // for NoFederated
#endif
namespace xgboost::collective {
void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) {
Context ctx;
DispatchDType(static_cast<ArrayInterfaceHandler::Type>(data_type), [&](auto t) {
using T = decltype(t);
auto data = linalg::MakeTensorView(
&ctx, common::Span{static_cast<T *>(send_receive_buffer), count}, count);
auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast<Op>(op));
SafeColl(rc);
});
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) {
Context ctx;
auto rc = Broadcast(&ctx, *GlobalCommGroup(),
linalg::MakeVec(static_cast<std::int8_t *>(send_receive_buffer), size), root);
SafeColl(rc);
}
void Allgather(void *send_receive_buffer, std::size_t size) {
Context ctx;
auto const &comm = GlobalCommGroup();
auto rc = Allgather(&ctx, *comm,
linalg::MakeVec(reinterpret_cast<std::int8_t *>(send_receive_buffer), size));
SafeColl(rc);
}
} // namespace xgboost::collective
using namespace xgboost; // NOLINT
namespace {
@@ -44,7 +75,8 @@ using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count())
: kDft};
common::Timer timer;
timer.Start();
@@ -62,7 +94,7 @@ void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
break;
}
if (timer.Duration() > timeout && timeout.count() != 0) {
if (timer.Duration() > timeout && collective::HasTimeout(timeout)) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
@@ -141,7 +173,7 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
@@ -151,3 +183,71 @@ XGB_DLL int XGTrackerFree(TrackerHandle handle) {
delete ptr;
API_END();
}
XGB_DLL int XGCommunicatorInit(char const *json_config) {
API_BEGIN();
xgboost_CHECK_C_ARG_PTR(json_config);
Json config{Json::Load(StringView{json_config})};
collective::GlobalCommGroupInit(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
collective::GlobalCommGroupFinalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
API_BEGIN();
return collective::GetRank();
API_END();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); }
XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); }
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
collective::Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto &local = *CollAPIThreadLocalStore::Get();
local.ret_str = collective::GetProcessorName();
xgboost_CHECK_C_ARG_PTR(name_str);
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
collective::Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op);
API_END();
}
// Not exposed to the public since the previous implementation didn't and we don't want to
// add unnecessary communicator API to a machine learning library.
XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) {
API_BEGIN();
collective::Allgather(send_receive_buffer, count);
API_END();
}
// Not yet exposed to the public, error recovery is still WIP.
XGB_DLL int XGCommunicatorSignalError() {
API_BEGIN();
auto msg = XGBGetLastError();
SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg)));
API_END()
}

View File

@@ -22,7 +22,6 @@
#include <cstdio>
#include <cstring>
#include <vector>
#include "collective/communicator-inl.h"
#include "common/common.h"
#include "common/config.h"
#include "common/io.h"
@@ -193,10 +192,6 @@ class CLI {
void CLITrain() {
const double tstart_data_load = dmlc::GetTime();
if (collective::IsDistributed()) {
std::string pname = collective::GetProcessorName();
LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank();
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
@@ -235,15 +230,9 @@ class CLI {
version += 1;
}
std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names);
if (collective::IsDistributed()) {
if (collective::GetRank() == 0) {
LOG(TRACKER) << res;
}
} else {
LOG(CONSOLE) << res;
}
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 &&
collective::GetRank() == 0) {
LOG(CONSOLE) << res;
if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) {
std::ostringstream os;
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
<< i + 1 << ".model";
@@ -256,8 +245,7 @@ class CLI {
<< " sec";
// always save final round
if ((param_.save_period == 0 ||
param_.num_round % param_.save_period != 0) &&
collective::GetRank() == 0) {
param_.num_round % param_.save_period != 0)) {
std::ostringstream os;
if (param_.model_out == CLIParam::kNull) {
os << param_.model_dir << '/' << std::setfill('0') << std::setw(4)
@@ -465,13 +453,6 @@ class CLI {
}
}
// Initialize the collective communicator.
Json json{JsonObject()};
for (auto& kv : cfg) {
json[kv.first] = String(kv.second);
}
collective::Init(json);
param_.Configure(cfg);
}
@@ -507,10 +488,6 @@ class CLI {
}
return 0;
}
~CLI() {
collective::Finalize();
}
};
} // namespace xgboost

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
@@ -13,7 +13,8 @@
#include <utility>
#include <vector>
#include "communicator-inl.cuh"
#include "allreduce.h"
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective {
@@ -24,15 +25,17 @@ namespace xgboost::collective {
* column-wise (vertically), the original values are returned.
*
* @tparam T The type of the values.
*
* @param info MetaInfo about the DMatrix.
* @param device The device id.
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
} // namespace xgboost::collective

View File

@@ -11,11 +11,44 @@
#include <utility>
#include <vector>
#include "allreduce.h"
#include "broadcast.h"
#include "comm.h"
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
namespace xgboost::collective {
namespace detail {
template <typename Fn>
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
std::string msg;
if (collective::GetRank() == 0) {
try {
fn();
} catch (dmlc::Error const& e) {
msg = e.what();
}
}
std::size_t msg_size{msg.size()};
auto rc = Success() << [&] {
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
return rc;
} << [&] {
if (msg_size > 0) {
msg.resize(msg_size);
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
}
return Success();
} << [&] {
if (msg_size > 0) {
LOG(FATAL) << msg;
}
return Success();
};
return rc;
}
} // namespace detail
/**
* @brief Apply the given function where the labels are.
@@ -30,29 +63,19 @@ namespace xgboost::collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
template <typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
collective::Broadcast(&message, 0);
if (message.empty()) {
collective::Broadcast(buffer, size, 0);
} else {
LOG(FATAL) << &message[0];
}
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
// We assume labels are only available on worker 0, so the calculation is done there and
// result broadcast to other workers.
return collective::Broadcast(
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
};
SafeColl(rc);
} else {
std::forward<FN>(function)();
std::forward<Fn>(fn)();
}
}
@@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
template <typename T, typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
auto rc = detail::TryApplyWithLabels(ctx, fn);
collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}
std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
std::size_t size{result->Size()};
rc = std::move(rc) << [&] {
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
} << [&] {
result->Resize(size);
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
};
SafeColl(rc);
} else {
std::forward<Function>(function)();
std::forward<Fn>(fn)();
}
}
@@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
* @return The global max of the input.
*/
template <typename T>
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
SafeColl(rc);
}
return value;
}
@@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
* @param size Number of values to sum.
*/
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
template <typename Container>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}
/**
* @brief Find the global ratio of the given two values across all workers.
*

View File

@@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return comm.Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
}
}
@@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
std::move(rc));
}
offset += as_bytes;
}
@@ -102,7 +103,7 @@ namespace detail {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
}
}
return comm.Block();

View File

@@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
auto rc = RingAllgather(comm, typed);
if (!rc.OK()) {
return rc;
return Fail("Ring allreduce small failed.", std::move(rc));
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());
@@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
@@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
std::move(rc));
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
@@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto n_bytes_in_seg = (n / world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
return Fail("Ring Allreduce failed.", std::move(rc));
}
auto prev = BootstrapPrev(comm.Rank(), comm.World());

View File

@@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
}
auto rank = comm.Rank();
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.");
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
}
workers[r] = std::move(worker);
}
@@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return rc;
}
std::int32_t rank{-1};
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
if (n_bytes != sizeof(comm.Rank())) {
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);

View File

@@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
loop_->Submit(std::move(op));
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }

View File

@@ -76,7 +76,7 @@ CommGroup::CommGroup()
// Common args
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
auto timeout =
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
@@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
sptr.reset();
SafeColl(rc);
}
void Init(Json const& config) { GlobalCommGroupInit(config); }
void Finalize() { GlobalCommGroupFinalize(); }
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
[[nodiscard]] bool IsFederated() {
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
}
void Print(std::string const& message) {
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
SafeColl(rc);
}
std::string GetProcessorName() {
std::string out;
auto rc = GlobalCommGroup()->ProcessorName(&out);
SafeColl(rc);
return out;
}
} // namespace xgboost::collective

View File

@@ -1,34 +0,0 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "communicator-inl.h"
namespace xgboost::collective {
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const &vec) { return vec.size(); });
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const &vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
auto out = AllgatherV(collected);
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
} // namespace xgboost::collective

View File

@@ -1,95 +0,0 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}
/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
} // namespace collective
} // namespace xgboost

View File

@@ -3,308 +3,63 @@
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h" // for Json
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Initialize the collective communicator.
*/
void Init(Json const& config);
/**
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const &config) { Communicator::Init(config); }
/*!
* \brief Finalize the collective communicator.
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*/
inline void Finalize() { Communicator::Finalize(); }
void Finalize();
/*!
* \brief Get rank of current process.
/**
* @brief Get rank of current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
inline int GetRank() { return Communicator::Get()->GetRank(); }
[[nodiscard]] std::int32_t GetRank() noexcept;
/*!
* \brief Get total number of processes.
/**
* @brief Get total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
[[nodiscard]] bool IsDistributed() noexcept;
/*!
* \brief Get if the communicator is federated.
/**
* @brief Get if the communicator is federated.
*
* \return True if the communicator is federated.
* @return True if the communicator is federated.
*/
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
[[nodiscard]] bool IsFederated();
/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* @param message The message to be printed.
*/
inline void Print(char const *message) { Communicator::Get()->Print(message); }
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
/*!
* \brief Get the name of the processor.
*
* \return Name of the processor.
*/
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
void Print(std::string const& message);
/**
* @brief Gathers a single value all processes and distributes the result to all processes.
* @brief Get the name of the processor.
*
* @param input The single value.
* @return Name of the processor.
*/
template <typename T>
inline std::vector<T> Allgather(T const &input) {
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> Allgather(std::vector<T> const &input) {
if (input.empty()) {
return input;
}
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGatherV(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
if (!output.empty()) {
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
}
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
std::size_t total_size{0};
for (auto const &s : input) {
total_size += s.length() + 1; // +1 for null-terminators
}
std::string flat_string;
flat_string.reserve(total_size);
for (auto const &s : input) {
flat_string.append(s);
flat_string.push_back('\0'); // Append a null-terminator after each string
}
auto const output = Communicator::Get()->AllGatherV(flat_string);
std::vector<std::string> result;
std::size_t start_index = 0;
// Iterate through the output, find each null-terminated substring.
for (std::size_t i = 0; i < output.size(); i++) {
if (output[i] == '\0') {
// Construct a std::string from the char* substring
result.emplace_back(&output[start_index]);
// Move to the next substring
start_index = i + 1;
}
}
return result;
}
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*/
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
static_cast<Operation>(op));
}
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
}
template <Operation op>
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
// Specialization for size_t, which is implementation defined, so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <Operation op, typename T,
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
inline void Allreduce(T *send_receive_buffer, size_t count) {
static_assert(sizeof(T) == sizeof(uint64_t));
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void Allreduce(float *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
} // namespace collective
} // namespace xgboost
std::string GetProcessorName();
} // namespace xgboost::collective

View File

@@ -1,63 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "comm.h"
#include "in_memory_communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};
void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
if (type == CommunicatorType::kUnknown) {
// Default to Rabit if unspecified.
type = CommunicatorType::kRabit;
}
type_ = type;
switch (type) {
case CommunicatorType::kRabit: {
communicator_.reset(RabitCommunicator::Create(config));
break;
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
}
#endif
} // namespace xgboost::collective

View File

@@ -1,54 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
thread_local auto old_device_ordinal = -1;
// If the number of GPUs changes, we need to re-initialize NCCL.
thread_local auto old_world_size = -1;
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
communicator_->GetWorldSize() != old_world_size) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost

View File

@@ -1,247 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
namespace xgboost {
namespace collective {
/** @brief Defines the integral and floating data types. */
enum class DataType {
kInt8 = 0,
kUInt8 = 1,
kInt32 = 2,
kUInt32 = 3,
kInt64 = 4,
kUInt64 = 5,
kFloat = 6,
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5
};
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Initialize the communicator. This can only be done once.
*
* @param config JSON configuration for the communicator.
*/
static void Init(Json const &config);
/** @brief Finalize the communicator. */
static void Finalize();
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
/**
* @brief Get the device communicator.
*
* @param device_ordinal ID of the device.
* @return An instance of device communicator.
*/
static DeviceCommunicator *GetDevice(int device_ordinal);
#endif
virtual ~Communicator() = default;
/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }
/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;
/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
virtual std::string AllGather(std::string_view input) = 0;
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
virtual std::string AllGatherV(std::string_view input) = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
/**
* @brief Gets the name of the processor.
*/
virtual std::string GetProcessorName() = 0;
/**
* @brief Prints the message.
*/
virtual void Print(std::string const &message) = 0;
/** @brief Get the communicator type from environment variables. Visible for testing. */
static CommunicatorType GetTypeFromEnv() {
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
if (env != nullptr) {
return StringToType(env);
} else {
return CommunicatorType::kUnknown;
}
}
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
static CommunicatorType GetTypeFromConfig(Json const &config) {
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
if (IsA<String const>(j_upper)) {
return StringToType(get<String const>(j_upper).c_str());
}
auto const &j_lower = config["xgboost_communicator"];
if (IsA<String const>(j_lower)) {
return StringToType(get<String const>(j_lower).c_str());
}
return CommunicatorType::kUnknown;
}
protected:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}
/**
* @brief Shuts down the communicator.
*/
virtual void Shutdown() = 0;
private:
static CommunicatorType StringToType(char const *str) {
CommunicatorType result = CommunicatorType::kUnknown;
if (!CompareStringsCaseInsensitive("rabit", str)) {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
result = CommunicatorType::kInMemoryNccl;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
return result;
}
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
static thread_local std::string nccl_path_;
#if defined(XGBOOST_USE_CUDA)
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif
int const world_size_;
int const rank_;
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,57 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <vector>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Collective communicator for device buffers.
*/
class DeviceCommunicator {
public:
virtual ~DeviceCommunicator() = default;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) = 0;
/** @brief Synchronize device operations. */
virtual void Synchronize() = 0;
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,94 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <numeric> // for accumulate
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.resize(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
auto const output = Allgather(host_buffer_);
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
}
void Synchronize() override {
// Noop.
}
private:
int const device_ordinal_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,12 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "in_memory_communicator.h"
namespace xgboost {
namespace collective {
InMemoryHandler InMemoryCommunicator::handler_{};
} // namespace collective
} // namespace xgboost

View File

@@ -15,14 +15,14 @@ namespace collective {
/**
* An in-memory communicator, useful for testing.
*/
class InMemoryCommunicator : public Communicator {
class InMemoryCommunicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator* Create(Json const& config) {
static InMemoryCommunicator* Create(Json const& config) {
int world_size{0};
int rank{-1};
@@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
return new InMemoryCommunicator(world_size, rank);
}
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
InMemoryCommunicator(int world_size, int rank) {
handler_.Init(world_size, rank);
}

View File

@@ -1,14 +1,13 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "in_memory_handler.h"
#include <algorithm>
#include <functional>
#include "comm.h"
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Functor for allgather.
*/
@@ -16,7 +15,7 @@ class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(std::size_t world_size, std::size_t rank)
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
: world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@@ -30,8 +29,8 @@ class AllgatherFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
};
/**
@@ -41,13 +40,13 @@ class AllgatherVFunctor {
public:
std::string const name{"AllgatherV"};
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == world_size_) {
if (data_->size() == static_cast<std::size_t>(world_size_)) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
}
@@ -56,8 +55,8 @@ class AllgatherVFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
std::map<std::size_t, std::string_view>* data_;
};
@@ -68,7 +67,7 @@ class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
AllreduceFunctor(DataType dataType, Operation operation)
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
: data_type_{dataType}, operation_{operation} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@@ -76,23 +75,23 @@ class AllreduceFunctor {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
// Apply the reduce_operation to the input and the buffer.
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
Accumulate(input, bytes / n_bytes_type, &buffer->front());
}
}
private:
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
Operation reduce_operation) const {
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kBitwiseAND:
case Op::kBitwiseAND:
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
break;
case Operation::kBitwiseOR:
case Op::kBitwiseOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
break;
case Operation::kBitwiseXOR:
case Op::kBitwiseXOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
break;
default:
@@ -101,27 +100,27 @@ class AllreduceFunctor {
}
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kMax:
case Op::kMax:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::max(a, b); });
break;
case Operation::kMin:
case Op::kMin:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::min(a, b); });
break;
case Operation::kSum:
case Op::kSum:
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
case Op::kBitwiseAND:
case Op::kBitwiseOR:
case Op::kBitwiseXOR:
AccumulateBitwise(buffer, input, size, reduce_operation);
break;
default:
@@ -130,36 +129,37 @@ class AllreduceFunctor {
}
void Accumulate(char const* input, std::size_t size, char* buffer) const {
using Type = ArrayInterfaceHandler::Type;
switch (data_type_) {
case DataType::kInt8:
case Type::kI1:
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
reinterpret_cast<std::int8_t const*>(input), size, operation_);
break;
case DataType::kUInt8:
case Type::kU1:
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
break;
case DataType::kInt32:
case Type::kI4:
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
reinterpret_cast<std::int32_t const*>(input), size, operation_);
break;
case DataType::kUInt32:
case Type::kU4:
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
break;
case DataType::kInt64:
case Type::kI8:
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
reinterpret_cast<std::int64_t const*>(input), size, operation_);
break;
case DataType::kUInt64:
case Type::kU8:
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
break;
case DataType::kFloat:
case Type::kF4:
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
operation_);
break;
case DataType::kDouble:
case Type::kF8:
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
operation_);
break;
@@ -169,8 +169,8 @@ class AllreduceFunctor {
}
private:
DataType data_type_;
Operation operation_;
ArrayInterfaceHandler::Type data_type_;
Op operation_;
};
/**
@@ -180,7 +180,7 @@ class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) {
@@ -190,11 +190,11 @@ class BroadcastFunctor {
}
private:
std::size_t rank_;
std::size_t root_;
std::int32_t rank_;
std::int32_t root_;
};
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
cv_.notify_all();
}
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
}
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type,
Operation op) {
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op) {
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
}
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root) {
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
}
template <class HandlerFunctor>
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank,
std::size_t sequence_number, std::int32_t rank,
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
@@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
cv_.notify_all();
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -1,16 +1,15 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <condition_variable>
#include <map>
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
#include "../data/array_interface.h"
#include "comm.h"
namespace xgboost::collective {
/**
* @brief Handles collective communication primitives in memory.
*
@@ -28,12 +27,11 @@ class InMemoryHandler {
/**
* @brief Construct a handler with the given world size.
* @param world_size Number of workers.
* @param world Number of workers.
*
* This is used when the handler only needs to be initialized once with a known world size.
*/
explicit InMemoryHandler(std::int32_t worldSize)
: world_size_{static_cast<std::size_t>(worldSize)} {}
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
/**
* @brief Initialize the handler with the world size and rank.
@@ -43,7 +41,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* initialize it collectively.
*/
void Init(std::size_t world_size, std::size_t rank);
void Init(std::int32_t world_size, std::int32_t rank);
/**
* @brief Shut down the handler.
@@ -53,7 +51,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* shut it down collectively.
*/
void Shutdown(uint64_t sequence_number, std::size_t rank);
void Shutdown(uint64_t sequence_number, std::int32_t rank);
/**
* @brief Perform allgather.
@@ -64,7 +62,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform variable-length allgather.
@@ -75,7 +73,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform allreduce.
@@ -88,7 +86,8 @@ class InMemoryHandler {
* @param op The reduce operation.
*/
void Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op);
/**
* @brief Perform broadcast.
@@ -100,7 +99,7 @@ class InMemoryHandler {
* @param root Index of the worker to broadcast from.
*/
void Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root);
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
private:
/**
@@ -115,17 +114,15 @@ class InMemoryHandler {
*/
template <class HandlerFunctor>
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
std::size_t rank, HandlerFunctor const& functor);
std::int32_t rank, HandlerFunctor const& functor);
std::size_t world_size_{}; /// Number of workers.
std::size_t received_{}; /// Number of calls received with the current sequence.
std::size_t sent_{}; /// Number of calls completed with the current sequence.
std::int32_t world_size_{}; /// Number of workers.
std::int64_t received_{}; /// Number of calls received with the current sequence.
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
mutable std::condition_variable cv_; /// Conditional variable to wait on.
};
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -6,6 +6,8 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <exception> // for exception, current_exception, rethrow_exception
#include <future> // for promise
#include <memory> // for make_shared
#include <mutex> // for lock_guard, unique_lock
#include <queue> // for queue
#include <string> // for string
@@ -18,9 +20,10 @@
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
auto error = [this](Op op) {
op.pr->set_value();
timer_.Stop(__func__);
};
@@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
switch (op.code) {
@@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
qcopy.push(std::move(op));
}
// poll, work on fds that are ready.
@@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (!poll.fds.empty()) {
auto rc = poll.Poll(timeout_);
if (!rc.OK()) {
error();
timer_.Stop(__func__);
return rc;
}
}
timer_.Stop("poll");
// we wonldn't be here if the queue is empty.
// We wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
std::int32_t n_bytes_done{0};
@@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
error(op);
return Fail("Encountered EOF. The other end is likely closed.",
op.sock->GetSockError());
}
}
break;
@@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
auto rc = system::FailWithCode("Invalid socket output.");
error();
error(op);
return rc;
}
@@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
// not yet finished, push back to queue for the next round.
qcopy.push(op);
} else {
op.pr->set_value();
}
}
if (!blocking) {
break;
}
}
timer_.Stop(__func__);
@@ -148,8 +150,7 @@ void Loop::Process() {
};
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
// answer the blocking call even if there are errors, otherwise the blocking will wait
// forever.
// answer the call even if there are errors.
while (true) {
try {
std::unique_lock lock{mu_};
@@ -170,44 +171,15 @@ void Loop::Process() {
// Move the global queue into a local variable to unblock it.
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
auto op = std::move(queue_.front());
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
qcopy.push(op);
}
lock.unlock();
// Clear the local queue, if `is_blocking` is true, this is blocking the current
// worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
}
// Push back the remaining operations.
if (rc.OK()) {
std::unique_lock lock{mu_};
while (!qcopy.empty()) {
queue_.push(qcopy.front());
qcopy.pop();
}
}
// Notify the client thread who called block after all error conditions are set.
auto notify_if_block = [&] {
if (is_blocking) {
std::unique_lock lock{mu_};
block_done_ = true;
lock.unlock();
block_cv_.notify_one();
}
};
// Clear the local queue.
auto rc = this->ProcessQueue(&qcopy);
// Handle error
if (!rc.OK()) {
@@ -215,8 +187,6 @@ void Loop::Process() {
} else {
CHECK(qcopy.empty());
}
notify_if_block();
} catch (std::exception const& e) {
curr_exce_ = std::current_exception();
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
@@ -256,20 +226,28 @@ Result Loop::Stop() {
stop_ = true;
}
}
if (!this->worker_.joinable()) {
std::lock_guard<std::mutex> guard{rc_lock_};
return Fail("Worker has stopped.", std::move(rc_));
}
this->Submit(Op{Op::kBlock});
{
// Wait for the block call to finish.
std::unique_lock lock{mu_};
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
block_done_ = false;
cv_.notify_one();
}
for (auto& fut : futures_) {
if (fut.valid()) {
try {
fut.get();
} catch (std::future_error const&) {
// Do nothing. If something went wrong in the worker, we have a std::future_error
// due to broken promise. This function will transfer the rc back to the caller.
}
}
}
futures_.clear();
{
// Transfer the rc.
std::lock_guard<std::mutex> lock{rc_lock_};
@@ -278,13 +256,13 @@ Result Loop::Stop() {
}
void Loop::Submit(Op op) {
auto p = std::make_shared<std::promise<void>>();
op.pr = std::move(p);
futures_.emplace_back(op.pr->get_future());
CHECK_NE(op.n, 0);
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {

View File

@@ -7,9 +7,12 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t
#include <exception> // for exception_ptr
#include <mutex> // for unique_lock, mutex
#include <future> // for future
#include <memory> // for shared_ptr
#include <mutex> // for mutex
#include <queue> // for queue
#include <thread> // for thread
#include <vector> // for vector
#include "../common/timer.h" // for Monitor
#include "xgboost/collective/result.h" // for Result
@@ -20,14 +23,15 @@ class Loop {
public:
struct Op {
// kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
std::shared_ptr<std::promise<void>> pr;
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
@@ -45,12 +49,11 @@ class Loop {
private:
std::thread worker_; // thread worker to execute the tasks
std::condition_variable cv_; // CV used to notify a new submit call
std::condition_variable block_cv_; // CV used to notify the blocking call
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
std::condition_variable cv_; // CV used to notify a new submit call
std::queue<Op> queue_; // event queue
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::vector<std::future<void>> futures_;
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::chrono::seconds timeout_;
@@ -61,7 +64,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
Result ProcessQueue(std::queue<Op>* p_queue) const;
// The cunsumer function that runs inside a worker thread.
void Process();

View File

@@ -1,243 +0,0 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <numeric> // for accumulate
#include "comm.cuh"
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
StringView nccl_path)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (world_size_ == 1) {
return;
}
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
CHECK(rc.OK()) << rc.Report();
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
if (needs_sync_) {
dh::DefaultStream().Sync();
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return stub_->GroupEnd(); };
}
void NcclDeviceCommunicator::Synchronize() {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::DefaultStream().Sync();
}
} // namespace collective
} // namespace xgboost
#endif

View File

@@ -1,91 +0,0 @@
/*!
* Copyright 2022-2023 XGBoost contributors
*/
#pragma once
#include "../common/device_helpers.cuh"
#include "comm.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
#include "nccl_stub.h"
namespace xgboost {
namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
/**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;
private:
static constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (rank_ == kRootRank) {
auto rc = stub_->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);
int const device_ordinal_;
bool const needs_sync_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
std::shared_ptr<NcclStub> stub_;
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,32 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view) override { return {}; }
std::string AllGatherV(std::string_view) override { return {}; }
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return {}; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
};
} // namespace collective
} // namespace xgboost

View File

@@ -41,20 +41,26 @@ struct Magic {
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
std::int32_t magic{kMagic};
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
magic = 0;
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
std::size_t n_sent{0};
return Success() << [&] {
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
return Success();
} << [&] {
magic = 0;
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
};
}
};
@@ -227,31 +233,43 @@ struct Error {
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
std::size_t n_sent{0};
return Success() << [&] {
return worker->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
};
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
std::size_t n_sent{0};
return Success() << [&] {
return self->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
};
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
std::size_t n_recv{0};
return Success() << [&] {
return peer->RecvAll(&err, sizeof(err), &n_recv);
} << [&] {
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
};
}
};
} // namespace xgboost::collective::proto

View File

@@ -1,175 +0,0 @@
/**
* Copyright 2022-2023 by XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator-inl.h"
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view input) override {
auto const per_rank = input.size();
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
std::string AllGatherV(std::string_view input) override {
auto const size_node_slice = input.size();
auto const all_sizes = collective::Allgather(size_node_slice);
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
return result;
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override { rabit::Finalize(); }
private:
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kBitwiseAND:
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
case Operation::kBitwiseOR:
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseXOR:
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost

View File

@@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
dmlc::DateLogger logger;
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
auto name = std::filesystem::path{ file }.filename();
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {

View File

@@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
CHECK(!this->IsClosed());
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
std::int32_t len = static_cast<std::int32_t>(str.size());
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
auto bytes = this->SendAll(str.c_str(), str.size());
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
return bytes;
std::size_t n_bytes{0};
auto rc = Success() << [&] {
return this->SendAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to send string length.");
}
return Success();
} << [&] {
return this->SendAll(str.c_str(), str.size(), &n_bytes);
} << [&] {
if (n_bytes != str.size()) {
return Fail("Failed to send string.");
}
return Success();
};
SafeColl(rc);
return n_bytes;
}
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
if (static_cast<decltype(len)>(bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
std::size_t n_bytes{0};
return Success() << [&] {
return this->RecvAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to recv string length.");
}
return Success();
} << [&] {
p_str->resize(len);
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
} << [&] {
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
};
}
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,

View File

@@ -31,14 +31,20 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: sortby_{static_cast<SortBy>(
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
n_workers_{
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
timeout_{std::chrono::seconds{
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
using std::chrono_literals::operator""s;
// Some old configurations in JVM for the scala implementation (removed) use 0 to
// indicate blocking. We continue that convention here.
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
}
Result Tracker::WaitUntilReady() const {
using namespace std::chrono_literals; // NOLINT
@@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
timer.Start();
while (!this->Ready()) {
auto ela = timer.Duration().count();
if (ela > this->Timeout().count()) {
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
@@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
{
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
}
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});

View File

@@ -15,6 +15,7 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
/**
*
* @brief Implementation of RABIT tracker.
@@ -52,7 +53,7 @@ class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
std::chrono::seconds timeout_{-1};
std::atomic<bool> ready_{false};
public:

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file io.h
* \brief general stream interface for serialization, I/O
* \author Tianqi Chen
@@ -8,7 +8,6 @@
#define XGBOOST_COMMON_IO_H_
#include <dmlc/io.h>
#include <rabit/internal/io.h> // for MemoryFixSizeBuffer, MemoryBufferStream
#include <algorithm> // for min, fill_n, copy_n
#include <array> // for array
@@ -23,12 +22,99 @@
#include <utility> // for move
#include <vector> // for vector
#include "common.h"
#include "common.h" // for DivRoundUp
#include "xgboost/string_view.h" // for StringView
namespace xgboost::common {
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
struct MemoryFixSizeBuffer : public dmlc::SeekStream {
public:
// similar to SEEK_END in libc
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
public:
/**
* @brief Ctor
*
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
* @param buffer_size Size of the source buffer
*/
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
~MemoryFixSizeBuffer() override = default;
std::size_t Read(void *ptr, std::size_t size) override {
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, std::size_t size) override {
if (size == 0) return;
CHECK_LE(curr_ptr_ + size, buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(std::size_t pos) override {
if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<std::size_t>(pos);
}
}
/**
* @brief Current position in the buffer (stream).
*/
std::size_t Tell() override { return curr_ptr_; }
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
protected:
/*! \brief in memory buffer */
char *p_buffer_{nullptr};
/*! \brief current pointer */
std::size_t buffer_size_{0};
/*! \brief current pointer */
std::size_t curr_ptr_{0};
};
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public dmlc::SeekStream {
public:
explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
size_t Tell() override {
return curr_ptr_;
}
virtual bool AtEnd() const {
return curr_ptr_ == p_buffer_->length();
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*!
* \brief Input stream that support additional PeekRead operation,

View File

@@ -116,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch)
namespace {
/**
* \brief A view over gathered sketch values.
* @brief A view over gathered sketch values.
*/
template <typename T>
struct QuantileAllreduce {
common::Span<T> global_values;
common::Span<bst_idx_t> worker_indptr;
common::Span<bst_idx_t> feature_indptr;
size_t n_features{0};
bst_feature_t n_features{0};
/**
* \brief Get sketch values of the a feature from a worker.
* @brief Get sketch values of the a feature from a worker.
*
* \param rank rank of target worker
* \param fidx feature idx
* @param rank rank of target worker
* @param fidx feature idx
*/
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
// get span for worker
@@ -154,7 +154,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
worker_segments.resize(1, 0);
auto world = collective::GetWorldSize();
auto rank = collective::GetRank();
auto n_columns = sketches_.size();
bst_feature_t n_columns = sketches_.size();
// get the size of each feature.
std::vector<bst_idx_t> sketch_size;
@@ -165,7 +165,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
sketch_size.push_back(reduced[i].size);
}
}
// turn the size into CSC indptr
// Turn the size into CSC indptr
std::vector<bst_idx_t> &sketches_scan = *p_sketches_scan;
sketches_scan.resize((n_columns + 1) * world, 0);
size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker.
@@ -174,7 +174,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
// Gather all column pointers
auto rc =
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
collective::SafeColl(rc);
if (!rc.OK()) {
collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc)));
}
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
@@ -206,7 +209,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
ctx, info,
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
collective::SafeColl(rc);
if (!rc.OK()) {
collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc)));
}
}
template <typename WQSketch>
@@ -260,7 +265,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, Meta
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_categories.data(), global_categories.size()));
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
static_cast<bst_feature_t>(categories_.size())};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
if (!IsCat(feature_types_, fidx)) {
return;
@@ -285,8 +290,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
monitor_.Start(__func__);
size_t n_columns = sketches_.size();
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
bst_feature_t n_columns = sketches_.size();
auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax);
collective::SafeColl(rc);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
AllreduceCategories(ctx, info);
@@ -300,8 +306,8 @@ void SketchContainerImpl<WQSketch>::AllReduce(
// Prune the intermediate num cuts for synchronization.
std::vector<bst_idx_t> global_column_size(columns_size_);
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_column_size.data(), global_column_size.size()));
collective::SafeColl(rc);
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {

View File

@@ -12,7 +12,8 @@
#include <numeric> // for partial_sum
#include <utility>
#include "../collective/communicator-inl.cuh"
#include "../collective/allgather.h"
#include "../collective/allreduce.h"
#include "categorical.h"
#include "common.h"
#include "device_helpers.cuh"
@@ -499,7 +500,7 @@ void SketchContainer::FixError() {
});
}
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
auto world = collective::GetWorldSize();
if (world == 1 || is_column_split) {
@@ -508,16 +509,18 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) {
timer_.Start(__func__);
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
size_t intermediate_num_cuts =
bst_idx_t global_sum_rows = num_rows_;
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&global_sum_rows, 1), collective::Op::kSum);
SafeColl(rc);
bst_idx_t intermediate_num_cuts =
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
size_t n = d_columns_ptr.size();
collective::Allreduce<collective::Operation::kMax>(&n, 1);
rc = collective::Allreduce(ctx, linalg::MakeVec(&n, 1), collective::Op::kMax);
SafeColl(rc);
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
// Get the columns ptr from all workers
@@ -527,18 +530,25 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) {
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
collective::AllReduce<collective::Operation::kSum>(device_.ordinal, gathered_ptrs.data().get(),
gathered_ptrs.size());
rc = collective::Allreduce(
ctx, linalg::MakeVec(gathered_ptrs.data().get(), gathered_ptrs.size(), ctx->Device()),
collective::Op::kSum);
SafeColl(rc);
// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
collective::AllGatherV(device_.ordinal, this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
collective::Synchronize(device_.ordinal);
std::vector<std::int64_t> recv_lengths;
HostDeviceVector<std::int8_t> recvbuf;
rc = collective::AllgatherV(
ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_),
&recv_lengths, &recvbuf);
collective::SafeColl(rc);
for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) {
recv_lengths[i] = recv_lengths[i + 1] - recv_lengths[i];
}
recv_lengths.resize(recv_lengths.size() - 1);
// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);
auto s_recvbuf = recvbuf.DeviceSpan();
std::vector<Span<SketchEntry>> allworkers;
offset = 0;
for (int32_t i = 0; i < world; ++i) {

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2015-2020, XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file random.h
* \brief Utility related to random.
* \author Tianqi Chen
@@ -19,11 +19,13 @@
#include <utility>
#include <vector>
#include "../collective/broadcast.h" // for Broadcast
#include "../collective/communicator-inl.h"
#include "algorithm.h" // ArgSort
#include "common.h"
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
namespace xgboost::common {
/*!
@@ -227,9 +229,10 @@ class ColumnSampler {
}
};
inline auto MakeColumnSampler(Context const*) {
inline auto MakeColumnSampler(Context const* ctx) {
std::uint32_t seed = common::GlobalRandomEngine()();
collective::Broadcast(&seed, sizeof(seed), 0);
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
collective::SafeColl(rc);
auto cs = std::make_shared<common::ColumnSampler>(seed);
return cs;
}

View File

@@ -615,7 +615,12 @@ auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(T) == 16) << error::NoF128();
return dispatch(T{});
// Avoid invalid type.
if constexpr (sizeof(T) == 16) {
return dispatch(T{});
} else {
return dispatch(double{});
}
}
case ArrayInterfaceHandler::kI1: {
return dispatch(std::int8_t{});

View File

@@ -18,7 +18,8 @@
#include <type_traits> // for remove_pointer_t, remove_reference
#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/communicator.h" // for Operation
#include "../collective/allgather.h"
#include "../collective/allreduce.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
@@ -601,41 +602,42 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
}
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
bool is_col_split = this->IsColumnSplit();
if (size != 0 && this->num_col_ != 0 && !is_col_split) {
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info);
}
if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear();
for (size_t i = 0; i < size; ++i) {
auto elem = info[i];
feature_type_names.emplace_back(elem);
}
if (IsColumnSplit()) {
feature_type_names = collective::AllgatherStrings(feature_type_names);
CHECK_EQ(feature_type_names.size(), num_col_)
// Gather column info when data is split by columns
auto gather_columns = [is_col_split, key, n_columns = this->num_col_](auto const& inputs) {
if (is_col_split) {
std::remove_const_t<std::remove_reference_t<decltype(inputs)>> result;
auto rc = collective::AllgatherStrings(inputs, &result);
collective::SafeColl(rc);
CHECK_EQ(result.size(), n_columns)
<< "Length of " << key << " must be equal to number of columns.";
return result;
}
return inputs;
};
if (StringView{key} == "feature_type") { // NOLINT
this->feature_type_names.clear();
std::copy(info, info + size, std::back_inserter(feature_type_names));
feature_type_names = gather_columns(feature_type_names);
auto& h_feature_types = feature_types.HostVector();
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) {
if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{};
} else if (StringView{key} == "feature_name") { // NOLINT
feature_names.clear();
if (is_col_split) {
auto const rank = collective::GetRank();
for (std::size_t i = 0; i < size; ++i) {
auto elem = std::to_string(rank) + "." + info[i];
local_feature_names.emplace_back(elem);
}
feature_names = collective::AllgatherStrings(local_feature_names);
CHECK_EQ(feature_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
std::transform(info, info + size, std::back_inserter(feature_names),
[rank](char const* elem) { return std::to_string(rank) + "." + elem; });
} else {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
}
std::copy(info, info + size, std::back_inserter(feature_names));
}
feature_names = gather_columns(feature_names);
} else {
LOG(FATAL) << "Unknown feature info name: " << key;
}
@@ -728,12 +730,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
}
void MetaInfo::SynchronizeNumberOfColumns(Context const* ctx) {
auto op = IsColumnSplit() ? collective::Op::kSum : collective::Op::kMax;
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&num_col_, 1), op);
collective::SafeColl(rc);
}
namespace {

View File

@@ -9,11 +9,12 @@
#include <type_traits> // for underlying_type_t
#include <vector> // for vector
#include "../collective/communicator-inl.h"
#include "../common/categorical.h" // common::IsCat
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for IsDistributed
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
@@ -95,13 +96,13 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
void SyncFeatureType(Context const* ctx, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
auto& h_ft = *p_h_ft;
auto n_ft = h_ft.size();
collective::Allreduce<collective::Operation::kMax>(&n_ft, 1);
bst_idx_t n_ft = h_ft.size();
collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax));
if (!h_ft.empty()) {
// Check correct size if this is not an empty DMatrix.
CHECK_EQ(h_ft.size(), n_ft);
@@ -109,7 +110,8 @@ void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
if (n_ft > 0) {
h_ft.resize(n_ft);
auto ptr = reinterpret_cast<std::underlying_type_t<FeatureType>*>(h_ft.data());
collective::Allreduce<collective::Operation::kMax>(ptr, h_ft.size());
collective::SafeColl(
collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax));
}
}
} // anonymous namespace
@@ -175,7 +177,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
collective::Allreduce<collective::Operation::kMax>(&n_features, 1);
collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax));
column_sizes.clear();
column_sizes.resize(n_features, 0);
info_.num_col_ = n_features;

View File

@@ -1,20 +1,18 @@
/**
* Copyright 2020-2023, XGBoost contributors
* Copyright 2020-2024, XGBoost contributors
*/
#include <algorithm>
#include <memory>
#include <type_traits>
#include "../collective/allreduce.h"
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for RegenGHist
#include "device_adapter.cuh"
#include "ellpack_page.cuh"
#include "gradient_index.h"
#include "iterative_dmatrix.h"
#include "proxy_dmatrix.cuh"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "sparse_page_source.h"
namespace xgboost::data {
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
@@ -63,7 +61,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
if (cols == 0) {
cols = num_cols();
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax);
SafeColl(rc);
this->info_.num_col_ = cols;
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";

View File

@@ -171,12 +171,13 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
if constexpr (get_value) {
return std::invoke_result_t<
Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
} else {
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
}
}
if constexpr (get_value) {
return std::invoke_result_t<Fn,
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value())>();
} else {
return std::invoke_result_t<Fn, decltype(std::declval<std::shared_ptr<ArrayAdapter>>())>();
}
}

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2014~2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file simple_dmatrix.cc
* \brief the input data structure for gradient boosting
* \author Tianqi Chen
@@ -13,6 +13,7 @@
#include <vector>
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../collective/allgather.h"
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "./simple_batch_iterator.h"
#include "adapter.h"
@@ -76,8 +77,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
std::vector<std::uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
auto rc = collective::Allgather(ctx, linalg::MakeVec(buffer.data(), buffer.size()));
SafeColl(rc);
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
if (offset == 0) {
return;
}

View File

@@ -11,6 +11,7 @@
#include <future> // for async
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <numeric> // for partial_sum
#include <string> // for string
#include <utility> // for pair, move
#include <vector> // for vector

View File

@@ -35,7 +35,6 @@
#include "collective/aggregator.h" // for ApplyWithLabels
#include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed
#include "collective/communicator.h" // for Operation
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
#include "common/common.h" // for ToString, Split
@@ -208,7 +207,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
}
// sanity check
void Validate(Context const*) {
void Validate(Context const* ctx) {
if (!collective::IsDistributed()) {
return;
}
@@ -229,7 +228,8 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
std::array<std::int32_t, 6> sync;
std::copy(data.cbegin(), data.cend(), sync.begin());
collective::Broadcast(sync.data(), sync.size(), 0);
auto rc = collective::Broadcast(ctx, linalg::MakeVec(sync.data(), sync.size()), 0);
collective::SafeColl(rc);
CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin()))
<< "Different model parameter across workers.";
}
@@ -754,7 +754,9 @@ class LearnerConfiguration : public Learner {
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
}
collective::Allreduce<collective::Operation::kMax>(&num_feature, 1);
auto rc =
collective::Allreduce(&ctx_, linalg::MakeVec(&num_feature, 1), collective::Op::kMax);
collective::SafeColl(rc);
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}

View File

@@ -1,14 +1,13 @@
/*!
* Copyright 2015-2018 by Contributors
/**
* Copyright 2015-2024, XGBoost Contributors
* \file logging.cc
* \brief Implementation of loggers.
* \author Tianqi Chen
*/
#include <iostream>
#include "xgboost/parameter.h"
#include "xgboost/logging.h"
#include <string> // for string
#include "collective/communicator-inl.h"
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0

View File

@@ -264,9 +264,14 @@ class EvalAUC : public MetricNoCache {
info.weights_.SetDevice(ctx_->Device());
}
// We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels.Size(), preds.Size()};
std::array<bst_idx_t, 2> meta{info.labels.Size(), preds.Size()};
if (!info.IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kMax>(meta.data(), meta.size());
auto rc = collective::Allreduce(
ctx_,
linalg::MakeTensorView(DeviceOrd::CPU(), common::Span{meta.data(), meta.size()},
meta.size()),
collective::Op::kMax);
collective::SafeColl(rc);
}
if (meta[0] == 0) {
// Empty across all workers, which is not supported.

View File

@@ -1,9 +1,9 @@
/**
* Copyright 2021-2023 by XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#include <thrust/copy.h> // for copy
#include <thrust/scan.h>
#include <algorithm>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <limits>
@@ -11,7 +11,7 @@
#include <tuple>
#include <utility>
#include "../collective/communicator-inl.cuh"
#include "../collective/allreduce.h"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
@@ -201,13 +201,16 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
});
}
double ScaleClasses(Context const *ctx, common::Span<double> results,
double ScaleClasses(Context const *ctx, bool is_column_split, common::Span<double> results,
common::Span<double> local_area, common::Span<double> tp,
common::Span<double> auc, size_t n_classes) {
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
// With vertical federated learning, only the root has label, other parties are not
// evaluation metrics.
if (collective::IsDistributed() && !(is_column_split && collective::IsFederated())) {
std::int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
auto rc = collective::Allreduce(
ctx, linalg::MakeVec(results.data(), results.size(), ctx->Device()), collective::Op::kSum);
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
@@ -334,7 +337,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
auto local_area = d_results.subspan(0, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes);
}
/**
@@ -438,7 +441,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
tp[c] = 1.0f;
}
});
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes);
}
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
@@ -835,7 +838,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
InitCacheOnce<false>(predts, p_cache);
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); // NOLINT
auto d_group_ptr = dh::ToSpan(group_ptr);
CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR.";
size_t n_groups = info.group_ptr_.size() - 1;

View File

@@ -1,18 +1,14 @@
/**
* Copyright 2021-2023, XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_METRIC_AUC_H_
#define XGBOOST_METRIC_AUC_H_
#include <array>
#include <cmath>
#include <limits>
#include <memory>
#include <tuple>
#include <utility>
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/metric.h"

View File

@@ -9,8 +9,6 @@
#include <vector> // std::vector
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
@@ -42,7 +40,7 @@ inline void UpdateLeafValues(Context const* ctx, std::vector<float>* p_quantiles
auto& quantiles = *p_quantiles;
auto const& h_node_idx = nidx;
size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size());
bst_idx_t n_leaf = collective::GlobalMax(ctx, info, static_cast<bst_idx_t>(h_node_idx.size()));
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
if (quantiles.empty()) {
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2017-2024, XGBoost Contributors
*/
#include <algorithm> // for max, fill, min
#include <any> // for any, any_cast
@@ -12,7 +12,7 @@
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
#include "../collective/communicator.h" // for Operation
#include "../collective/allreduce.h"
#include "../common/bitfield.h" // for RBitField8
#include "../common/categorical.h" // for IsCat, Decision
#include "../common/common.h" // for DivRoundUp
@@ -461,11 +461,17 @@ class ColumnSplitHelper {
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
}
void AllreduceBitVectors(Context const*) {
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
decision_storage_.size());
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
missing_storage_.size());
void AllreduceBitVectors(Context const *ctx) {
auto rc = collective::Success() << [&] {
return collective::Allreduce(
ctx, linalg::MakeVec(decision_storage_.data(), decision_storage_.size()),
collective::Op::kBitwiseOR);
} << [&] {
return collective::Allreduce(
ctx, linalg::MakeVec(missing_storage_.data(), missing_storage_.size()),
collective::Op::kBitwiseAND);
};
collective::SafeColl(rc);
}
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2017-2024, XGBoost Contributors
*/
#include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/copy.h>
@@ -11,7 +11,7 @@
#include <any> // for any, any_cast
#include <memory>
#include "../collective/communicator-inl.cuh"
#include "../collective/allreduce.h"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
@@ -817,10 +817,18 @@ class ColumnSplitHelper {
void AllReduceBitVectors(dh::caching_device_vector<BitType>* decision_storage,
dh::caching_device_vector<BitType>* missing_storage) const {
collective::AllReduce<collective::Operation::kBitwiseOR>(
ctx_->Ordinal(), decision_storage->data().get(), decision_storage->size());
collective::AllReduce<collective::Operation::kBitwiseAND>(
ctx_->Ordinal(), missing_storage->data().get(), missing_storage->size());
auto rc = collective::Success() << [&] {
return collective::Allreduce(
ctx_,
linalg::MakeVec(decision_storage->data().get(), decision_storage->size(), ctx_->Device()),
collective::Op::kBitwiseOR);
} << [&] {
return collective::Allreduce(
ctx_,
linalg::MakeVec(missing_storage->data().get(), missing_storage->size(), ctx_->Device()),
collective::Op::kBitwiseAND);
};
collective::SafeColl(rc);
}
void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,

View File

@@ -1,24 +1,28 @@
/**
* Copyright 2021-2023 XGBoost contributors
* Copyright 2021-2023, XGBoost contributors
* \file common_row_partitioner.h
* \brief Common partitioner logic for hist and approx methods.
*/
#ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_
#include <algorithm> // std::all_of
#include <cinttypes> // std::uint32_t
#include <limits> // std::numeric_limits
#include <vector>
#include <algorithm> // for all_of, fill
#include <cinttypes> // for uint32_t
#include <limits> // for numeric_limits
#include <vector> // for vector
#include "../collective/communicator-inl.h"
#include "../common/linalg_op.h" // cbegin
#include "../common/numeric.h" // Iota
#include "../common/partition_builder.h"
#include "hist/expand_entry.h" // CPUExpandEntry
#include "xgboost/base.h"
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h" // TensorView
#include "../collective/allreduce.h" // for Allreduce
#include "../common/bitfield.h" // for RBitField8
#include "../common/linalg_op.h" // for cbegin
#include "../common/numeric.h" // for Iota
#include "../common/partition_builder.h" // for PartitionBuilder
#include "../common/row_set.h" // for RowSetCollection
#include "../common/threading_utils.h" // for ParallelFor2d
#include "xgboost/base.h" // for bst_row_t
#include "xgboost/collective/result.h" // for Success, SafeColl
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView
#include "xgboost/span.h" // for Span
namespace xgboost::tree {
@@ -39,7 +43,7 @@ class ColumnSplitHelper {
}
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads,
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
std::vector<ExpandEntry> const& nodes,
std::vector<int32_t> const& split_conditions, RegTree const* p_tree) {
@@ -56,10 +60,12 @@ class ColumnSplitHelper {
});
// Then aggregate the bit vectors across all the workers.
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
decision_storage_.size());
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
missing_storage_.size());
auto rc = collective::Success() << [&] {
return collective::Allreduce(ctx, &decision_storage_, collective::Op::kBitwiseOR);
} << [&] {
return collective::Allreduce(ctx, &missing_storage_, collective::Op::kBitwiseAND);
};
collective::SafeColl(rc);
// Finally use the bit vectors to partition the rows.
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
@@ -220,7 +226,7 @@ class CommonRowPartitioner {
// Store results in intermediate buffers from partition_builder_
if (is_col_split_) {
column_split_helper_.Partition<BinIdxType, any_missing, any_cat>(
space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree);
ctx, space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree);
} else {
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();

View File

@@ -47,8 +47,10 @@ void FitStump(Context const* ctx, MetaInfo const& info,
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
collective::GlobalSum(info, ctx->Device(), reinterpret_cast<double*>(d_sum.Values().data()),
d_sum.Size() * 2);
auto rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(reinterpret_cast<double*>(d_sum.Values().data()),
d_sum.Size() * 2, ctx->Device()));
SafeColl(rc);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable {

View File

@@ -1,11 +1,11 @@
/**
* Copyright 2020-2023, XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <algorithm> // std::max
#include <vector>
#include <limits>
#include "../../collective/communicator-inl.cuh"
#include "../../collective/allgather.h"
#include "../../common/categorical.h"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh"
@@ -413,8 +413,14 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
collective::AllGather(device_.ordinal, out_splits.data(), all_candidates.data(),
out_splits.size() * sizeof(DeviceSplitCandidate));
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
out_splits.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToDevice));
auto rc = collective::Allgather(
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
collective::SafeColl(rc);
// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),

View File

@@ -12,6 +12,7 @@
#include <utility> // for move
#include <vector> // for vector
#include "../../collective/allgather.h"
#include "../../common/categorical.h" // for CatBitField
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
#include "../../common/linalg_op.h" // for cbegin, cend, begin
@@ -35,7 +36,7 @@ template <typename ExpandEntry>
std::enable_if_t<std::is_same_v<ExpandEntry, CPUExpandEntry> ||
std::is_same_v<ExpandEntry, MultiExpandEntry>,
std::vector<ExpandEntry>>
AllgatherColumnSplit(std::vector<ExpandEntry> const &entries) {
AllgatherColumnSplit(Context const *ctx, std::vector<ExpandEntry> const &entries) {
auto const n_entries = entries.size();
// First, gather all the primitive fields.
@@ -52,7 +53,7 @@ AllgatherColumnSplit(std::vector<ExpandEntry> const &entries) {
serialized_entries.emplace_back(std::move(out));
}
auto all_serialized = collective::VectorAllgatherV(serialized_entries);
auto all_serialized = collective::VectorAllgatherV(ctx, serialized_entries);
CHECK_GE(all_serialized.size(), local_entries.size());
std::vector<ExpandEntry> all_entries(all_serialized.size());
@@ -401,7 +402,7 @@ class HistEvaluator {
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto all_entries = AllgatherColumnSplit(entries);
auto all_entries = AllgatherColumnSplit(ctx_, entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(
@@ -632,7 +633,7 @@ class HistMultiEvaluator {
if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
auto all_entries = AllgatherColumnSplit(entries);
auto all_entries = AllgatherColumnSplit(ctx_, entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
entries[nidx_in_set].split.Update(

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
@@ -7,26 +7,24 @@
#include <algorithm> // for max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <functional> // for function
#include <utility> // for move
#include <vector> // for vector
#include "../../collective/communicator-inl.h" // for Allreduce
#include "../../collective/communicator.h" // for Operation
#include "../../common/hist_util.h" // for GHistRow, ParallelGHi...
#include "../../common/row_set.h" // for RowSetCollection
#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d
#include "../../data/gradient_index.h" // for GHistIndexMatrix
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "hist_cache.h" // for BoundedHistCollection
#include "param.h" // for HistMakerTrainParam
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchIterator, BatchSet
#include "xgboost/linalg.h" // for MatrixView, All, Vect...
#include "xgboost/logging.h" // for CHECK_GE
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree
#include "../../collective/allreduce.h" // for Allreduce
#include "../../common/hist_util.h" // for GHistRow, ParallelGHi...
#include "../../common/row_set.h" // for RowSetCollection
#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d
#include "../../data/gradient_index.h" // for GHistIndexMatrix
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "hist_cache.h" // for BoundedHistCollection
#include "param.h" // for HistMakerTrainParam
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchIterator, BatchSet
#include "xgboost/linalg.h" // for MatrixView, All, Vect...
#include "xgboost/logging.h" // for CHECK_GE
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost::tree {
/**
@@ -171,7 +169,7 @@ class HistogramBuilder {
}
}
void SyncHistogram(Context const *, RegTree const *p_tree,
void SyncHistogram(Context const *ctx, RegTree const *p_tree,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick) {
auto n_total_bins = buffer_.TotalBins();
@@ -186,8 +184,10 @@ class HistogramBuilder {
CHECK(!nodes_to_build.empty());
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
auto rc = collective::Allreduce(
ctx, linalg::MakeVec(reinterpret_cast<double *>(this->hist_[first_nidx].data()), n),
collective::Op::kSum);
SafeColl(rc);
}
common::BlockedSpace2d const &subspace =

View File

@@ -1,18 +1,22 @@
/**
* Copyright 2021-2023, XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#include "param.h"
#include <ios> // for binary
#include <string> // for string
#include "../../collective/communicator-inl.h" // for GetRank, Broadcast
#include "../../collective/broadcast.h" // for Broadcast
#include "../../collective/communicator-inl.h" // for GetRank
#include "xgboost/json.h" // for Object, Json
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost::tree {
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree) const {
void HistMakerTrainParam::CheckTreesSynchronized(Context const* ctx,
RegTree const* local_tree) const {
if (!this->debug_synchronize) {
return;
}
@@ -24,7 +28,15 @@ void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const*
local_tree->SaveModel(&model);
}
Json::Dump(model, &s_model, std::ios::binary);
collective::Broadcast(&s_model, 0);
auto nchars{static_cast<std::int64_t>(s_model.size())};
auto rc = collective::Success() << [&] {
return collective::Broadcast(ctx, linalg::MakeVec(&nchars, 1), 0);
} << [&] {
s_model.resize(nchars);
return collective::Broadcast(ctx, linalg::MakeVec(s_model.data(), s_model.size()), 0);
};
collective::SafeColl(rc);
RegTree ref_tree{}; // rank 0 tree
auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary);

View File

@@ -13,7 +13,7 @@
#include <vector>
#include "../collective/aggregator.h"
#include "../collective/aggregator.cuh"
#include "../collective/broadcast.h"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/cuda_context.cuh" // CUDAContext
@@ -410,11 +410,16 @@ struct GPUHistMakerDevice {
}
});
collective::AllReduce<collective::Operation::kBitwiseOR>(
ctx_->Ordinal(), decision_storage.data().get(), decision_storage.size());
collective::AllReduce<collective::Operation::kBitwiseAND>(
ctx_->Ordinal(), missing_storage.data().get(), missing_storage.size());
collective::Synchronize(ctx_->Ordinal());
auto rc = collective::Success() << [&] {
return collective::Allreduce(
ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(decision_storage), decision_storage.size()),
collective::Op::kBitwiseOR);
} << [&] {
return collective::Allreduce(
ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(missing_storage), missing_storage.size()),
collective::Op::kBitwiseAND);
};
collective::SafeColl(rc);
row_partitioner->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
@@ -611,8 +616,11 @@ struct GPUHistMakerDevice {
monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
collective::GlobalSum(info_, ctx_->Device(), reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms);
auto rc = collective::GlobalSum(
ctx_, info_,
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device()));
SafeColl(rc);
monitor.Stop("AllReduce");
}
@@ -860,7 +868,9 @@ class GPUHistMaker : public TreeUpdater {
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
auto rc = collective::Broadcast(
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0);
SafeColl(rc);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
@@ -1001,9 +1011,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
monitor_.Start(__func__);
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
p_last_fmat_ = p_fmat;

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file updater_refresh.cc
* \brief refresh the statistics and leaf value on the tree on the dataset
* \author Tianqi Chen
@@ -9,8 +9,7 @@
#include <limits>
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "../collective/allreduce.h"
#include "../common/threading_utils.h"
#include "../predictor/predict_fn.h"
#include "./param.h"
@@ -39,7 +38,7 @@ class TreeRefresher : public TreeUpdater {
}
CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented();
const std::vector<GradientPair> &gpair_h = gpair->Data()->ConstHostVector();
// thread temporal space
// Thread local variables.
std::vector<std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp;
// setup temp space for each thread
@@ -61,9 +60,8 @@ class TreeRefresher : public TreeUpdater {
});
}
exc.Rethrow();
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
auto lazy_get_stats = [&]() {
auto get_stats = [&]() {
const MetaInfo &info = p_fmat->Info();
// start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
@@ -93,12 +91,17 @@ class TreeRefresher : public TreeUpdater {
}
});
};
lazy_get_stats();
collective::Allreduce<collective::Operation::kSum>(&dmlc::BeginPtr(stemp[0])->sum_grad,
stemp[0].size() * 2);
int offset = 0;
get_stats();
// Synchronize the aggregated result.
auto &sum_grad = stemp[0];
// x2 for gradient and hessian.
auto rc = collective::Allreduce(
ctx_, linalg::MakeVec(&sum_grad.data()->sum_grad, sum_grad.size() * 2),
collective::Op::kMax);
collective::SafeColl(rc);
bst_node_t offset = 0;
for (auto tree : trees) {
this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree);
this->Refresh(param, dmlc::BeginPtr(sum_grad) + offset, 0, tree);
offset += tree->NumNodes();
}
}

View File

@@ -1,14 +1,14 @@
/**
* Copyright 2014-2023 by XBGoost Contributors
* Copyright 2014-2024, XBGoost Contributors
* \file updater_sync.cc
* \brief synchronize the tree in all distributed nodes
*/
#include <xgboost/tree_updater.h>
#include <limits>
#include <string>
#include <vector>
#include "../collective/broadcast.h"
#include "../collective/communicator-inl.h"
#include "../common/io.h"
#include "xgboost/json.h"
@@ -44,7 +44,8 @@ class TreeSyncher : public TreeUpdater {
}
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
auto rc = collective::Broadcast(ctx_, linalg::MakeVec(s_model.data(), s_model.size()), 0);
SafeColl(rc);
for (auto tree : trees) {
tree->Load(&fs);
}