Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
@@ -13,7 +13,8 @@
#include <utility>
#include <vector>
#include "communicator-inl.cuh"
#include "allreduce.h"
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective {
@@ -24,15 +25,17 @@ namespace xgboost::collective {
* column-wise (vertically), the original values are returned.
*
* @tparam T The type of the values.
*
* @param info MetaInfo about the DMatrix.
* @param device The device id.
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
} // namespace xgboost::collective

View File

@@ -11,11 +11,44 @@
#include <utility>
#include <vector>
#include "allreduce.h"
#include "broadcast.h"
#include "comm.h"
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
namespace xgboost::collective {
namespace detail {
template <typename Fn>
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
std::string msg;
if (collective::GetRank() == 0) {
try {
fn();
} catch (dmlc::Error const& e) {
msg = e.what();
}
}
std::size_t msg_size{msg.size()};
auto rc = Success() << [&] {
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
return rc;
} << [&] {
if (msg_size > 0) {
msg.resize(msg_size);
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
}
return Success();
} << [&] {
if (msg_size > 0) {
LOG(FATAL) << msg;
}
return Success();
};
return rc;
}
} // namespace detail
/**
* @brief Apply the given function where the labels are.
@@ -30,29 +63,19 @@ namespace xgboost::collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
template <typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
collective::Broadcast(&message, 0);
if (message.empty()) {
collective::Broadcast(buffer, size, 0);
} else {
LOG(FATAL) << &message[0];
}
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
// We assume labels are only available on worker 0, so the calculation is done there and
// result broadcast to other workers.
return collective::Broadcast(
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
};
SafeColl(rc);
} else {
std::forward<FN>(function)();
std::forward<Fn>(fn)();
}
}
@@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
template <typename T, typename Fn>
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
Fn&& fn) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
auto rc = detail::TryApplyWithLabels(ctx, fn);
collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}
std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
std::size_t size{result->Size()};
rc = std::move(rc) << [&] {
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
} << [&] {
result->Resize(size);
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
};
SafeColl(rc);
} else {
std::forward<Function>(function)();
std::forward<Fn>(fn)();
}
}
@@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
* @return The global max of the input.
*/
template <typename T>
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
SafeColl(rc);
}
return value;
}
@@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
* @param size Number of values to sum.
*/
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
return collective::Allreduce(ctx, values, collective::Op::kSum);
}
return Success();
}
template <typename Container>
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}
/**
* @brief Find the global ratio of the given two values across all workers.
*

View File

@@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
return comm.Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
}
}
@@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
auto as_bytes = sizes[r];
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
if (!rc.OK()) {
return rc;
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
std::move(rc));
}
offset += as_bytes;
}
@@ -102,7 +103,7 @@ namespace detail {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
}
}
return comm.Block();

View File

@@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
auto rc = RingAllgather(comm, typed);
if (!rc.OK()) {
return rc;
return Fail("Ring allreduce small failed.", std::move(rc));
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());
@@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
@@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
std::move(rc));
}
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
@@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto n_bytes_in_seg = (n / world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
return Fail("Ring Allreduce failed.", std::move(rc));
}
auto prev = BootstrapPrev(comm.Rank(), comm.World());

View File

@@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
}
auto rank = comm.Rank();
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.");
std::size_t n_bytes{0};
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to send rank.", std::move(rc));
}
workers[r] = std::move(worker);
}
@@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return rc;
}
std::int32_t rank{-1};
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
if (n_bytes != sizeof(comm.Rank())) {
std::size_t n_bytes{0};
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
if (!rc.OK()) {
return rc;
} else if (n_bytes != sizeof(comm.Rank())) {
return Fail("Failed to recv rank.");
}
workers[rank] = std::move(peer);

View File

@@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
loop_->Submit(std::move(op));
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }

View File

@@ -76,7 +76,7 @@ CommGroup::CommGroup()
// Common args
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
auto timeout =
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
@@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
sptr.reset();
SafeColl(rc);
}
void Init(Json const& config) { GlobalCommGroupInit(config); }
void Finalize() { GlobalCommGroupFinalize(); }
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
[[nodiscard]] bool IsFederated() {
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
}
void Print(std::string const& message) {
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
SafeColl(rc);
}
std::string GetProcessorName() {
std::string out;
auto rc = GlobalCommGroup()->ProcessorName(&out);
SafeColl(rc);
return out;
}
} // namespace xgboost::collective

View File

@@ -1,34 +0,0 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "communicator-inl.h"
namespace xgboost::collective {
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const &vec) { return vec.size(); });
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const &vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
auto out = AllgatherV(collected);
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
} // namespace xgboost::collective

View File

@@ -1,95 +0,0 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}
/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
} // namespace collective
} // namespace xgboost

View File

@@ -3,308 +3,63 @@
*/
#pragma once
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h" // for Json
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Initialize the collective communicator.
*/
void Init(Json const& config);
/**
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const &config) { Communicator::Init(config); }
/*!
* \brief Finalize the collective communicator.
* @brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*/
inline void Finalize() { Communicator::Finalize(); }
void Finalize();
/*!
* \brief Get rank of current process.
/**
* @brief Get rank of current process.
*
* \return Rank of the worker.
* @return Rank of the worker.
*/
inline int GetRank() { return Communicator::Get()->GetRank(); }
[[nodiscard]] std::int32_t GetRank() noexcept;
/*!
* \brief Get total number of processes.
/**
* @brief Get total number of processes.
*
* \return Total world size.
* @return Total world size.
*/
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
/*!
* \brief Get if the communicator is distributed.
/**
* @brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
* @return True if the communicator is distributed.
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
[[nodiscard]] bool IsDistributed() noexcept;
/*!
* \brief Get if the communicator is federated.
/**
* @brief Get if the communicator is federated.
*
* \return True if the communicator is federated.
* @return True if the communicator is federated.
*/
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
[[nodiscard]] bool IsFederated();
/*!
* \brief Print the message to the communicator.
/**
* @brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* @param message The message to be printed.
*/
inline void Print(char const *message) { Communicator::Get()->Print(message); }
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
/*!
* \brief Get the name of the processor.
*
* \return Name of the processor.
*/
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
void Print(std::string const& message);
/**
* @brief Gathers a single value all processes and distributes the result to all processes.
* @brief Get the name of the processor.
*
* @param input The single value.
* @return Name of the processor.
*/
template <typename T>
inline std::vector<T> Allgather(T const &input) {
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> Allgather(std::vector<T> const &input) {
if (input.empty()) {
return input;
}
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGather(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
template <typename T>
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
input.size() * sizeof(T)};
auto const output = Communicator::Get()->AllGatherV(str_input);
CHECK_EQ(output.size() % sizeof(T), 0);
std::vector<T> result(output.size() / sizeof(T));
if (!output.empty()) {
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
}
return result;
}
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
std::vector<std::vector<char>> const &input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
std::size_t total_size{0};
for (auto const &s : input) {
total_size += s.length() + 1; // +1 for null-terminators
}
std::string flat_string;
flat_string.reserve(total_size);
for (auto const &s : input) {
flat_string.append(s);
flat_string.push_back('\0'); // Append a null-terminator after each string
}
auto const output = Communicator::Get()->AllGatherV(flat_string);
std::vector<std::string> result;
std::size_t start_index = 0;
// Iterate through the output, find each null-terminated substring.
for (std::size_t i = 0; i < output.size(); i++) {
if (output[i] == '\0') {
// Construct a std::string from the char* substring
result.emplace_back(&output[start_index]);
// Move to the next substring
start_index = i + 1;
}
}
return result;
}
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*/
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
static_cast<Operation>(op));
}
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
}
template <Operation op>
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
// Specialization for size_t, which is implementation defined, so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <Operation op, typename T,
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
inline void Allreduce(T *send_receive_buffer, size_t count) {
static_assert(sizeof(T) == sizeof(uint64_t));
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void Allreduce(float *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
} // namespace collective
} // namespace xgboost
std::string GetProcessorName();
} // namespace xgboost::collective

