Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
* Copyright 2023-2024, XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
@@ -13,7 +13,8 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
#include "allreduce.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
@@ -24,15 +25,17 @@ namespace xgboost::collective {
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
*
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param device The device id.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) {
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device.ordinal, values, size);
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -11,11 +11,44 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "allreduce.h"
|
||||
#include "broadcast.h"
|
||||
#include "comm.h"
|
||||
#include "communicator-inl.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/data.h" // for MetaINfo
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
template <typename Fn>
|
||||
[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) {
|
||||
std::string msg;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
fn();
|
||||
} catch (dmlc::Error const& e) {
|
||||
msg = e.what();
|
||||
}
|
||||
}
|
||||
std::size_t msg_size{msg.size()};
|
||||
auto rc = Success() << [&] {
|
||||
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0);
|
||||
return rc;
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
msg.resize(msg_size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0);
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
if (msg_size > 0) {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
return rc;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
@@ -30,29 +63,19 @@ namespace xgboost::collective {
|
||||
* @param size The size of the buffer.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename FN>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
FN&& function) {
|
||||
template <typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<FN>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (message.empty()) {
|
||||
collective::Broadcast(buffer, size, 0);
|
||||
} else {
|
||||
LOG(FATAL) << &message[0];
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and
|
||||
// result broadcast to other workers.
|
||||
return collective::Broadcast(
|
||||
ctx, linalg::MakeVec(reinterpret_cast<std::int8_t*>(buffer), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<FN>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Function&& function) {
|
||||
template <typename T, typename Fn>
|
||||
void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector<T>* result,
|
||||
Fn&& fn) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
auto rc = detail::TryApplyWithLabels(ctx, fn);
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
std::size_t size{result->Size()};
|
||||
rc = std::move(rc) << [&] {
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0);
|
||||
} << [&] {
|
||||
result->Resize(size);
|
||||
return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0);
|
||||
};
|
||||
SafeColl(rc);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
std::forward<Fn>(fn)();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
|
||||
* @return The global max of the input.
|
||||
*/
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
|
||||
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const* ctx,
|
||||
MetaInfo const& info,
|
||||
T value) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kMax>(&value, 1);
|
||||
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax);
|
||||
SafeColl(rc);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
@@ -136,19 +147,14 @@ std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context co
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<T, kDim> values) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
|
||||
return collective::Allreduce(ctx, values, collective::Op::kSum);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
|
||||
return GlobalSum(ctx, info, values->data(), values->size());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global ratio of the given two values across all workers.
|
||||
*
|
||||
|
||||
@@ -47,7 +47,7 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span<std::int64_t const> si
|
||||
auto as_bytes = sizes[r];
|
||||
auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
@@ -102,7 +103,7 @@ namespace detail {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc));
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
|
||||
@@ -36,7 +36,7 @@ Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func
|
||||
auto rc = RingAllgather(comm, typed);
|
||||
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring allreduce small failed.", std::move(rc));
|
||||
}
|
||||
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||
CHECK_EQ(first.size(), data.size());
|
||||
@@ -64,7 +64,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
@@ -97,6 +97,10 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r),
|
||||
std::move(rc));
|
||||
}
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
@@ -128,7 +132,7 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return Fail("Ring Allreduce failed.", std::move(rc));
|
||||
}
|
||||
|
||||
auto prev = BootstrapPrev(comm.Rank(), comm.World());
|
||||
|
||||
@@ -150,9 +150,12 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
}
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank()));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.");
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to send rank.", std::move(rc));
|
||||
}
|
||||
workers[r] = std::move(worker);
|
||||
}
|
||||
@@ -169,8 +172,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
return rc;
|
||||
}
|
||||
std::int32_t rank{-1};
|
||||
auto n_bytes = peer->RecvAll(&rank, sizeof(rank));
|
||||
if (n_bytes != sizeof(comm.Rank())) {
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
} else if (n_bytes != sizeof(comm.Rank())) {
|
||||
return Fail("Failed to recv rank.");
|
||||
}
|
||||
workers[rank] = std::move(peer);
|
||||
|
||||
@@ -94,7 +94,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const {
|
||||
CHECK(loop_);
|
||||
loop_->Submit(op);
|
||||
loop_->Submit(std::move(op));
|
||||
}
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ CommGroup::CommGroup()
|
||||
// Common args
|
||||
auto retry = get_param("dmlc_retry", static_cast<Integer::Int>(DefaultRetry()), Integer{});
|
||||
auto timeout =
|
||||
get_param("dmlc_timeout_sec", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
get_param("dmlc_timeout", static_cast<Integer::Int>(DefaultTimeoutSec()), Integer{});
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
@@ -123,4 +123,30 @@ void GlobalCommGroupFinalize() {
|
||||
sptr.reset();
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
void Init(Json const& config) { GlobalCommGroupInit(config); }
|
||||
|
||||
void Finalize() { GlobalCommGroupFinalize(); }
|
||||
|
||||
std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
|
||||
|
||||
std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
|
||||
|
||||
bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] bool IsFederated() {
|
||||
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
|
||||
}
|
||||
|
||||
void Print(std::string const& message) {
|
||||
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
|
||||
SafeColl(rc);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() {
|
||||
std::string out;
|
||||
auto rc = GlobalCommGroup()->ProcessorName(&out);
|
||||
SafeColl(rc);
|
||||
return out;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost contributors
|
||||
*/
|
||||
#include "communicator-inl.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input) {
|
||||
auto n_inputs = input.size();
|
||||
std::vector<std::int64_t> sizes(n_inputs);
|
||||
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||
[](auto const &vec) { return vec.size(); });
|
||||
|
||||
std::vector<std::int64_t> global_sizes = AllgatherV(sizes);
|
||||
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||
offset[0] = 0;
|
||||
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||
}
|
||||
|
||||
std::vector<char> collected;
|
||||
for (auto const &vec : input) {
|
||||
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||
}
|
||||
auto out = AllgatherV(collected);
|
||||
|
||||
std::vector<std::vector<char>> result;
|
||||
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||
result.emplace_back(std::move(local));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,95 +0,0 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Reduce values from all processes and distribute the result back to all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronize device operations.
|
||||
* @param device ID of the device.
|
||||
*/
|
||||
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -3,308 +3,63 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Initialize the collective communicator.
|
||||
*/
|
||||
void Init(Json const& config);
|
||||
|
||||
/**
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const &config) { Communicator::Init(config); }
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
* @brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*/
|
||||
inline void Finalize() { Communicator::Finalize(); }
|
||||
void Finalize();
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
/**
|
||||
* @brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
* @return Rank of the worker.
|
||||
*/
|
||||
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||
[[nodiscard]] std::int32_t GetRank() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
/**
|
||||
* @brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
* @return Total world size.
|
||||
*/
|
||||
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
[[nodiscard]] std::int32_t GetWorldSize() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
/**
|
||||
* @brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
* @return True if the communicator is distributed.
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
[[nodiscard]] bool IsDistributed() noexcept;
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is federated.
|
||||
/**
|
||||
* @brief Get if the communicator is federated.
|
||||
*
|
||||
* \return True if the communicator is federated.
|
||||
* @return True if the communicator is federated.
|
||||
*/
|
||||
inline bool IsFederated() { return Communicator::Get()->IsFederated(); }
|
||||
[[nodiscard]] bool IsFederated();
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
/**
|
||||
* @brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
* @param message The message to be printed.
|
||||
*/
|
||||
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||
|
||||
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
void Print(std::string const& message);
|
||||
/**
|
||||
* @brief Gathers a single value all processes and distributes the result to all processes.
|
||||
* @brief Get the name of the processor.
|
||||
*
|
||||
* @param input The single value.
|
||||
* @return Name of the processor.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(T const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(std::vector<T> const &input) {
|
||||
if (input.empty()) {
|
||||
return input;
|
||||
}
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGatherV(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
if (!output.empty()) {
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||
* across different workers. Along with which, the size of each vector in
|
||||
* the input can also vary.
|
||||
*
|
||||
* @return The AllgatherV result, containing vectors from all workers.
|
||||
*/
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
std::vector<std::vector<char>> const &input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
|
||||
std::size_t total_size{0};
|
||||
for (auto const &s : input) {
|
||||
total_size += s.length() + 1; // +1 for null-terminators
|
||||
}
|
||||
std::string flat_string;
|
||||
flat_string.reserve(total_size);
|
||||
for (auto const &s : input) {
|
||||
flat_string.append(s);
|
||||
flat_string.push_back('\0'); // Append a null-terminator after each string
|
||||
}
|
||||
|
||||
auto const output = Communicator::Get()->AllGatherV(flat_string);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::size_t start_index = 0;
|
||||
// Iterate through the output, find each null-terminated substring.
|
||||
for (std::size_t i = 0; i < output.size(); i++) {
|
||||
if (output[i] == '\0') {
|
||||
// Construct a std::string from the char* substring
|
||||
result.emplace_back(&output[start_index]);
|
||||
// Move to the next substring
|
||||
start_index = i + 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*/
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||
static_cast<Operation>(op));
|
||||
}
|
||||
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
template <Operation op, typename T,
|
||||
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||
static_assert(sizeof(T) == sizeof(uint64_t));
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
std::string GetProcessorName();
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
thread_local std::string Communicator::nccl_path_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
|
||||
nccl_path_ = nccl;
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
type = arg;
|
||||
}
|
||||
if (type == CommunicatorType::kUnknown) {
|
||||
// Default to Rabit if unspecified.
|
||||
type = CommunicatorType::kRabit;
|
||||
}
|
||||
type_ = type;
|
||||
switch (type) {
|
||||
case CommunicatorType::kRabit: {
|
||||
communicator_.reset(RabitCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kFederated: {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
communicator_.reset(FederatedCommunicator::Create(config));
|
||||
#else
|
||||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kInMemory:
|
||||
case CommunicatorType::kInMemoryNccl: {
|
||||
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kUnknown:
|
||||
LOG(FATAL) << "Unknown communicator type.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,54 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#include "noop_communicator.h"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
thread_local auto old_device_ordinal = -1;
|
||||
// If the number of GPUs changes, we need to re-initialize NCCL.
|
||||
thread_local auto old_world_size = -1;
|
||||
if (!device_communicator_ || device_ordinal != old_device_ordinal ||
|
||||
communicator_->GetWorldSize() != old_world_size) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,247 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/** @brief Defines the integral and floating data types. */
|
||||
enum class DataType {
|
||||
kInt8 = 0,
|
||||
kUInt8 = 1,
|
||||
kInt32 = 2,
|
||||
kUInt32 = 3,
|
||||
kInt64 = 4,
|
||||
kUInt64 = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation {
|
||||
kMax = 0,
|
||||
kMin = 1,
|
||||
kSum = 2,
|
||||
kBitwiseAND = 3,
|
||||
kBitwiseOR = 4,
|
||||
kBitwiseXOR = 5
|
||||
};
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A communicator class that handles collective communication.
|
||||
*/
|
||||
class Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Initialize the communicator. This can only be done once.
|
||||
*
|
||||
* @param config JSON configuration for the communicator.
|
||||
*/
|
||||
static void Init(Json const &config);
|
||||
|
||||
/** @brief Finalize the communicator. */
|
||||
static void Finalize();
|
||||
|
||||
/** @brief Get the communicator instance. */
|
||||
static Communicator *Get() { return communicator_.get(); }
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
/**
|
||||
* @brief Get the device communicator.
|
||||
*
|
||||
* @param device_ordinal ID of the device.
|
||||
* @return An instance of device communicator.
|
||||
*/
|
||||
static DeviceCommunicator *GetDevice(int device_ordinal);
|
||||
#endif
|
||||
|
||||
virtual ~Communicator() = default;
|
||||
|
||||
/** @brief Get the total number of processes. */
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
|
||||
/** @brief Get the rank of the current processes. */
|
||||
int GetRank() const { return rank_; }
|
||||
|
||||
/** @brief Whether the communicator is running in distributed mode. */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
|
||||
/** @brief Whether the communicator is running in federated mode. */
|
||||
virtual bool IsFederated() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGather(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGatherV(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
|
||||
* group.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root Rank of broadcast root.
|
||||
*/
|
||||
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the name of the processor.
|
||||
*/
|
||||
virtual std::string GetProcessorName() = 0;
|
||||
|
||||
/**
|
||||
* @brief Prints the message.
|
||||
*/
|
||||
virtual void Print(std::string const &message) = 0;
|
||||
|
||||
/** @brief Get the communicator type from environment variables. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromEnv() {
|
||||
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
|
||||
if (env != nullptr) {
|
||||
return StringToType(env);
|
||||
} else {
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromConfig(Json const &config) {
|
||||
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
|
||||
if (IsA<String const>(j_upper)) {
|
||||
return StringToType(get<String const>(j_upper).c_str());
|
||||
}
|
||||
auto const &j_lower = config["xgboost_communicator"];
|
||||
if (IsA<String const>(j_lower)) {
|
||||
return StringToType(get<String const>(j_lower).c_str());
|
||||
}
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Construct a new communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
*/
|
||||
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
|
||||
if (world_size < 1) {
|
||||
LOG(FATAL) << "World size " << world_size << " is less than 1.";
|
||||
}
|
||||
if (rank < 0) {
|
||||
LOG(FATAL) << "Rank " << rank << " is less than 0.";
|
||||
}
|
||||
if (rank >= world_size) {
|
||||
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuts down the communicator.
|
||||
*/
|
||||
virtual void Shutdown() = 0;
|
||||
|
||||
private:
|
||||
static CommunicatorType StringToType(char const *str) {
|
||||
CommunicatorType result = CommunicatorType::kUnknown;
|
||||
if (!CompareStringsCaseInsensitive("rabit", str)) {
|
||||
result = CommunicatorType::kRabit;
|
||||
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||
result = CommunicatorType::kFederated;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
|
||||
result = CommunicatorType::kInMemory;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
|
||||
result = CommunicatorType::kInMemoryNccl;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator type " << str;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
static thread_local std::string nccl_path_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,57 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Collective communicator for device buffers.
|
||||
*/
|
||||
class DeviceCommunicator {
|
||||
public:
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) = 0;
|
||||
|
||||
/** @brief Synchronize device operations. */
|
||||
virtual void Synchronize() = 0;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,94 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
explicit DeviceCommunicatorAdapter(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.resize(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
|
||||
auto const output = Allgather(host_buffer_);
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.resize(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank_) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
// Noop.
|
||||
}
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,12 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
InMemoryHandler InMemoryCommunicator::handler_{};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -15,14 +15,14 @@ namespace collective {
|
||||
/**
|
||||
* An in-memory communicator, useful for testing.
|
||||
*/
|
||||
class InMemoryCommunicator : public Communicator {
|
||||
class InMemoryCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator* Create(Json const& config) {
|
||||
static InMemoryCommunicator* Create(Json const& config) {
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
|
||||
@@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator {
|
||||
return new InMemoryCommunicator(world_size, rank);
|
||||
}
|
||||
|
||||
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
|
||||
InMemoryCommunicator(int world_size, int rank) {
|
||||
handler_.Init(world_size, rank);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_handler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Functor for allgather.
|
||||
*/
|
||||
@@ -16,7 +15,7 @@ class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(std::size_t world_size, std::size_t rank)
|
||||
AllgatherFunctor(std::int32_t world_size, std::int32_t rank)
|
||||
: world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@@ -30,8 +29,8 @@ class AllgatherFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -41,13 +40,13 @@ class AllgatherVFunctor {
|
||||
public:
|
||||
std::string const name{"AllgatherV"};
|
||||
|
||||
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
|
||||
AllgatherVFunctor(std::int32_t world_size, std::int32_t rank,
|
||||
std::map<std::size_t, std::string_view>* data)
|
||||
: world_size_{world_size}, rank_{rank}, data_{data} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
data_->emplace(rank_, std::string_view{input, bytes});
|
||||
if (data_->size() == world_size_) {
|
||||
if (data_->size() == static_cast<std::size_t>(world_size_)) {
|
||||
for (auto const& kv : *data_) {
|
||||
buffer->append(kv.second);
|
||||
}
|
||||
@@ -56,8 +55,8 @@ class AllgatherVFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::int32_t world_size_;
|
||||
std::int32_t rank_;
|
||||
std::map<std::size_t, std::string_view>* data_;
|
||||
};
|
||||
|
||||
@@ -68,7 +67,7 @@ class AllreduceFunctor {
|
||||
public:
|
||||
std::string const name{"Allreduce"};
|
||||
|
||||
AllreduceFunctor(DataType dataType, Operation operation)
|
||||
AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation)
|
||||
: data_type_{dataType}, operation_{operation} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
@@ -76,23 +75,23 @@ class AllreduceFunctor {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); });
|
||||
// Apply the reduce_operation to the input and the buffer.
|
||||
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
|
||||
Accumulate(input, bytes / n_bytes_type, &buffer->front());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
|
||||
Operation reduce_operation) const {
|
||||
void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kBitwiseAND:
|
||||
case Op::kBitwiseAND:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
case Op::kBitwiseOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseXOR:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
|
||||
break;
|
||||
default:
|
||||
@@ -101,27 +100,27 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
|
||||
void AccumulateBitwise(T*, T const*, std::size_t, Op) const {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kMax:
|
||||
case Op::kMax:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::max(a, b); });
|
||||
break;
|
||||
case Operation::kMin:
|
||||
case Op::kMin:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::min(a, b); });
|
||||
break;
|
||||
case Operation::kSum:
|
||||
case Op::kSum:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
case Op::kBitwiseAND:
|
||||
case Op::kBitwiseOR:
|
||||
case Op::kBitwiseXOR:
|
||||
AccumulateBitwise(buffer, input, size, reduce_operation);
|
||||
break;
|
||||
default:
|
||||
@@ -130,36 +129,37 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
void Accumulate(char const* input, std::size_t size, char* buffer) const {
|
||||
using Type = ArrayInterfaceHandler::Type;
|
||||
switch (data_type_) {
|
||||
case DataType::kInt8:
|
||||
case Type::kI1:
|
||||
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
|
||||
reinterpret_cast<std::int8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
case Type::kU1:
|
||||
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
|
||||
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
case Type::kI4:
|
||||
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
|
||||
reinterpret_cast<std::int32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
case Type::kU4:
|
||||
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
|
||||
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
case Type::kI8:
|
||||
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
|
||||
reinterpret_cast<std::int64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
case Type::kU8:
|
||||
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
|
||||
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
case Type::kF4:
|
||||
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
case Type::kF8:
|
||||
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
@@ -169,8 +169,8 @@ class AllreduceFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
Operation operation_;
|
||||
ArrayInterfaceHandler::Type data_type_;
|
||||
Op operation_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -180,7 +180,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
|
||||
BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@@ -190,11 +190,11 @@ class BroadcastFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t rank_;
|
||||
std::size_t root_;
|
||||
std::int32_t rank_;
|
||||
std::int32_t root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
std::size_t sequence_number, std::int32_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type,
|
||||
Operation op) {
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root) {
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank,
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
@@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string*
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
#include "../data/array_interface.h"
|
||||
#include "comm.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief Handles collective communication primitives in memory.
|
||||
*
|
||||
@@ -28,12 +27,11 @@ class InMemoryHandler {
|
||||
|
||||
/**
|
||||
* @brief Construct a handler with the given world size.
|
||||
* @param world_size Number of workers.
|
||||
* @param world Number of workers.
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(std::int32_t worldSize)
|
||||
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||
explicit InMemoryHandler(std::int32_t world) : world_size_{world} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
@@ -43,7 +41,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(std::size_t world_size, std::size_t rank);
|
||||
void Init(std::int32_t world_size, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
@@ -53,7 +51,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, std::size_t rank);
|
||||
void Shutdown(uint64_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
@@ -64,7 +62,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform variable-length allgather.
|
||||
@@ -75,7 +73,7 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
std::size_t sequence_number, std::int32_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
@@ -88,7 +86,8 @@ class InMemoryHandler {
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
|
||||
std::size_t sequence_number, std::int32_t rank,
|
||||
ArrayInterfaceHandler::Type data_type, Op op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
@@ -100,7 +99,7 @@ class InMemoryHandler {
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root);
|
||||
std::size_t sequence_number, std::int32_t rank, std::int32_t root);
|
||||
|
||||
private:
|
||||
/**
|
||||
@@ -115,17 +114,15 @@ class InMemoryHandler {
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
std::size_t rank, HandlerFunctor const& functor);
|
||||
std::int32_t rank, HandlerFunctor const& functor);
|
||||
|
||||
std::size_t world_size_{}; /// Number of workers.
|
||||
std::size_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::size_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::int32_t world_size_{}; /// Number of workers.
|
||||
std::int64_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::int64_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception, current_exception, rethrow_exception
|
||||
#include <future> // for promise
|
||||
#include <memory> // for make_shared
|
||||
#include <mutex> // for lock_guard, unique_lock
|
||||
#include <queue> // for queue
|
||||
#include <string> // for string
|
||||
@@ -18,9 +20,10 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
auto error = [this](Op op) {
|
||||
op.pr->set_value();
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
@@ -38,7 +41,7 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
|
||||
// Iterate through all the ops for poll
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
switch (op.code) {
|
||||
@@ -54,12 +57,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
qcopy.push(op);
|
||||
qcopy.push(std::move(op));
|
||||
}
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
@@ -67,18 +70,18 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
timer_.Stop(__func__);
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
// We wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
// Iterate through all the ops for performing the operations
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
auto op = std::move(qcopy.front());
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
@@ -93,8 +96,9 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
error(op);
|
||||
return Fail("Encountered EOF. The other end is likely closed.",
|
||||
op.sock->GetSockError());
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -112,14 +116,14 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
error(op);
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
error(op);
|
||||
return rc;
|
||||
}
|
||||
|
||||
@@ -127,14 +131,12 @@ Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
// not yet finished, push back to queue for the next round.
|
||||
qcopy.push(op);
|
||||
} else {
|
||||
op.pr->set_value();
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@@ -148,8 +150,7 @@ void Loop::Process() {
|
||||
};
|
||||
|
||||
// This loop cannot exit unless `stop_` is set to true. There must always be a thread to
|
||||
// answer the blocking call even if there are errors, otherwise the blocking will wait
|
||||
// forever.
|
||||
// answer the call even if there are errors.
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
@@ -170,44 +171,15 @@ void Loop::Process() {
|
||||
// Move the global queue into a local variable to unblock it.
|
||||
std::queue<Op> qcopy;
|
||||
|
||||
bool is_blocking = false;
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
auto op = std::move(queue_.front());
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
auto notify_if_block = [&] {
|
||||
if (is_blocking) {
|
||||
std::unique_lock lock{mu_};
|
||||
block_done_ = true;
|
||||
lock.unlock();
|
||||
block_cv_.notify_one();
|
||||
}
|
||||
};
|
||||
// Clear the local queue.
|
||||
auto rc = this->ProcessQueue(&qcopy);
|
||||
|
||||
// Handle error
|
||||
if (!rc.OK()) {
|
||||
@@ -215,8 +187,6 @@ void Loop::Process() {
|
||||
} else {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
|
||||
notify_if_block();
|
||||
} catch (std::exception const& e) {
|
||||
curr_exce_ = std::current_exception();
|
||||
set_rc(Fail("Exception inside the event loop:" + std::string{e.what()}));
|
||||
@@ -256,20 +226,28 @@ Result Loop::Stop() {
|
||||
stop_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->worker_.joinable()) {
|
||||
std::lock_guard<std::mutex> guard{rc_lock_};
|
||||
return Fail("Worker has stopped.", std::move(rc_));
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
block_cv_.wait(lock, [this] { return block_done_ || stop_; });
|
||||
block_done_ = false;
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
for (auto& fut : futures_) {
|
||||
if (fut.valid()) {
|
||||
try {
|
||||
fut.get();
|
||||
} catch (std::future_error const&) {
|
||||
// Do nothing. If something went wrong in the worker, we have a std::future_error
|
||||
// due to broken promise. This function will transfer the rc back to the caller.
|
||||
}
|
||||
}
|
||||
}
|
||||
futures_.clear();
|
||||
|
||||
{
|
||||
// Transfer the rc.
|
||||
std::lock_guard<std::mutex> lock{rc_lock_};
|
||||
@@ -278,13 +256,13 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
op.pr = std::move(p);
|
||||
futures_.emplace_back(op.pr->get_future());
|
||||
CHECK_NE(op.n, 0);
|
||||
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
|
||||
@@ -7,9 +7,12 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <exception> // for exception_ptr
|
||||
#include <mutex> // for unique_lock, mutex
|
||||
#include <future> // for future
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -20,14 +23,15 @@ class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
std::shared_ptr<std::promise<void>> pr;
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
@@ -45,12 +49,11 @@ class Loop {
|
||||
private:
|
||||
std::thread worker_; // thread worker to execute the tasks
|
||||
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
std::condition_variable block_cv_; // CV used to notify the blocking call
|
||||
bool block_done_{false}; // Flag to indicate whether the blocking call has finished.
|
||||
std::condition_variable cv_; // CV used to notify a new submit call
|
||||
|
||||
std::queue<Op> queue_; // event queue
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
std::vector<std::future<void>> futures_;
|
||||
std::mutex mu_; // mutex to protect the queue, cv, and block_done
|
||||
|
||||
std::chrono::seconds timeout_;
|
||||
|
||||
@@ -61,7 +64,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync,
|
||||
StringView nccl_path)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
stub_ = std::make_shared<NcclStub>(std::move(nccl_path));
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world_size_)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result{ncclInt8};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Operation const &op) {
|
||||
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
|
||||
op == Operation::kBitwiseXOR;
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
auto const size = count * GetTypeSize(data_type);
|
||||
dh::caching_device_vector<char> buffer(size * world_size_);
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream());
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream());
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
offset += as_bytes;
|
||||
}
|
||||
return Success();
|
||||
} << [&] { return stub_->GroupEnd(); };
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
@@ -1,91 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "comm.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "nccl_stub.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new NCCL communicator.
|
||||
* @param device_ordinal The GPU device id.
|
||||
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||
*
|
||||
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||
*
|
||||
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override;
|
||||
void Synchronize() override;
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (rank_ == kRootRank) {
|
||||
auto rc = stub_->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
}
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
bool const needs_sync_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
std::shared_ptr<NcclStub> stub_;
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,32 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* A no-op communicator, used for non-distributed training.
|
||||
*/
|
||||
class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
std::string AllGather(std::string_view) override { return {}; }
|
||||
std::string AllGatherV(std::string_view) override { return {}; }
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return {}; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -41,20 +41,26 @@ struct Magic {
|
||||
|
||||
[[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) {
|
||||
std::int32_t magic{kMagic};
|
||||
auto n_bytes = p_sock->SendAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
|
||||
magic = 0;
|
||||
n_bytes = p_sock->RecvAll(&magic, sizeof(magic));
|
||||
if (n_bytes != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return p_sock->SendAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
magic = 0;
|
||||
return p_sock->RecvAll(&magic, sizeof(magic), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent != sizeof(magic)) {
|
||||
return Fail("Failed to verify.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -227,31 +233,43 @@ struct Error {
|
||||
|
||||
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||
std::int32_t err{ErrorSignal()};
|
||||
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return worker->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
};
|
||||
}
|
||||
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||
// close.
|
||||
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
std::size_t n_sent{0};
|
||||
return Success() << [&] {
|
||||
return self->SendAll(&err, sizeof(err), &n_sent);
|
||||
} << [&] {
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
};
|
||||
}
|
||||
// get signal, either for error or for shutdown.
|
||||
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
std::size_t n_recv{0};
|
||||
return Success() << [&] {
|
||||
return peer->RecvAll(&err, sizeof(err), &n_recv);
|
||||
} << [&] {
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
};
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class RabitCommunicator : public Communicator {
|
||||
public:
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::vector<std::string> args_str;
|
||||
for (auto &items : get<Object const>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kString: {
|
||||
args_str.push_back(items.first + "=" + get<String const>(items.second));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
args_str.push_back(items.first + "=1");
|
||||
} else {
|
||||
args_str.push_back(items.first + "=0");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<char *> args;
|
||||
for (auto &key_value : args_str) {
|
||||
args.push_back(&key_value[0]);
|
||||
}
|
||||
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
|
||||
LOG(FATAL) << "Failed to initialize Rabit";
|
||||
}
|
||||
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
|
||||
}
|
||||
|
||||
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
|
||||
|
||||
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
DoAllReduce<char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
DoAllReduce<float>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
DoAllReduce<double>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type";
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
rabit::Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
|
||||
|
||||
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
|
||||
|
||||
protected:
|
||||
void Shutdown() override { rabit::Finalize(); }
|
||||
|
||||
private:
|
||||
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
|
||||
count);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kMin:
|
||||
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kSum:
|
||||
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -62,20 +62,15 @@ void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
dmlc::DateLogger logger;
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
auto name = std::filesystem::path{ file }.filename();
|
||||
return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() +
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
return std::string{"["} + logger.HumanDate() + "]" + std::forward<std::string>(msg); // NOLINT
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
|
||||
@@ -60,24 +60,46 @@ std::size_t TCPSocket::Send(StringView str) {
|
||||
CHECK(!this->IsClosed());
|
||||
CHECK_LT(str.size(), std::numeric_limits<std::int32_t>::max());
|
||||
std::int32_t len = static_cast<std::int32_t>(str.size());
|
||||
CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length.";
|
||||
auto bytes = this->SendAll(str.c_str(), str.size());
|
||||
CHECK_EQ(bytes, str.size()) << "Failed to send string.";
|
||||
return bytes;
|
||||
std::size_t n_bytes{0};
|
||||
auto rc = Success() << [&] {
|
||||
return this->SendAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to send string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
return this->SendAll(str.c_str(), str.size(), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != str.size()) {
|
||||
return Fail("Failed to send string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
SafeColl(rc);
|
||||
return n_bytes;
|
||||
}
|
||||
|
||||
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||
CHECK(!this->IsClosed());
|
||||
std::int32_t len;
|
||||
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
p_str->resize(len);
|
||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||
if (static_cast<decltype(len)>(bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
std::size_t n_bytes{0};
|
||||
return Success() << [&] {
|
||||
return this->RecvAll(&len, sizeof(len), &n_bytes);
|
||||
} << [&] {
|
||||
if (n_bytes != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
p_str->resize(len);
|
||||
return this->RecvAll(&(*p_str)[0], len, &n_bytes);
|
||||
} << [&] {
|
||||
if (static_cast<std::remove_reference_t<decltype(len)>>(n_bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
|
||||
@@ -31,14 +31,20 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
Tracker::Tracker(Json const& config)
|
||||
: sortby_{static_cast<SortBy>(
|
||||
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
|
||||
n_workers_{
|
||||
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
|
||||
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
|
||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||
timeout_{std::chrono::seconds{
|
||||
OptionalArg<Integer const>(config, "timeout", static_cast<std::int64_t>(0))}} {
|
||||
using std::chrono_literals::operator""s;
|
||||
// Some old configurations in JVM for the scala implementation (removed) use 0 to
|
||||
// indicate blocking. We continue that convention here.
|
||||
timeout_ = (timeout_ == 0s) ? -1s : timeout_;
|
||||
}
|
||||
|
||||
Result Tracker::WaitUntilReady() const {
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
@@ -49,7 +55,7 @@ Result Tracker::WaitUntilReady() const {
|
||||
timer.Start();
|
||||
while (!this->Ready()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > this->Timeout().count()) {
|
||||
if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) {
|
||||
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
|
||||
" seconds.");
|
||||
}
|
||||
@@ -250,8 +256,10 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
return listener_.NonBlocking(true);
|
||||
} << [&] {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
{
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
}
|
||||
if (state.running) {
|
||||
// Don't timeout if the communicator group is up and running.
|
||||
return poll.Poll(std::chrono::seconds{-1});
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; }
|
||||
/**
|
||||
*
|
||||
* @brief Implementation of RABIT tracker.
|
||||
@@ -52,7 +53,7 @@ class Tracker {
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
std::int32_t port_{-1};
|
||||
std::chrono::seconds timeout_{0};
|
||||
std::chrono::seconds timeout_{-1};
|
||||
std::atomic<bool> ready_{false};
|
||||
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user