merge latest change from upstream
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2024 by XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
@@ -617,8 +617,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -637,8 +637,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -682,19 +683,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
||||
xgboost::bst_ulong size, int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
xgboost_CHECK_C_ARG_PTR(field);
|
||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
Context ctx;
|
||||
auto dtype = static_cast<DataType>(type);
|
||||
std::string str;
|
||||
auto proc = [&](auto cast_d_ptr) {
|
||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||
auto t = linalg::TensorView<T, 1>(
|
||||
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
|
||||
{size}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json iface{linalg::ArrayInterface(t)};
|
||||
CHECK(ArrayInterface<1>{iface}.is_contiguous);
|
||||
str = Json::Dump(iface);
|
||||
return str;
|
||||
};
|
||||
|
||||
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
|
||||
switch (dtype) {
|
||||
case xgboost::DataType::kFloat32: {
|
||||
auto cast_ptr = reinterpret_cast<const float *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kDouble: {
|
||||
auto cast_ptr = reinterpret_cast<const double *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt32: {
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt64: {
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t *>(data);
|
||||
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
|
||||
}
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -990,7 +1024,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs
|
||||
bst_float *hess, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
||||
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
auto ctx = learner->Ctx()->MakeCPU();
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
* Copyright 2021-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <tuple> // for make_tuple
|
||||
#include <utility> // for move
|
||||
#include <vector>
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <functional> // for multiplies
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <string> // for string
|
||||
#include <tuple> // for make_tuple
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/json_utils.h" // for TypeCheck
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <future> // for future
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||
#include <utility> // for pair
|
||||
|
||||
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||
#include "../collective/tracker.h" // for RabitTracker
|
||||
#include "../common/timer.h" // for Timer
|
||||
#include "c_api_error.h" // for API_BEGIN
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
|
||||
|
||||
namespace {
|
||||
using TrackerHandleT =
|
||||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||
|
||||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||
xgboost_CHECK_C_ARG_PTR(handle);
|
||||
@@ -40,17 +42,29 @@ struct CollAPIEntry {
|
||||
};
|
||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||
|
||||
void WaitImpl(TrackerHandleT *ptr) {
|
||||
std::chrono::seconds wait_for{100};
|
||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
|
||||
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
|
||||
|
||||
auto fut = ptr->second;
|
||||
while (fut.valid()) {
|
||||
auto res = fut.wait_for(wait_for);
|
||||
CHECK(res != std::future_status::deferred);
|
||||
|
||||
if (res == std::future_status::ready) {
|
||||
auto const &rc = ptr->second.get();
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
collective::SafeColl(rc);
|
||||
break;
|
||||
}
|
||||
|
||||
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@@ -62,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
||||
Json jconfig = Json::Load(config);
|
||||
|
||||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||
std::unique_ptr<collective::Tracker> tptr;
|
||||
std::shared_ptr<collective::Tracker> tptr;
|
||||
if (type == "federated") {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
||||
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
|
||||
#else
|
||||
LOG(FATAL) << error::NoFederated();
|
||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||
} else if (type == "rabit") {
|
||||
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
||||
tptr = std::make_shared<collective::RabitTracker>(jconfig);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator:" << type;
|
||||
}
|
||||
@@ -93,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||
@@ -101,19 +115,39 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
||||
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
|
||||
API_BEGIN();
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
WaitImpl(ptr);
|
||||
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||
// interrupt the model training.
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||
API_BEGIN();
|
||||
using namespace std::chrono_literals; // NOLINT
|
||||
auto *ptr = GetTrackerHandle(handle);
|
||||
WaitImpl(ptr);
|
||||
ptr->first->Stop();
|
||||
// The wait is not necessary since we just called stop, just reusing the function to do
|
||||
// any potential cleanups.
|
||||
WaitImpl(ptr, ptr->first->Timeout());
|
||||
common::Timer timer;
|
||||
timer.Start();
|
||||
// Make sure no one else is waiting on the tracker.
|
||||
while (!ptr->first.unique()) {
|
||||
auto ela = timer.Duration().count();
|
||||
if (ela > ptr->first->Timeout().count()) {
|
||||
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||
break;
|
||||
}
|
||||
std::this_thread::sleep_for(64ms);
|
||||
}
|
||||
delete ptr;
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ template <typename T>
|
||||
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
|
||||
std::array<T, 2> results{dividend, divisor};
|
||||
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
|
||||
collective::SafeColl(rc);
|
||||
SafeColl(rc);
|
||||
std::tie(dividend, divisor) = std::tuple_cat(results);
|
||||
if (divisor <= 0) {
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <utility> // for move
|
||||
|
||||
#include "broadcast.h"
|
||||
#include "comm.h" // for Comm, Channel
|
||||
@@ -29,16 +30,22 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
auto rc = Success() << [&] {
|
||||
auto send_rank = (rank + world - r + worker_off) % world;
|
||||
auto send_off = send_rank * segment_size;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
|
||||
auto send_seg = data.subspan(send_off, send_nbytes);
|
||||
CHECK_NE(send_seg.size(), 0);
|
||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||
} << [&] {
|
||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||
auto recv_off = recv_rank * segment_size;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
||||
bool is_last_segment = recv_rank == (world - 1);
|
||||
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
|
||||
auto recv_seg = data.subspan(recv_off, recv_nbytes);
|
||||
CHECK_NE(recv_seg.size(), 0);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -91,7 +98,9 @@ namespace detail {
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
} << [&] { return prev_ch->Block(); };
|
||||
} << [&] {
|
||||
return prev_ch->Block();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -99,4 +108,47 @@ namespace detail {
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
|
||||
auto n_inputs = input.size();
|
||||
std::vector<std::int64_t> sizes(n_inputs);
|
||||
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||
[](auto const& vec) { return vec.size(); });
|
||||
|
||||
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
|
||||
|
||||
HostDeviceVector<std::int8_t> recv;
|
||||
auto rc =
|
||||
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
|
||||
SafeColl(rc);
|
||||
|
||||
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
|
||||
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||
offset[0] = 0;
|
||||
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||
}
|
||||
|
||||
std::vector<char> collected;
|
||||
for (auto const& vec : input) {
|
||||
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||
}
|
||||
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
|
||||
&recv);
|
||||
SafeColl(rc);
|
||||
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
|
||||
|
||||
std::vector<std::vector<char>> result;
|
||||
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||
result.emplace_back(std::move(local));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, std::vector<std::vector<char>> const& input) {
|
||||
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <string> // for string
|
||||
#include <type_traits> // for remove_cv_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "comm_group.h" // for CommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/linalg.h" // for MakeVec
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||
* worker_off = 1, then it owns the third segment.
|
||||
* worker_off = 1, then it owns the third segment (2 + 1).
|
||||
*/
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
@@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
|
||||
// This function is also used for ring allreduce, hence we allow the last segment to be
|
||||
// larger due to round-down.
|
||||
auto n_bytes_per_segment = data.size_bytes() / comm.World();
|
||||
auto erased = common::EraseType(data);
|
||||
|
||||
auto rank = comm.Rank();
|
||||
@@ -61,7 +65,7 @@ template <typename T>
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -76,7 +80,7 @@ template <typename T>
|
||||
|
||||
std::vector<std::int64_t> sizes(world, 0);
|
||||
sizes[rank] = data.size_bytes();
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@@ -98,4 +102,115 @@ template <typename T>
|
||||
|
||||
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
|
||||
auto const& cctx = comm.Ctx(ctx, data.Device());
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Allgather(cctx, erased);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather all data from all workers.
|
||||
*
|
||||
* @param data The input and output buffer, needs to be pre-allocated by the caller.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
|
||||
auto const& cg = *GlobalCommGroup();
|
||||
if (data.Size() % cg.World() != 0) {
|
||||
return Fail("The total number of elements should be multiple of the number of workers.");
|
||||
}
|
||||
return Allgather(ctx, cg, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data,
|
||||
std::vector<std::int64_t>* recv_segments,
|
||||
HostDeviceVector<std::int8_t>* recv) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
std::vector<std::int64_t> sizes(comm.World(), 0);
|
||||
sizes[comm.Rank()] = data.Values().size_bytes();
|
||||
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
|
||||
auto rc = comm.Backend(DeviceOrd::CPU())
|
||||
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
recv_segments->resize(sizes.size() + 1);
|
||||
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
|
||||
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
|
||||
recv->SetDevice(data.Device());
|
||||
recv->Resize(total_bytes);
|
||||
|
||||
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
|
||||
|
||||
auto backend = comm.Backend(data.Device());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
|
||||
return backend->AllgatherV(
|
||||
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
|
||||
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allgather with variable length data.
|
||||
*
|
||||
* @param data The input data.
|
||||
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
|
||||
* from the first worker, [2, 5) elements are from the second one.
|
||||
* @param recv The buffer storing the result.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
|
||||
std::vector<std::int64_t>* recv_segments,
|
||||
HostDeviceVector<std::int8_t>* recv) {
|
||||
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
*
|
||||
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||
* across different workers. Along with which, the size of each vector in
|
||||
* the input can also vary.
|
||||
*
|
||||
* @return The AllgatherV result, containing vectors from all workers.
|
||||
*/
|
||||
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||
Context const* ctx, std::vector<std::vector<char>> const& input);
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
|
||||
std::vector<std::string>* p_result) {
|
||||
std::vector<std::vector<char>> inputs(input.size());
|
||||
for (std::size_t i = 0; i < input.size(); ++i) {
|
||||
inputs[i] = {input[i].cbegin(), input[i].cend()};
|
||||
}
|
||||
Context ctx;
|
||||
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
|
||||
auto& result = *p_result;
|
||||
result.resize(out.size());
|
||||
for (std::size_t i = 0; i < out.size(); ++i) {
|
||||
result[i] = {out[i].cbegin(), out[i].cend()};
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "allreduce.h"
|
||||
|
||||
@@ -16,7 +16,44 @@
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
namespace {
|
||||
template <typename T>
|
||||
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
|
||||
auto rank = comm.Rank();
|
||||
auto world = comm.World();
|
||||
|
||||
auto next_ch = comm.Chan(BootstrapNext(rank, world));
|
||||
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));
|
||||
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
|
||||
auto s_buffer = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
auto offset = data.size_bytes() * rank;
|
||||
auto self = s_buffer.subspan(offset, data.size_bytes());
|
||||
std::copy_n(data.data(), data.size_bytes(), self.data());
|
||||
|
||||
auto typed = common::RestoreType<T>(s_buffer);
|
||||
auto rc = RingAllgather(comm, typed);
|
||||
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||
CHECK_EQ(first.size(), data.size());
|
||||
|
||||
for (std::int32_t r = 1; r < world; ++r) {
|
||||
auto offset = data.size_bytes() * r;
|
||||
auto buf = s_buffer.subspan(offset, data.size_bytes());
|
||||
op(buf, first);
|
||||
}
|
||||
std::copy_n(first.data(), first.size(), data.data());
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
// note that n_bytes_in_seg is calculated with round-down.
|
||||
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::size_t n_bytes_in_seg, Func const& op) {
|
||||
auto rank = comm.Rank();
|
||||
@@ -27,33 +64,39 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto next_ch = comm.Chan(dst_rank);
|
||||
auto prev_ch = comm.Chan(src_rank);
|
||||
|
||||
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
|
||||
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||
|
||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||
// send to ring next
|
||||
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
|
||||
send_off = std::min(send_off, data.size_bytes());
|
||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
common::Span<std::int8_t> seg, recv_seg;
|
||||
auto rc = Success() << [&] {
|
||||
// send to ring next
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = send_rank * n_bytes_in_seg;
|
||||
|
||||
auto rc = next_ch->SendAll(send_seg);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
bool is_last_segment = send_rank == (world - 1);
|
||||
|
||||
// receive from ring prev
|
||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||
recv_off = std::min(recv_off, data.size_bytes());
|
||||
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
|
||||
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||
return next_ch->SendAll(send_seg);
|
||||
} << [&] {
|
||||
// receive from ring prev
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = recv_rank * n_bytes_in_seg;
|
||||
|
||||
bool is_last_segment = recv_rank == (world - 1);
|
||||
|
||||
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
|
||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||
|
||||
recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||
seg = s_buf.subspan(0, recv_seg.size());
|
||||
return prev_ch->RecvAll(seg);
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
|
||||
// accumulate to recv_seg
|
||||
CHECK_EQ(seg.size(), recv_seg.size());
|
||||
@@ -68,6 +111,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
if (comm.World() == 1) {
|
||||
return Success();
|
||||
}
|
||||
if (data.size_bytes() == 0) {
|
||||
return Success();
|
||||
}
|
||||
return DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
// Divide the data into segments according to the number of workers.
|
||||
@@ -75,7 +121,11 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
||||
auto n = data.size_bytes() / n_bytes_elem;
|
||||
auto world = comm.World();
|
||||
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
|
||||
if (n < static_cast<decltype(n)>(world)) {
|
||||
return RingAllreduceSmall<T>(comm, data, op);
|
||||
}
|
||||
|
||||
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
@@ -88,7 +138,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
|
||||
return std::move(rc) << [&] {
|
||||
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
} << [&] { return comm.Block(); };
|
||||
} << [&] {
|
||||
return comm.Block();
|
||||
};
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int8_t
|
||||
#include <functional> // for function
|
||||
#include <type_traits> // for is_invocable_v, enable_if_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/type.h" // for EraseType, RestoreType
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
|
||||
#include "comm.h" // for Comm, RestoreType
|
||||
#include "comm_group.h" // for GlobalCommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
||||
auto erased = common::EraseType(data);
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||
common::Span<std::int8_t> out) {
|
||||
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
|
||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||
auto rhs_t = common::RestoreType<T>(out);
|
||||
@@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
||||
|
||||
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
||||
}
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
|
||||
linalg::TensorView<T, kDim> data, Op op) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
auto type = ToDType<T>::kType;
|
||||
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
|
||||
}
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
|
||||
return Allreduce(ctx, *GlobalCommGroup(), data, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specialization for std::vector.
|
||||
*/
|
||||
template <typename T, typename Alloc>
|
||||
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
|
||||
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specialization for scalar value.
|
||||
*/
|
||||
template <typename T>
|
||||
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
|
||||
Allreduce(Context const* ctx, T* data, Op op) {
|
||||
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/collective/result.h" // for
|
||||
#include "../common/type.h"
|
||||
#include "comm.h" // for Comm, EraseType
|
||||
#include "comm_group.h" // for CommGroup
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/linalg.h" // for VectorView
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
@@ -23,4 +27,21 @@ template <typename T>
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
return cpu_impl::Broadcast(comm, erased, root);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
|
||||
linalg::VectorView<T> data, std::int32_t root) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
CHECK(data.Contiguous());
|
||||
auto erased = common::EraseType(data.Values());
|
||||
auto backend = comm.Backend(data.Device());
|
||||
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
|
||||
return Broadcast(ctx, *GlobalCommGroup(), data, root);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -42,6 +42,10 @@ bool constexpr IsFloatingPointV() {
|
||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||
auto p_lhs = lhs.data();
|
||||
auto p_out = out.data();
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
|
||||
#pragma omp simd
|
||||
#endif
|
||||
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
||||
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||
}
|
||||
@@ -108,9 +112,8 @@ bool constexpr IsFloatingPointV() {
|
||||
return cpu_impl::Broadcast(comm, data, root);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
return RingAllgather(comm, data, size);
|
||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||
return RingAllgather(comm, data);
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include <cstdint> // for int8_t, int64_t
|
||||
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../data/array_interface.h"
|
||||
#include "allgather.h" // for AllgatherVOffset
|
||||
@@ -166,14 +165,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
||||
} << [&] { return nccl->Block(); };
|
||||
}
|
||||
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||
if (!comm.IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||
CHECK(nccl);
|
||||
auto stub = nccl->Stub();
|
||||
auto size = data.size_bytes() / comm.World();
|
||||
|
||||
auto send = data.subspan(comm.Rank() * size, size);
|
||||
return Success() << [&] {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -8,8 +8,7 @@
|
||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "nccl_stub.h"
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class NCCLColl : public Coll {
|
||||
@@ -20,8 +19,7 @@ class NCCLColl : public Coll {
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
|
||||
@@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
* @brief Allgather
|
||||
*
|
||||
* @param [in,out] data Data buffer for input and output.
|
||||
* @param [in] size Size of data for each worker.
|
||||
*/
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size);
|
||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data);
|
||||
/**
|
||||
* @brief Allgather with variable length.
|
||||
*
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "comm.h"
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for exit
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
#include "../common/common.h" // for AssertNCCLSupport
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
@@ -21,11 +24,7 @@
|
||||
namespace xgboost::collective {
|
||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id)
|
||||
: timeout_{timeout},
|
||||
retry_{retry},
|
||||
tracker_{host, port, -1},
|
||||
task_id_{std::move(task_id)},
|
||||
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
|
||||
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
|
||||
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
@@ -187,12 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
return Success();
|
||||
}
|
||||
|
||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
||||
namespace {
|
||||
std::string InitLog(std::string task_id, std::int32_t rank) {
|
||||
if (task_id.empty()) {
|
||||
return "Rank " + std::to_string(rank);
|
||||
}
|
||||
return "Task " + task_id + " got rank " + std::to_string(rank);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||
StringView nccl_path)
|
||||
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
|
||||
nccl_path_{std::move(nccl_path)} {
|
||||
if (this->TrackerInfo().host.empty()) {
|
||||
// Not in a distributed environment.
|
||||
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||
return;
|
||||
}
|
||||
|
||||
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
|
||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||
if (!rc.OK()) {
|
||||
this->ResetState();
|
||||
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||
}
|
||||
}
|
||||
@@ -219,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
|
||||
// Start command
|
||||
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
||||
std::int32_t lport = listener.BindHost();
|
||||
listener.Listen();
|
||||
std::int32_t lport{0};
|
||||
rc = std::move(rc) << [&] {
|
||||
return listener.BindHost(&lport);
|
||||
} << [&] {
|
||||
return listener.Listen();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
// create worker for listening to error notice.
|
||||
auto domain = tracker.Domain();
|
||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||
auto eport = error_sock->BindHost();
|
||||
error_sock->Listen();
|
||||
std::int32_t eport{0};
|
||||
rc = std::move(rc) << [&] {
|
||||
return error_sock->BindHost(&eport);
|
||||
} << [&] {
|
||||
return error_sock->Listen();
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
error_port_ = eport;
|
||||
|
||||
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
||||
auto conn = error_sock->Accept();
|
||||
TCPSocket conn;
|
||||
SockAddress addr;
|
||||
auto rc = error_sock->Accept(&conn, &addr);
|
||||
// On Linux, a shutdown causes an invalid argument error;
|
||||
if (rc.Code() == std::errc::invalid_argument) {
|
||||
return;
|
||||
}
|
||||
// On Windows, accept returns a closed socket after finalize.
|
||||
if (conn.IsClosed()) {
|
||||
return;
|
||||
}
|
||||
// The error signal is from the tracker, while shutdown signal is from the shutdown method
|
||||
// of the RabitComm class (this).
|
||||
bool is_error{false};
|
||||
rc = proto::Error{}.RecvSignal(&conn, &is_error);
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
return;
|
||||
}
|
||||
if (!is_error) {
|
||||
return; // shutdown
|
||||
}
|
||||
|
||||
LOG(WARNING) << "Another worker is running into error.";
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
// exit is nicer than abort as the former performs cleanups.
|
||||
@@ -241,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
LOG(FATAL) << "abort";
|
||||
#endif
|
||||
}};
|
||||
// The worker thread is detached here to avoid the need to handle it later during
|
||||
// destruction. For C++, if a thread is not joined or detached, it will segfault during
|
||||
// destruction.
|
||||
error_worker_.detach();
|
||||
|
||||
proto::Start start;
|
||||
@@ -253,7 +307,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
|
||||
// get ring neighbors
|
||||
std::string snext;
|
||||
tracker.Recv(&snext);
|
||||
rc = tracker.Recv(&snext);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||
}
|
||||
@@ -273,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
CHECK(this->channels_.empty());
|
||||
for (auto& w : workers) {
|
||||
if (w) {
|
||||
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
||||
<< [&] { return w->SetKeepAlive(); };
|
||||
rc = std::move(rc) << [&] {
|
||||
return w->SetNoDelay();
|
||||
} << [&] {
|
||||
return w->NonBlocking(true);
|
||||
} << [&] {
|
||||
return w->SetKeepAlive();
|
||||
};
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
||||
}
|
||||
|
||||
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||
return rc;
|
||||
}
|
||||
|
||||
@@ -288,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
|
||||
if (!this->IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
|
||||
"lead to undefined behaviour.";
|
||||
auto rc = this->Shutdown();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
@@ -295,24 +358,52 @@ RabitComm::~RabitComm() noexcept(false) {
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::Shutdown() {
|
||||
if (!this->IsDistributed()) {
|
||||
return Success();
|
||||
}
|
||||
// Tell the tracker that this worker is shutting down.
|
||||
TCPSocket tracker;
|
||||
// Tell the error hanlding thread that we are shutting down.
|
||||
TCPSocket err_client;
|
||||
|
||||
return Success() << [&] {
|
||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||
} << [&] {
|
||||
return this->Block();
|
||||
} << [&] {
|
||||
Json jcmd{Object{}};
|
||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
||||
auto scmd = Json::Dump(jcmd);
|
||||
auto n_bytes = tracker.Send(scmd);
|
||||
if (n_bytes != scmd.size()) {
|
||||
return Fail("Faled to send cmd.");
|
||||
}
|
||||
return proto::ShutdownCMD{}.Send(&tracker);
|
||||
} << [&] {
|
||||
this->channels_.clear();
|
||||
return Success();
|
||||
} << [&] {
|
||||
// Use tracker address to determine whether we want to use IPv6.
|
||||
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
|
||||
// Shutdown the error handling thread. We signal the thread through socket,
|
||||
// alternatively, we can get the native handle and use pthread_cancel. But using a
|
||||
// socket seems to be clearer as we know what's happening.
|
||||
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
|
||||
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
|
||||
// local socket. For a normal OS, this should be enough time to schedule the
|
||||
// connection.
|
||||
auto rc = Connect(StringView{addr}, this->error_port_, 1,
|
||||
std::min(std::chrono::seconds{10}, timeout_), &err_client);
|
||||
this->ResetState();
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to connect to the error socket.", std::move(rc));
|
||||
}
|
||||
return rc;
|
||||
} << [&] {
|
||||
// We put error thread shutdown at the end so that we have a better chance to finish
|
||||
// the previous more important steps.
|
||||
return proto::Error{}.SignalShutdown(&err_client);
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
||||
if (!this->IsDistributed()) {
|
||||
LOG(CONSOLE) << msg;
|
||||
return Success();
|
||||
}
|
||||
TCPSocket out;
|
||||
proto::Print print;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
@@ -320,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
||||
TCPSocket out;
|
||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
|
||||
TCPSocket tracker;
|
||||
return Success() << [&] {
|
||||
return this->ConnectTracker(&tracker);
|
||||
} << [&] {
|
||||
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
|
||||
};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
auto rc = stub->GetUniqueId(&id);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
@@ -90,9 +90,8 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
|
||||
SafeColl(rc);
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
std::size_t j = 0;
|
||||
@@ -113,7 +112,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
[&] {
|
||||
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||
};
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||
this->channels_.emplace_back(
|
||||
@@ -124,7 +123,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
NCCLComm::~NCCLComm() {
|
||||
if (nccl_comm_) {
|
||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -53,6 +53,10 @@ class NCCLComm : public Comm {
|
||||
auto rc = this->Stream().Sync(false);
|
||||
return GetCUDAResult(rc);
|
||||
}
|
||||
[[nodiscard]] Result Shutdown() final {
|
||||
this->ResetState();
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
class NCCLChannel : public Channel {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdint> // for int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
@@ -14,13 +14,13 @@
|
||||
#include "loop.h" // for Loop
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
|
||||
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
|
||||
inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
|
||||
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
||||
|
||||
// indexing into the ring
|
||||
@@ -51,11 +51,25 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
|
||||
proto::PeerInfo tracker_;
|
||||
SockDomain domain_{SockDomain::kV4};
|
||||
|
||||
std::thread error_worker_;
|
||||
std::int32_t error_port_;
|
||||
|
||||
std::string task_id_;
|
||||
std::vector<std::shared_ptr<Channel>> channels_;
|
||||
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
|
||||
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
|
||||
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
|
||||
|
||||
void ResetState() {
|
||||
this->world_ = -1;
|
||||
this->rank_ = 0;
|
||||
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
|
||||
|
||||
tracker_ = proto::PeerInfo{};
|
||||
this->task_id_.clear();
|
||||
channels_.clear();
|
||||
|
||||
loop_.reset();
|
||||
}
|
||||
|
||||
public:
|
||||
Comm() = default;
|
||||
@@ -75,10 +89,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] auto Retry() const { return retry_; }
|
||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||
|
||||
[[nodiscard]] auto Rank() const { return rank_; }
|
||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
||||
[[nodiscard]] auto Rank() const noexcept { return rank_; }
|
||||
[[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
|
||||
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||
void Submit(Loop::Op op) const {
|
||||
CHECK(loop_);
|
||||
loop_->Submit(op);
|
||||
}
|
||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||
|
||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||
@@ -88,6 +105,14 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
/**
|
||||
* @brief Get a string ID for the current process.
|
||||
*/
|
||||
[[nodiscard]] virtual Result ProcessorName(std::string* out) const {
|
||||
auto rc = GetHostName(out);
|
||||
return rc;
|
||||
}
|
||||
[[nodiscard]] virtual Result Shutdown() = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -105,20 +130,20 @@ class RabitComm : public HostComm {
|
||||
|
||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string task_id);
|
||||
[[nodiscard]] Result Shutdown();
|
||||
|
||||
public:
|
||||
// bootstrapping construction.
|
||||
RabitComm() = default;
|
||||
// ctor for testing where environment is known.
|
||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t retry, std::string task_id, StringView nccl_path);
|
||||
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||
StringView nccl_path);
|
||||
~RabitComm() noexcept(false) override;
|
||||
|
||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||
|
||||
[[nodiscard]] Result SignalError(Result const&) override;
|
||||
[[nodiscard]] Result Shutdown() final;
|
||||
|
||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "comm_group.h"
|
||||
|
||||
#include <algorithm> // for transform
|
||||
#include <cctype> // for tolower
|
||||
#include <chrono> // for seconds
|
||||
#include <cstdint> // for int32_t
|
||||
#include <iterator> // for back_inserter
|
||||
#include <memory> // for shared_ptr, unique_ptr
|
||||
#include <string> // for string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "tracker.h" // for GetHostAddress
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/context.h" // for DeviceOrd
|
||||
#include "xgboost/json.h" // for Json
|
||||
#include "../common/json_utils.h" // for OptionalArg
|
||||
#include "coll.h" // for Coll
|
||||
#include "comm.h" // for Comm
|
||||
#include "xgboost/context.h" // for DeviceOrd
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_coll.h"
|
||||
@@ -65,6 +64,9 @@ CommGroup::CommGroup()
|
||||
|
||||
auto const& obj = get<Object const>(config);
|
||||
auto it = obj.find(upper);
|
||||
if (it != obj.cend() && obj.find(name) != obj.cend()) {
|
||||
LOG(FATAL) << "Duplicated parameter:" << name;
|
||||
}
|
||||
if (it != obj.cend()) {
|
||||
return OptionalArg<decltype(t)>(config, upper, dft);
|
||||
} else {
|
||||
@@ -78,14 +80,14 @@ CommGroup::CommGroup()
|
||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||
|
||||
if (type == "rabit") {
|
||||
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
||||
auto ptr =
|
||||
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||
auto ptr = new CommGroup{
|
||||
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
|
||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||
return ptr;
|
||||
} else if (type == "federated") {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
@@ -117,6 +119,8 @@ void GlobalCommGroupInit(Json config) {
|
||||
|
||||
void GlobalCommGroupFinalize() {
|
||||
auto& sptr = GlobalCommGroup();
|
||||
auto rc = sptr->Finalize();
|
||||
sptr.reset();
|
||||
SafeColl(rc);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "coll.h" // for Comm
|
||||
#include "comm.h" // for Coll
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for GetHostName
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
@@ -31,19 +30,35 @@ class CommGroup {
|
||||
public:
|
||||
CommGroup();
|
||||
|
||||
[[nodiscard]] auto World() const { return comm_->World(); }
|
||||
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
||||
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
||||
[[nodiscard]] auto World() const noexcept { return comm_->World(); }
|
||||
[[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
|
||||
[[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
|
||||
|
||||
[[nodiscard]] Result Finalize() const {
|
||||
return Success() << [this] {
|
||||
if (gpu_comm_) {
|
||||
return gpu_comm_->Shutdown();
|
||||
}
|
||||
return Success();
|
||||
} << [&] {
|
||||
return comm_->Shutdown();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] static CommGroup* Create(Json config);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
||||
/**
|
||||
* @brief Decide the context to use for communication.
|
||||
*
|
||||
* @param ctx Global context, provides the CUDA stream and ordinal.
|
||||
* @param device The device used by the data to be communicated.
|
||||
*/
|
||||
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
||||
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
||||
|
||||
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
||||
auto rc = GetHostName(out);
|
||||
return rc;
|
||||
return this->comm_->ProcessorName(out);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ class InMemoryHandler {
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
|
||||
explicit InMemoryHandler(std::int32_t worldSize)
|
||||
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] { timer_.Stop(__func__); };
|
||||
auto error = [this] {
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
if (stop_) {
|
||||
timer_.Stop(__func__);
|
||||
@@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
poll.WatchWrite(*op.sock);
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
timer_.Start("poll");
|
||||
auto rc = poll.Poll(timeout_);
|
||||
timer_.Stop("poll");
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
@@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
CHECK(op.sock->NonBlocking());
|
||||
if (!op.sock) {
|
||||
CHECK(op.code == Op::kSleep);
|
||||
} else {
|
||||
CHECK(op.sock->NonBlocking());
|
||||
}
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
// For testing only.
|
||||
std::this_thread::sleep_for(std::chrono::seconds{op.n});
|
||||
n_bytes_done = op.n;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@@ -128,6 +153,15 @@ void Loop::Process() {
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
// This can handle missed notification: wait(lock, predicate) is equivalent to:
|
||||
//
|
||||
// while (!predicate()) {
|
||||
// cv.wait(lock);
|
||||
// }
|
||||
//
|
||||
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
|
||||
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
|
||||
// the blocking call can never go unanswered.
|
||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||
if (stop_) {
|
||||
break; // only point where this loop can exit.
|
||||
@@ -142,26 +176,27 @@ void Loop::Process() {
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
// Block must be the last op in the current batch since no further submit can be
|
||||
// issued until the blocking call is finished.
|
||||
CHECK(queue_.empty());
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_blocking) {
|
||||
// Unblock, we can write to the global queue again.
|
||||
lock.unlock();
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
|
||||
// Clear the local queue, this is blocking the current worker thread (but not the
|
||||
// client thread), wait until all operations are finished.
|
||||
auto rc = this->EmptyQueue(&qcopy);
|
||||
|
||||
if (is_blocking) {
|
||||
// The unlock is delayed if this is a blocking call
|
||||
lock.unlock();
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
@@ -228,7 +263,6 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
@@ -243,8 +277,20 @@ Result Loop::Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] { this->Process(); }};
|
||||
worker_ = std::thread{[this] {
|
||||
this->Process();
|
||||
}};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -19,20 +19,27 @@ namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
Op& operator=(Op const&) = default;
|
||||
Op(Op&&) = default;
|
||||
Op& operator=(Op&&) = default;
|
||||
// For testing purpose only
|
||||
[[nodiscard]] static Op Sleep(std::size_t seconds) {
|
||||
Op op{kSleep};
|
||||
op.n = seconds;
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -54,7 +61,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
@@ -64,12 +71,7 @@ class Loop {
|
||||
*/
|
||||
Result Stop();
|
||||
|
||||
void Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
void Submit(Op op);
|
||||
|
||||
/**
|
||||
* @brief Block the event loop until all ops are finished. In the case of failure, this
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "comm.cuh"
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstdint> // for int32_t
|
||||
@@ -58,6 +58,7 @@ struct Magic {
|
||||
}
|
||||
};
|
||||
|
||||
// Basic commands for communication between workers and the tracker.
|
||||
enum class CMD : std::int32_t {
|
||||
kInvalid = 0,
|
||||
kStart = 1,
|
||||
@@ -84,7 +85,10 @@ struct Connect {
|
||||
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
||||
std::string* task_id) const {
|
||||
std::string init;
|
||||
sock->Recv(&init);
|
||||
auto rc = sock->Recv(&init);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Connect protocol failed.", std::move(rc));
|
||||
}
|
||||
auto jinit = Json::Load(StringView{init});
|
||||
*world = get<Integer const>(jinit["world_size"]);
|
||||
*rank = get<Integer const>(jinit["rank"]);
|
||||
@@ -122,9 +126,9 @@ class Start {
|
||||
}
|
||||
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
||||
std::string scmd;
|
||||
auto n_bytes = tracker->Recv(&scmd);
|
||||
if (n_bytes <= 0) {
|
||||
return Fail("Failed to recv init command from tracker.");
|
||||
auto rc = tracker->Recv(&scmd);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to recv init command from tracker.", std::move(rc));
|
||||
}
|
||||
auto jcmd = Json::Load(scmd);
|
||||
auto world = get<Integer const>(jcmd["world_size"]);
|
||||
@@ -132,7 +136,7 @@ class Start {
|
||||
return Fail("Invalid world size.");
|
||||
}
|
||||
*p_world = world;
|
||||
return Success();
|
||||
return rc;
|
||||
}
|
||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
||||
std::int32_t* p_port, TCPSocket* p_sock,
|
||||
@@ -150,6 +154,7 @@ class Start {
|
||||
}
|
||||
};
|
||||
|
||||
// Protocol for communicating with the tracker for printing message.
|
||||
struct Print {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
||||
Json jcmd{Object{}};
|
||||
@@ -172,6 +177,7 @@ struct Print {
|
||||
}
|
||||
};
|
||||
|
||||
// Protocol for communicating with the tracker during error.
|
||||
struct ErrorCMD {
|
||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
||||
auto msg = res.Report();
|
||||
@@ -199,6 +205,7 @@ struct ErrorCMD {
|
||||
}
|
||||
};
|
||||
|
||||
// Protocol for communicating with the tracker during shutdown.
|
||||
struct ShutdownCMD {
|
||||
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
||||
Json jcmd{Object{}};
|
||||
@@ -211,4 +218,40 @@ struct ShutdownCMD {
|
||||
return Success();
|
||||
}
|
||||
};
|
||||
|
||||
// Protocol for communicating with the local error handler during error or shutdown. Only
|
||||
// one protocol that doesn't have the tracker involved.
|
||||
struct Error {
|
||||
constexpr static std::int32_t ShutdownSignal() { return 0; }
|
||||
constexpr static std::int32_t ErrorSignal() { return -1; }
|
||||
|
||||
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||
std::int32_t err{ErrorSignal()};
|
||||
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send error signal");
|
||||
}
|
||||
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||
// close.
|
||||
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||
if (n_sent == sizeof(err)) {
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to send shutdown signal");
|
||||
}
|
||||
// get signal, either for error or for shutdown.
|
||||
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||
std::int32_t err{ShutdownSignal()};
|
||||
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||
if (n_recv == sizeof(err)) {
|
||||
*p_is_error = err == 1;
|
||||
return Success();
|
||||
}
|
||||
return Fail("Failed to receive error signal.");
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective::proto
|
||||
|
||||
86
src/collective/result.cc
Normal file
86
src/collective/result.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#include <filesystem> // for path
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
[[nodiscard]] std::string ResultImpl::Report() const {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::error_code ResultImpl::Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
|
||||
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
auto ptr = this;
|
||||
while (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
* Copyright 2022-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstring> // std::memcpy, std::memset
|
||||
#include <filesystem> // for path
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
#include <system_error> // for error_code, system_category
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "rabit/internal/socket.h" // for PollHelper
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
@@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
|
||||
return bytes;
|
||||
}
|
||||
|
||||
std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||
CHECK(!this->IsClosed());
|
||||
std::int32_t len;
|
||||
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
|
||||
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||
return Fail("Failed to recv string length.");
|
||||
}
|
||||
p_str->resize(len);
|
||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||
CHECK_EQ(bytes, len) << "Failed to recv string.";
|
||||
return bytes;
|
||||
if (static_cast<decltype(len)>(bytes) != len) {
|
||||
return Fail("Failed to recv string.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
@@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||
if (attempt > 0) {
|
||||
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
Sleep(attempt << 1);
|
||||
#else
|
||||
sleep(attempt << 1);
|
||||
#endif
|
||||
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
|
||||
}
|
||||
|
||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||
@@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Failed to connect to " << host << ":" << port;
|
||||
conn.Close();
|
||||
return Fail(ss.str(), std::move(last_error));
|
||||
auto close_rc = conn.Close();
|
||||
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "rabit/internal/socket.h"
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <netdb.h> // gethostbyname
|
||||
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
||||
@@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||
} << [&] {
|
||||
std::string cmd;
|
||||
sock_.Recv(&cmd);
|
||||
auto rc = sock_.Recv(&cmd);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
jcmd = Json::Load(StringView{cmd});
|
||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||
return Success();
|
||||
return rc;
|
||||
} << [&] {
|
||||
if (cmd_ == proto::CMD::kStart) {
|
||||
proto::Start start;
|
||||
@@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
||||
|
||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||
std::string self;
|
||||
auto rc = collective::GetHostAddress(&self);
|
||||
host_ = OptionalArg<String>(config, "host", self);
|
||||
auto rc = Success() << [&] {
|
||||
return collective::GetHostAddress(&self);
|
||||
} << [&] {
|
||||
host_ = OptionalArg<String>(config, "host", self);
|
||||
|
||||
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||
rc = listener_.Bind(host_, &this->port_);
|
||||
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||
return listener_.Bind(host_, &this->port_);
|
||||
} << [&] {
|
||||
return listener_.Listen();
|
||||
};
|
||||
SafeColl(rc);
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
@@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
//
|
||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||
// tracker and the worker might be waiting for each other.
|
||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
||||
auto rc = Success() << [&] {
|
||||
return Connect(w.first, w.second, 1, timeout_, &out);
|
||||
} << [&] {
|
||||
return proto::Error{}.SignalError(&out);
|
||||
};
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to inform workers to stop.");
|
||||
return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
|
||||
}
|
||||
}
|
||||
return Success();
|
||||
@@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
return std::async(std::launch::async, [this, handle_error] {
|
||||
State state{this->n_workers_};
|
||||
|
||||
auto select_accept = [&](TCPSocket* sock, auto* addr) {
|
||||
// accept with poll so that we can enable timeout and interruption.
|
||||
rabit::utils::PollHelper poll;
|
||||
auto rc = Success() << [&] {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
return listener_.NonBlocking(true);
|
||||
} << [&] {
|
||||
std::lock_guard lock{listener_mu_};
|
||||
poll.WatchRead(listener_);
|
||||
if (state.running) {
|
||||
// Don't timeout if the communicator group is up and running.
|
||||
return poll.Poll(std::chrono::seconds{-1});
|
||||
} else {
|
||||
// Have timeout for workers to bootstrap.
|
||||
return poll.Poll(timeout_);
|
||||
}
|
||||
} << [&] {
|
||||
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
|
||||
// due to shutdown, the state is still valid (closed).
|
||||
return listener_.Accept(sock, addr);
|
||||
};
|
||||
return rc;
|
||||
};
|
||||
|
||||
while (state.ShouldContinue()) {
|
||||
TCPSocket sock;
|
||||
SockAddress addr;
|
||||
this->ready_ = true;
|
||||
auto rc = listener_.Accept(&sock, &addr);
|
||||
auto rc = select_accept(&sock, &addr);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to accept connection.", std::move(rc));
|
||||
return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
|
||||
}
|
||||
|
||||
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
||||
@@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
state.Error();
|
||||
rc = handle_error(worker);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to handle abort.", std::move(rc));
|
||||
return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
state.Bootstrap();
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
return this->Stop() + std::move(rc);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
}
|
||||
case proto::CMD::kInvalid:
|
||||
default: {
|
||||
return Fail("Invalid command received.");
|
||||
return Fail("Invalid command received.", this->Stop());
|
||||
}
|
||||
}
|
||||
}
|
||||
ready_ = false;
|
||||
return Success();
|
||||
return this->Stop();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||
SafeColl(rc);
|
||||
|
||||
Json args{Object{}};
|
||||
args["DMLC_TRACKER_URI"] = String{host_};
|
||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||
args["dmlc_tracker_uri"] = String{host_};
|
||||
args["dmlc_tracker_port"] = this->Port();
|
||||
return args;
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RabitTracker::Stop() {
|
||||
if (!this->Ready()) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
ready_ = false;
|
||||
std::lock_guard lock{listener_mu_};
|
||||
if (this->listener_.IsClosed()) {
|
||||
return Success();
|
||||
}
|
||||
|
||||
return Success() << [&] {
|
||||
// This should have the effect of stopping the `accept` call.
|
||||
return this->listener_.Shutdown();
|
||||
} << [&] {
|
||||
return listener_.Close();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||
auto rc = GetHostName(out);
|
||||
if (!rc.OK()) {
|
||||
|
||||
@@ -36,15 +36,18 @@ namespace xgboost::collective {
|
||||
* signal an error to the tracker and the tracker will notify other workers.
|
||||
*/
|
||||
class Tracker {
|
||||
public:
|
||||
enum class SortBy : std::int8_t {
|
||||
kHost = 0,
|
||||
kTask = 1,
|
||||
};
|
||||
|
||||
protected:
|
||||
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
||||
// setting, multiple workers can occupy the same host, in which case one should sort
|
||||
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
||||
// we use host as the default.
|
||||
enum class SortBy : std::int8_t {
|
||||
kHost = 0,
|
||||
kTask = 1,
|
||||
} sortby_;
|
||||
SortBy sortby_;
|
||||
|
||||
protected:
|
||||
std::int32_t n_workers_{0};
|
||||
@@ -54,10 +57,7 @@ class Tracker {
|
||||
|
||||
public:
|
||||
explicit Tracker(Json const& config);
|
||||
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
|
||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
||||
|
||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
||||
virtual ~Tracker() = default;
|
||||
|
||||
[[nodiscard]] Result WaitUntilReady() const;
|
||||
|
||||
@@ -69,6 +69,11 @@ class Tracker {
|
||||
* @brief Flag to indicate whether the server is running.
|
||||
*/
|
||||
[[nodiscard]] bool Ready() const { return ready_; }
|
||||
/**
|
||||
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
|
||||
* calling accept.
|
||||
*/
|
||||
virtual Result Stop() { return Success(); }
|
||||
};
|
||||
|
||||
class RabitTracker : public Tracker {
|
||||
@@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
|
||||
// record for how to reach out to workers if error happens.
|
||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||
// listening socket for incoming workers.
|
||||
//
|
||||
// At the moment, the listener calls accept without first polling. We can add an
|
||||
// additional unix domain socket to allow cancelling the accept.
|
||||
TCPSocket listener_;
|
||||
// mutex for protecting the listener, used to prevent race when it's listening while
|
||||
// another thread tries to shut it down.
|
||||
std::mutex listener_mu_;
|
||||
|
||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||
|
||||
public:
|
||||
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
|
||||
std::chrono::seconds timeout)
|
||||
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
|
||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||
auto rc = listener_.Bind(host, &this->port_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
listener_.Listen();
|
||||
}
|
||||
|
||||
explicit RabitTracker(Json const& config);
|
||||
~RabitTracker() noexcept(false) override = default;
|
||||
~RabitTracker() override = default;
|
||||
|
||||
std::future<Result> Run() override;
|
||||
[[nodiscard]] Json WorkerArgs() const override;
|
||||
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
|
||||
// one of the workers failes to start.
|
||||
[[nodiscard]] Result Stop() override;
|
||||
};
|
||||
|
||||
// Prob the public IP address of the host, need a better method.
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
|
||||
#include <thrust/logical.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
#include <thrust/transform_scan.h>
|
||||
@@ -301,21 +300,22 @@ class MemoryLogger {
|
||||
void RegisterAllocation(void *ptr, size_t n) {
|
||||
device_allocations[ptr] = n;
|
||||
currently_allocated_bytes += n;
|
||||
peak_allocated_bytes =
|
||||
std::max(peak_allocated_bytes, currently_allocated_bytes);
|
||||
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
|
||||
num_allocations++;
|
||||
CHECK_GT(num_allocations, num_deallocations);
|
||||
}
|
||||
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
|
||||
auto itr = device_allocations.find(ptr);
|
||||
if (itr == device_allocations.end()) {
|
||||
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
|
||||
<< current_device << " that was never allocated ";
|
||||
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
|
||||
<< " that was never allocated\n"
|
||||
<< dmlc::StackTrace();
|
||||
} else {
|
||||
num_deallocations++;
|
||||
CHECK_LE(num_deallocations, num_allocations);
|
||||
currently_allocated_bytes -= itr->second;
|
||||
device_allocations.erase(itr);
|
||||
}
|
||||
num_deallocations++;
|
||||
CHECK_LE(num_deallocations, num_allocations);
|
||||
currently_allocated_bytes -= itr->second;
|
||||
device_allocations.erase(itr);
|
||||
}
|
||||
};
|
||||
DeviceStats stats_;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::error {
|
||||
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
||||
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
||||
std::stringstream ss;
|
||||
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
|
||||
return ss.str();
|
||||
|
||||
@@ -89,7 +89,7 @@ void WarnDeprecatedGPUId();
|
||||
|
||||
void WarnEmptyDataset();
|
||||
|
||||
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
|
||||
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
|
||||
|
||||
constexpr StringView InvalidCUDAOrdinal() {
|
||||
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#define COMMON_HIST_UTIL_CUH_
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/sort.h> // for sort
|
||||
|
||||
#include <cstddef> // for size_t
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "quantile.h"
|
||||
|
||||
#include <limits>
|
||||
#include <numeric> // for partial_sum
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
@@ -8,8 +8,8 @@
|
||||
#include <thrust/transform_scan.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <memory>
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <numeric> // for partial_sum
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
/**
|
||||
* Copyright 2020-2024, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
|
||||
#define XGBOOST_COMMON_QUANTILE_CUH_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/span.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "device_helpers.cuh"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2019
|
||||
/**
|
||||
* Copyright 2019-2024, XGBoost Contributors
|
||||
*/
|
||||
#include "timer.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
@@ -61,6 +60,9 @@ void Monitor::Print() const {
|
||||
kv.second.timer.elapsed)
|
||||
.count());
|
||||
}
|
||||
if (stat_map.empty()) {
|
||||
return;
|
||||
}
|
||||
LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========";
|
||||
this->PrintStatistics(stat_map);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include <cmath> // for abs
|
||||
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
|
||||
#include <cstring> // for size_t, strcmp, memcpy
|
||||
#include <exception> // for exception
|
||||
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
|
||||
#include <map> // for map, operator!=
|
||||
#include <numeric> // for accumulate, partial_sum
|
||||
@@ -22,7 +21,6 @@
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/common.h" // for Split
|
||||
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
|
||||
#include "../common/group_data.h" // for ParallelGroupBuilder
|
||||
#include "../common/io.h" // for PeekableInStream
|
||||
@@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_
|
||||
<< ", must have at least 1 column even if it's empty.";
|
||||
auto const& first = get<Object const>(array.front());
|
||||
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
||||
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||
} else {
|
||||
auto const& first = get<Object const>(j_interface);
|
||||
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
||||
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||
}
|
||||
|
||||
if (is_cuda) {
|
||||
@@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
|
||||
size_t num) {
|
||||
CHECK(key);
|
||||
auto proc = [&](auto cast_d_ptr) {
|
||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
|
||||
CHECK(t.CContiguous());
|
||||
Json interface {
|
||||
linalg::ArrayInterface(t)
|
||||
};
|
||||
assert(ArrayInterface<1>{interface}.is_contiguous);
|
||||
return interface;
|
||||
};
|
||||
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
|
||||
switch (dtype) {
|
||||
case xgboost::DataType::kFloat32: {
|
||||
auto cast_ptr = reinterpret_cast<const float*>(dptr);
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kDouble: {
|
||||
auto cast_ptr = reinterpret_cast<const double*>(dptr);
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt32: {
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt64: {
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
const void** out_dptr) const {
|
||||
if (dtype == DataType::kFloat32) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
* Copyright 2021-2024, XGBoost contributors
|
||||
*/
|
||||
#include "file_iterator.h"
|
||||
|
||||
@@ -10,7 +10,10 @@
|
||||
#include <ostream> // for operator<<, basic_ostream, istringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/common.h" // for Split
|
||||
#include "../common/common.h" // for Split
|
||||
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
#include "xgboost/string_view.h" // for operator<<, StringView
|
||||
|
||||
namespace xgboost::data {
|
||||
@@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) {
|
||||
for (size_t i = 0; i < arg_list.size(); ++i) {
|
||||
std::istringstream is(arg_list[i]);
|
||||
std::pair<std::string, std::string> kv;
|
||||
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
|
||||
<< " for key in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
|
||||
<< " for value in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.first, '='))
|
||||
<< "Invalid uri argument format" << " for key in arg " << i + 1;
|
||||
CHECK(std::getline(is, kv.second))
|
||||
<< "Invalid uri argument format" << " for value in arg " << i + 1;
|
||||
args.insert(kv);
|
||||
}
|
||||
if (args.find("format") == args.cend()) {
|
||||
@@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) {
|
||||
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
|
||||
}
|
||||
}
|
||||
|
||||
int FileIterator::Next() {
|
||||
CHECK(parser_);
|
||||
if (parser_->Next()) {
|
||||
row_block_ = parser_->Value();
|
||||
|
||||
indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1);
|
||||
values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]);
|
||||
indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]);
|
||||
|
||||
size_t n_columns =
|
||||
*std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]);
|
||||
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
|
||||
// this condition and just add 1 to n_columns
|
||||
n_columns += 1;
|
||||
|
||||
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns);
|
||||
|
||||
if (row_block_.label) {
|
||||
auto str = linalg::Make1dInterface(row_block_.label, row_block_.size);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
}
|
||||
if (row_block_.qid) {
|
||||
auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||
}
|
||||
if (row_block_.weight) {
|
||||
auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
// Continue iteration
|
||||
return true;
|
||||
} else {
|
||||
// Stop iteration
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
* Copyright 2021-2024, XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
#define XGBOOST_DATA_FILE_ITERATOR_H_
|
||||
|
||||
#include <algorithm> // for max_element
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
#include <utility> // for move
|
||||
|
||||
#include "dmlc/data.h" // for RowBlock, Parser
|
||||
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
|
||||
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate
|
||||
|
||||
namespace xgboost::data {
|
||||
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
|
||||
@@ -53,41 +49,7 @@ class FileIterator {
|
||||
XGDMatrixFree(proxy_);
|
||||
}
|
||||
|
||||
int Next() {
|
||||
CHECK(parser_);
|
||||
if (parser_->Next()) {
|
||||
row_block_ = parser_->Value();
|
||||
using linalg::MakeVec;
|
||||
|
||||
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
|
||||
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
|
||||
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
|
||||
|
||||
size_t n_columns = *std::max_element(row_block_.index,
|
||||
row_block_.index + row_block_.offset[row_block_.size]);
|
||||
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
|
||||
// this condition and just add 1 to n_columns
|
||||
n_columns += 1;
|
||||
|
||||
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
|
||||
values_.c_str(), n_columns);
|
||||
|
||||
if (row_block_.label) {
|
||||
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
|
||||
}
|
||||
if (row_block_.qid) {
|
||||
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
|
||||
}
|
||||
if (row_block_.weight) {
|
||||
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
|
||||
}
|
||||
// Continue iteration
|
||||
return true;
|
||||
} else {
|
||||
// Stop iteration
|
||||
return false;
|
||||
}
|
||||
}
|
||||
int Next();
|
||||
|
||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file sparse_page_source.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||
@@ -7,23 +7,26 @@
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <atomic> // for atomic
|
||||
#include <cstdio> // for remove
|
||||
#include <future> // for async
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex> // for mutex
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility> // for pair, move
|
||||
#include <vector>
|
||||
#include <memory> // for unique_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <string> // for string
|
||||
#include <utility> // for pair, move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/io.h" // for PrivateMmapConstStream
|
||||
#include "../common/timer.h" // for Monitor, Timer
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
#include "../common/io.h" // for PrivateMmapConstStream
|
||||
#include "../common/timer.h" // for Monitor, Timer
|
||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/data.h" // for SparsePage, CSCPage
|
||||
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
|
||||
#include "xgboost/logging.h" // for CHECK_EQ
|
||||
|
||||
namespace xgboost::data {
|
||||
inline void TryDeleteCacheFile(const std::string& file) {
|
||||
@@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
|
||||
exce_.Rethrow();
|
||||
|
||||
auto const config = *GlobalConfigThreadLocalStore::Get();
|
||||
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||
fetch_it %= n_batches_; // ring
|
||||
if (ring_->at(fetch_it).valid()) {
|
||||
@@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
auto const* self = this; // make sure it's const
|
||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
|
||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
|
||||
*GlobalConfigThreadLocalStore::Get() = config;
|
||||
auto page = std::make_shared<S>();
|
||||
this->exce_.Run([&] {
|
||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023 by Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@@ -11,14 +11,12 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric> // for iota
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../tree/param.h" // TrainParam
|
||||
#include "gbtree_model.h"
|
||||
|
||||
@@ -10,15 +10,15 @@
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h" // MetricNoCache
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#include "../common/math.h"
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "../common/pseudo_huber.h"
|
||||
#include "../common/quantile_loss_utils.h" // QuantileLossParam
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
#include "xgboost/collective/result.h" // for SafeColl
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
#include <string>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/common.h"
|
||||
#include "xgboost/metric.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <numeric> // for accumulate
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
|
||||
@@ -9,10 +9,9 @@
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <numeric> // for accumulate
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/survival_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
* Copyright 2019-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/random.h>
|
||||
#include <thrust/sort.h> // for sort
|
||||
#include <thrust/transform.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
Reference in New Issue
Block a user