View File

@@ -1,63 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "comm.h"
#include "in_memory_communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};
void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
if (type == CommunicatorType::kUnknown) {
// Default to Rabit if unspecified.
type = CommunicatorType::kRabit;
}
type_ = type;
switch (type) {
case CommunicatorType::kRabit: {
communicator_.reset(RabitCommunicator::Create(config));
break;
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
}
#endif
} // namespace xgboost::collective

View File

@@ -1,54 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
thread_local auto old_device_ordinal = -1;
// If the number of GPUs changes, we need to re-initialize NCCL.
thread_local auto old_world_size = -1;
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
communicator_->GetWorldSize() != old_world_size) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost

View File

@@ -1,247 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
namespace xgboost {
namespace collective {
/** @brief Defines the integral and floating data types. */
enum class DataType {
kInt8 = 0,
kUInt8 = 1,
kInt32 = 2,
kUInt32 = 3,
kInt64 = 4,
kUInt64 = 5,
kFloat = 6,
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5
};
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Initialize the communicator. This can only be done once.
*
* @param config JSON configuration for the communicator.
*/
static void Init(Json const &config);
/** @brief Finalize the communicator. */
static void Finalize();
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
/**
* @brief Get the device communicator.
*
* @param device_ordinal ID of the device.
* @return An instance of device communicator.
*/
static DeviceCommunicator *GetDevice(int device_ordinal);
#endif
virtual ~Communicator() = default;
/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }
/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;
/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;
/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size.
*
* @param input Buffer storing the data.
*/
virtual std::string AllGather(std::string_view input) = 0;
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
* @param input Buffer storing the data.
*/
virtual std::string AllGatherV(std::string_view input) = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
/**
* @brief Gets the name of the processor.
*/
virtual std::string GetProcessorName() = 0;
/**
* @brief Prints the message.
*/
virtual void Print(std::string const &message) = 0;
/** @brief Get the communicator type from environment variables. Visible for testing. */
static CommunicatorType GetTypeFromEnv() {
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
if (env != nullptr) {
return StringToType(env);
} else {
return CommunicatorType::kUnknown;
}
}
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
static CommunicatorType GetTypeFromConfig(Json const &config) {
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
if (IsA<String const>(j_upper)) {
return StringToType(get<String const>(j_upper).c_str());
}
auto const &j_lower = config["xgboost_communicator"];
if (IsA<String const>(j_lower)) {
return StringToType(get<String const>(j_lower).c_str());
}
return CommunicatorType::kUnknown;
}
protected:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}
/**
* @brief Shuts down the communicator.
*/
virtual void Shutdown() = 0;
private:
static CommunicatorType StringToType(char const *str) {
CommunicatorType result = CommunicatorType::kUnknown;
if (!CompareStringsCaseInsensitive("rabit", str)) {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
result = CommunicatorType::kInMemoryNccl;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
return result;
}
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
static thread_local std::string nccl_path_;
#if defined(XGBOOST_USE_CUDA)
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif
int const world_size_;
int const rank_;
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,57 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <vector>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Collective communicator for device buffers.
*/
class DeviceCommunicator {
public:
virtual ~DeviceCommunicator() = default;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) = 0;
/** @brief Synchronize device operations. */
virtual void Synchronize() = 0;
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,94 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <numeric> // for accumulate
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.resize(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
auto const output = Allgather(host_buffer_);
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
}
void Synchronize() override {
// Noop.
}
private:
int const device_ordinal_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,12 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "in_memory_communicator.h"
namespace xgboost {
namespace collective {
InMemoryHandler InMemoryCommunicator::handler_{};
} // namespace collective
} // namespace xgboost

View File

@@ -15,14 +15,14 @@ namespace collective {
/**
* An in-memory communicator, useful for testing.
*/
class InMemoryCommunicator : public Communicator {
class InMemoryCommunicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator* Create(Json const& config) {
static InMemoryCommunicator* Create(Json const& config) {
int world_size{0};
int rank{-1};
@@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
return new InMemoryCommunicator(world_size, rank);
}
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
InMemoryCommunicator(int world_size, int rank) {
handler_.Init(world_size, rank);
}

View File

@@ -1,14 +1,13 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "in_memory_handler.h"
#include <algorithm>
#include <functional>
#include "comm.h"
namespace xgboost {
namespace collective {
namespace xgboost::collective {
/**
* @brief Functor for allgather.
*/
@@ -16,7 +15,7 @@ class AllgatherFunctor {
public:
std::string const name{"Allgather"};
AllgatherFunctor(std::size_t world_size, std::size_t rank)
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
: world_size_{world_size}, rank_{rank} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@@ -30,8 +29,8 @@ class AllgatherFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
};
/**
@@ -41,13 +40,13 @@ class AllgatherVFunctor {
public:
std::string const name{"AllgatherV"};
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
std::map<std::size_t, std::string_view>* data)
: world_size_{world_size}, rank_{rank}, data_{data} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
data_->emplace(rank_, std::string_view{input, bytes});
if (data_->size() == world_size_) {
if (data_->size() == static_cast<std::size_t>(world_size_)) {
for (auto const& kv : *data_) {
buffer->append(kv.second);
}
@@ -56,8 +55,8 @@ class AllgatherVFunctor {
}
private:
std::size_t world_size_;
std::size_t rank_;
std::int32_t world_size_;
std::int32_t rank_;
std::map<std::size_t, std::string_view>* data_;
};
@@ -68,7 +67,7 @@ class AllreduceFunctor {
public:
std::string const name{"Allreduce"};
AllreduceFunctor(DataType dataType, Operation operation)
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
: data_type_{dataType}, operation_{operation} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
@@ -76,23 +75,23 @@ class AllreduceFunctor {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
// Apply the reduce_operation to the input and the buffer.
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
Accumulate(input, bytes / n_bytes_type, &buffer->front());
}
}
private:
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
Operation reduce_operation) const {
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kBitwiseAND:
case Op::kBitwiseAND:
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
break;
case Operation::kBitwiseOR:
case Op::kBitwiseOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
break;
case Operation::kBitwiseXOR:
case Op::kBitwiseXOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
break;
default:
@@ -101,27 +100,27 @@ class AllreduceFunctor {
}
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <class T>
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
switch (reduce_operation) {
case Operation::kMax:
case Op::kMax:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::max(a, b); });
break;
case Operation::kMin:
case Op::kMin:
std::transform(buffer, buffer + size, input, buffer,
[](T a, T b) { return std::min(a, b); });
break;
case Operation::kSum:
case Op::kSum:
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
case Op::kBitwiseAND:
case Op::kBitwiseOR:
case Op::kBitwiseXOR:
AccumulateBitwise(buffer, input, size, reduce_operation);
break;
default:
@@ -130,36 +129,37 @@ class AllreduceFunctor {
}
void Accumulate(char const* input, std::size_t size, char* buffer) const {
using Type = ArrayInterfaceHandler::Type;
switch (data_type_) {
case DataType::kInt8:
case Type::kI1:
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
reinterpret_cast<std::int8_t const*>(input), size, operation_);
break;
case DataType::kUInt8:
case Type::kU1:
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
break;
case DataType::kInt32:
case Type::kI4:
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
reinterpret_cast<std::int32_t const*>(input), size, operation_);
break;
case DataType::kUInt32:
case Type::kU4:
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
break;
case DataType::kInt64:
case Type::kI8:
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
reinterpret_cast<std::int64_t const*>(input), size, operation_);
break;
case DataType::kUInt64:
case Type::kU8:
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
break;
case DataType::kFloat:
case Type::kF4:
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
operation_);
break;
case DataType::kDouble:
case Type::kF8:
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
operation_);
break;
@@ -169,8 +169,8 @@ class AllreduceFunctor {
}
private:
DataType data_type_;
Operation operation_;
ArrayInterfaceHandler::Type data_type_;
Op operation_;
};
/**
@@ -180,7 +180,7 @@ class BroadcastFunctor {
public:
std::string const name{"Broadcast"};
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) {
@@ -190,11 +190,11 @@ class BroadcastFunctor {
}
private:
std::size_t rank_;
std::size_t root_;
std::int32_t rank_;
std::int32_t root_;
};
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
cv_.notify_all();
}
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
std::unique_lock<std::mutex> lock(mutex_);
@@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
}
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank) {
std::size_t sequence_number, std::int32_t rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
}
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type,
Operation op) {
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op) {
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
}
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root) {
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
}
template <class HandlerFunctor>
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank,
std::size_t sequence_number, std::int32_t rank,
HandlerFunctor const& functor) {
// Pass through if there is only 1 client.
if (world_size_ == 1) {
@@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
cv_.notify_all();
}
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -1,16 +1,15 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#pragma once
#include <condition_variable>
#include <map>
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
#include "../data/array_interface.h"
#include "comm.h"
namespace xgboost::collective {
/**
* @brief Handles collective communication primitives in memory.
*
@@ -28,12 +27,11 @@ class InMemoryHandler {
/**
* @brief Construct a handler with the given world size.
* @param world_size Number of workers.
* @param world Number of workers.
*
* This is used when the handler only needs to be initialized once with a known world size.
*/
explicit InMemoryHandler(std::int32_t worldSize)
: world_size_{static_cast<std::size_t>(worldSize)} {}
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
/**
* @brief Initialize the handler with the world size and rank.
@@ -43,7 +41,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* initialize it collectively.
*/
void Init(std::size_t world_size, std::size_t rank);
void Init(std::int32_t world_size, std::int32_t rank);
/**
* @brief Shut down the handler.
@@ -53,7 +51,7 @@ class InMemoryHandler {
* This is used when multiple objects/threads are accessing the same handler and need to
* shut it down collectively.
*/
void Shutdown(uint64_t sequence_number, std::size_t rank);
void Shutdown(uint64_t sequence_number, std::int32_t rank);
/**
* @brief Perform allgather.
@@ -64,7 +62,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform variable-length allgather.
@@ -75,7 +73,7 @@ class InMemoryHandler {
* @param rank Index of the worker.
*/
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank);
std::size_t sequence_number, std::int32_t rank);
/**
* @brief Perform allreduce.
@@ -88,7 +86,8 @@ class InMemoryHandler {
* @param op The reduce operation.
*/
void Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
std::size_t sequence_number, std::int32_t rank,
ArrayInterfaceHandler::Type data_type, Op op);
/**
* @brief Perform broadcast.
@@ -100,7 +99,7 @@ class InMemoryHandler {
* @param root Index of the worker to broadcast from.
*/
void Broadcast(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, std::size_t rank, std::size_t root);
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
private:
/**
@@ -115,17 +114,15 @@ class InMemoryHandler {
*/
template <class HandlerFunctor>
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
std::size_t rank, HandlerFunctor const& functor);
std::int32_t rank, HandlerFunctor const& functor);
std::size_t world_size_{}; /// Number of workers.
std::size_t received_{}; /// Number of calls received with the current sequence.
std::size_t sent_{}; /// Number of calls completed with the current sequence.
std::int32_t world_size_{}; /// Number of workers.
std::int64_t received_{}; /// Number of calls received with the current sequence.
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
std::string buffer_{}; /// A shared common buffer.
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
uint64_t sequence_number_{}; /// Call sequence number.
mutable std::mutex mutex_; /// Lock.
mutable std::condition_variable cv_; /// Conditional variable to wait on.
};
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -6,6 +6,8 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <exception> // for exception, current_exception, rethrow_exception
#include <future> // for promise
#include <memory> // for make_shared
#include <mutex> // for lock_guard, unique_lock
#include <queue> // for queue
#include <string> // for string
@@ -18,9 +20,10 @@
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
auto error = [this](Op op) {
op.pr->set_value();
timer_.Stop(__func__);
};
@@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
switch (op.code) {
@@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
qcopy.push(std::move(op));
}
// poll, work on fds that are ready.
@@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (!poll.fds.empty()) {
auto rc = poll.Poll(timeout_);
if (!rc.OK()) {
error();
timer_.Stop(__func__);
return rc;
}
}
timer_.Stop("poll");
// we wonldn't be here if the queue is empty.
// We wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
auto op = std::move(qcopy.front());
qcopy.pop();
std::int32_t n_bytes_done{0};
@@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
error(op);
return Fail("Encountered EOF. The other end is likely closed.",
op.sock->GetSockError());
}
}
break;
@@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
break;
}
default: {
error();
error(op);
return Fail("Invalid socket operation.");
}
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
auto rc = system::FailWithCode("Invalid socket output.");
error();
error(op);
return rc;
}
@@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
// not yet finished, push back to queue for the next round.
qcopy.push(op);
} else {
op.pr->set_value();
}
}
if (!blocking) {
break;
}
}
timer_.Stop(__func__);
@@ -148,8 +150,7 @@ void Loop::Process() {
};
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
// answer the blocking call even if there are errors, otherwise the blocking will wait
// forever.
// answer the call even if there are errors.
while (true) {
try {
std::unique_lock lock{mu_};
@@ -170,44 +171,15 @@ void Loop::Process() {
// Move the global queue into a local variable to unblock it.
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
auto op = std::move(queue_.front());
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
qcopy.push(op);
}
lock.unlock();
// Clear the local queue, if `is_blocking` is true, this is blocking the current
// worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
}
// Push back the remaining operations.
if (rc.OK()) {
std::unique_lock lock{mu_};
while (!qcopy.empty()) {
queue_.push(qcopy.front());
qcopy.pop();
}
}
// Notify the client thread who called block after all error conditions are set.
auto notify_if_block = [&] {
if (is_blocking) {
std::unique_lock lock{mu_};
block_done_ = true;
lock.unlock();
block_cv_.notify_one();
}
};
// Clear the local queue.
auto rc = this->ProcessQueue(&qcopy);
// Handle error
if (!rc.OK()) {
@@ -215,8 +187,6 @@ void Loop::Process() {
} else {
CHECK(qcopy.empty());
}
notify_if_block();
} catch (std::exception const& e) {
curr_exce_ = std::current_exception();
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
@@ -256,20 +226,28 @@ Result Loop::Stop() {
stop_ = true;
}
}
if (!this->worker_.joinable()) {
std::lock_guard<std::mutex> guard{rc_lock_};
return Fail("Worker has stopped.", std::move(rc_));
}
this->Submit(Op{Op::kBlock});
{
// Wait for the block call to finish.
std::unique_lock lock{mu_};
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
block_done_ = false;
cv_.notify_one();
}
for (auto& fut : futures_) {
if (fut.valid()) {
try {
fut.get();
} catch (std::future_error const&) {
// Do nothing. If something went wrong in the worker, we have a std::future_error
// due to broken promise. This function will transfer the rc back to the caller.
}
}
}
futures_.clear();
{
// Transfer the rc.
std::lock_guard<std::mutex> lock{rc_lock_};
@@ -278,13 +256,13 @@ Result Loop::Stop() {
}
void Loop::Submit(Op op) {
auto p = std::make_shared<std::promise<void>>();
op.pr = std::move(p);
futures_.emplace_back(op.pr->get_future());
CHECK_NE(op.n, 0);
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {

View File

@@ -7,9 +7,12 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t
#include <exception> // for exception_ptr
#include <mutex> // for unique_lock, mutex
#include <future> // for future
#include <memory> // for shared_ptr
#include <mutex> // for mutex
#include <queue> // for queue
#include <thread> // for thread
#include <vector> // for vector
#include "../common/timer.h" // for Monitor
#include "xgboost/collective/result.h" // for Result
@@ -20,14 +23,15 @@ class Loop {
public:
struct Op {
// kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
std::shared_ptr<std::promise<void>> pr;
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
@@ -45,12 +49,11 @@ class Loop {
private:
std::thread worker_; // thread worker to execute the tasks
std::condition_variable cv_; // CV used to notify a new submit call
std::condition_variable block_cv_; // CV used to notify the blocking call
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
std::condition_variable cv_; // CV used to notify a new submit call
std::queue<Op> queue_; // event queue
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::vector<std::future<void>> futures_;
std::mutex mu_; // mutex to protect the queue, cv, and block_done
std::chrono::seconds timeout_;
@@ -61,7 +64,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
Result ProcessQueue(std::queue<Op>* p_queue) const;
// The cunsumer function that runs inside a worker thread.
void Process();

View File

@@ -1,243 +0,0 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include <numeric> // for accumulate
#include "comm.cuh"
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
StringView nccl_path)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (world_size_ == 1) {
return;
}
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
CHECK(rc.OK()) << rc.Report();
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
if (needs_sync_) {
dh::DefaultStream().Sync();
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream());
CHECK(rc.OK()) << rc.Report();
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream());
if (!rc.OK()) {
return rc;
}
offset += as_bytes;
}
return Success();
} << [&] { return stub_->GroupEnd(); };
}
void NcclDeviceCommunicator::Synchronize() {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::DefaultStream().Sync();
}
} // namespace collective
} // namespace xgboost
#endif

View File

@@ -1,91 +0,0 @@
/*!
* Copyright 2022-2023 XGBoost contributors
*/
#pragma once
#include "../common/device_helpers.cuh"
#include "comm.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
#include "nccl_stub.h"
namespace xgboost {
namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
/**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;
private:
static constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (rank_ == kRootRank) {
auto rc = stub_->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
}
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);
int const device_ordinal_;
bool const needs_sync_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
std::shared_ptr<NcclStub> stub_;
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
};
} // namespace collective
} // namespace xgboost

View File

@@ -1,32 +0,0 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view) override { return {}; }
std::string AllGatherV(std::string_view) override { return {}; }
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return {}; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
};
} // namespace collective
} // namespace xgboost

View File

@@ -41,20 +41,26 @@ struct Magic {
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
std::int32_t magic{kMagic};
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
magic = 0;
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
if (n_bytes != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
std::size_t n_sent{0};
return Success() << [&] {
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
return Success();
} << [&] {
magic = 0;
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
} << [&] {
if (n_sent != sizeof(magic)) {
return Fail("Failed to verify.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
return Success();
};
}
};
@@ -227,31 +233,43 @@ struct Error {
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
std::size_t n_sent{0};
return Success() << [&] {
return worker->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
};
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
std::size_t n_sent{0};
return Success() << [&] {
return self->SendAll(&err, sizeof(err), &n_sent);
} << [&] {
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
};
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
std::size_t n_recv{0};
return Success() << [&] {
return peer->RecvAll(&err, sizeof(err), &n_recv);
} << [&] {
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
};
}
};
} // namespace xgboost::collective::proto

View File

@@ -1,175 +0,0 @@
/**
* Copyright 2022-2023 by XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator-inl.h"
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view input) override {
auto const per_rank = input.size();
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
std::string AllGatherV(std::string_view input) override {
auto const size_node_slice = input.size();
auto const all_sizes = collective::Allgather(size_node_slice);
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
return result;
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override { rabit::Finalize(); }
private:
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kBitwiseAND:
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
case Operation::kBitwiseOR:
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseXOR:
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost

View File

@@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
dmlc::DateLogger logger;
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
auto name = std::filesystem::path{ file }.filename();
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {

View File

@@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
CHECK(!this->IsClosed());
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
std::int32_t len = static_cast<std::int32_t>(str.size());
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
auto bytes = this->SendAll(str.c_str(), str.size());
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
return bytes;
std::size_t n_bytes{0};
auto rc = Success() << [&] {
return this->SendAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to send string length.");
}
return Success();
} << [&] {
return this->SendAll(str.c_str(), str.size(), &n_bytes);
} << [&] {
if (n_bytes != str.size()) {
return Fail("Failed to send string.");
}
return Success();
};
SafeColl(rc);
return n_bytes;
}
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
if (static_cast<decltype(len)>(bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
std::size_t n_bytes{0};
return Success() << [&] {
return this->RecvAll(&len, sizeof(len), &n_bytes);
} << [&] {
if (n_bytes != sizeof(len)) {
return Fail("Failed to recv string length.");
}
return Success();
} << [&] {
p_str->resize(len);
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
} << [&] {
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
};
}
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,

View File

@@ -31,14 +31,20 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: sortby_{static_cast<SortBy>(
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
n_workers_{
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
timeout_{std::chrono::seconds{
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
using std::chrono_literals::operator""s;
// Some old configurations in JVM for the scala implementation (removed) use 0 to
// indicate blocking. We continue that convention here.
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
}
Result Tracker::WaitUntilReady() const {
using namespace std::chrono_literals; // NOLINT
@@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
timer.Start();
while (!this->Ready()) {
auto ela = timer.Duration().count();
if (ela > this->Timeout().count()) {
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
" seconds.");
}
@@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
{
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
}
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});

View File

@@ -15,6 +15,7 @@
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
/**
*
* @brief Implementation of RABIT tracker.
@@ -52,7 +53,7 @@ class Tracker {
protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
std::chrono::seconds timeout_{0};
std::chrono::seconds timeout_{-1};
std::atomic<bool> ready_{false};
public: