[coll] Allreduce. (#9679)
This commit is contained in:
parent
da6803b75b
commit
48ac9b6cbe
@ -99,6 +99,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
|
|||||||
@ -99,6 +99,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
|
|||||||
90
src/collective/allreduce.cc
Normal file
90
src/collective/allreduce.cc
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "allreduce.h"
|
||||||
|
|
||||||
|
#include <algorithm> // for min
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, int8_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../data/array_interface.h" // for Type, DispatchDType
|
||||||
|
#include "allgather.h" // for RingAllgather
|
||||||
|
#include "comm.h" // for Comm
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective::cpu_impl {
|
||||||
|
template <typename T>
|
||||||
|
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
std::size_t n_bytes_in_seg, Func const& op) {
|
||||||
|
auto rank = comm.Rank();
|
||||||
|
auto world = comm.World();
|
||||||
|
|
||||||
|
auto dst_rank = BootstrapNext(rank, world);
|
||||||
|
auto src_rank = BootstrapPrev(rank, world);
|
||||||
|
auto next_ch = comm.Chan(dst_rank);
|
||||||
|
auto prev_ch = comm.Chan(src_rank);
|
||||||
|
|
||||||
|
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
|
||||||
|
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||||
|
// send to ring next
|
||||||
|
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
|
||||||
|
send_off = std::min(send_off, data.size_bytes());
|
||||||
|
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
||||||
|
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||||
|
|
||||||
|
next_ch->SendAll(send_seg);
|
||||||
|
|
||||||
|
// receive from ring prev
|
||||||
|
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
||||||
|
recv_off = std::min(recv_off, data.size_bytes());
|
||||||
|
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
|
||||||
|
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||||
|
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||||
|
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||||
|
|
||||||
|
prev_ch->RecvAll(seg);
|
||||||
|
auto rc = prev_ch->Block();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// accumulate to recv_seg
|
||||||
|
CHECK_EQ(seg.size(), recv_seg.size());
|
||||||
|
op(seg, recv_seg);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||||
|
ArrayInterfaceHandler::Type type) {
|
||||||
|
return DispatchDType(type, [&](auto t) {
|
||||||
|
using T = decltype(t);
|
||||||
|
// Divide the data into segments according to the number of workers.
|
||||||
|
auto n_bytes_elem = sizeof(T);
|
||||||
|
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
||||||
|
auto n = data.size_bytes() / n_bytes_elem;
|
||||||
|
auto world = comm.World();
|
||||||
|
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
|
||||||
|
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prev = BootstrapPrev(comm.Rank(), comm.World());
|
||||||
|
auto next = BootstrapNext(comm.Rank(), comm.World());
|
||||||
|
auto prev_ch = comm.Chan(prev);
|
||||||
|
auto next_ch = comm.Chan(next);
|
||||||
|
|
||||||
|
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
return comm.Block();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective::cpu_impl
|
||||||
39
src/collective/allreduce.h
Normal file
39
src/collective/allreduce.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <cstdint> // for int8_t
|
||||||
|
#include <functional> // for function
|
||||||
|
#include <type_traits> // for is_invocable_v
|
||||||
|
|
||||||
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
|
#include "comm.h" // for Comm, RestoreType
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace cpu_impl {
|
||||||
|
using Func =
|
||||||
|
std::function<void(common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out)>;
|
||||||
|
|
||||||
|
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||||
|
ArrayInterfaceHandler::Type type);
|
||||||
|
} // namespace cpu_impl
|
||||||
|
|
||||||
|
template <typename T, typename Fn>
|
||||||
|
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
|
||||||
|
Comm const& comm, common::Span<T> data, Fn redop) {
|
||||||
|
auto erased = EraseType(data);
|
||||||
|
auto type = ToDType<T>::kType;
|
||||||
|
|
||||||
|
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
||||||
|
common::Span<std::int8_t> out) {
|
||||||
|
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||||
|
auto lhs_t = RestoreType<T const>(lhs);
|
||||||
|
auto rhs_t = RestoreType<T>(out);
|
||||||
|
redop(lhs_t, rhs_t);
|
||||||
|
};
|
||||||
|
|
||||||
|
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -16,7 +16,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/bitfield.h"
|
#include "../common/bitfield.h" // for RBitField8
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/error_msg.h" // for NoF128
|
#include "../common/error_msg.h" // for NoF128
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
@ -104,7 +104,20 @@ struct ArrayInterfaceErrors {
|
|||||||
*/
|
*/
|
||||||
class ArrayInterfaceHandler {
|
class ArrayInterfaceHandler {
|
||||||
public:
|
public:
|
||||||
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
enum Type : std::int8_t {
|
||||||
|
kF2 = 0,
|
||||||
|
kF4 = 1,
|
||||||
|
kF8 = 2,
|
||||||
|
kF16 = 3,
|
||||||
|
kI1 = 4,
|
||||||
|
kI2 = 5,
|
||||||
|
kI4 = 6,
|
||||||
|
kI8 = 7,
|
||||||
|
kU1 = 8,
|
||||||
|
kU2 = 9,
|
||||||
|
kU4 = 10,
|
||||||
|
kU8 = 11,
|
||||||
|
};
|
||||||
|
|
||||||
template <typename PtrType>
|
template <typename PtrType>
|
||||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||||
@ -587,6 +600,57 @@ class ArrayInterface {
|
|||||||
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
|
||||||
|
switch (dtype) {
|
||||||
|
case ArrayInterfaceHandler::kF2: {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
return dispatch(__half{});
|
||||||
|
#else
|
||||||
|
LOG(FATAL) << "half type is only supported for CUDA input.";
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF4: {
|
||||||
|
return dispatch(float{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF8: {
|
||||||
|
return dispatch(double{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF16: {
|
||||||
|
using T = long double;
|
||||||
|
CHECK(sizeof(T) == 16) << error::NoF128();
|
||||||
|
return dispatch(T{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI1: {
|
||||||
|
return dispatch(std::int8_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI2: {
|
||||||
|
return dispatch(std::int16_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI4: {
|
||||||
|
return dispatch(std::int32_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI8: {
|
||||||
|
return dispatch(std::int64_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU1: {
|
||||||
|
return dispatch(std::uint8_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU2: {
|
||||||
|
return dispatch(std::uint16_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU4: {
|
||||||
|
return dispatch(std::uint32_t{});
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU8: {
|
||||||
|
return dispatch(std::uint64_t{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::result_of_t<Fn(std::int8_t)>();
|
||||||
|
}
|
||||||
|
|
||||||
template <std::int32_t D, typename Fn>
|
template <std::int32_t D, typename Fn>
|
||||||
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
||||||
// Only used for cuDF at the moment.
|
// Only used for cuDF at the moment.
|
||||||
@ -602,60 +666,7 @@ void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
|||||||
std::numeric_limits<std::size_t>::max()},
|
std::numeric_limits<std::size_t>::max()},
|
||||||
array.shape, array.strides, device});
|
array.shape, array.strides, device});
|
||||||
};
|
};
|
||||||
switch (array.type) {
|
DispatchDType(array.type, dispatch);
|
||||||
case ArrayInterfaceHandler::kF2: {
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
|
||||||
dispatch(__half{});
|
|
||||||
#endif
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kF4: {
|
|
||||||
dispatch(float{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kF8: {
|
|
||||||
dispatch(double{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kF16: {
|
|
||||||
using T = long double;
|
|
||||||
CHECK(sizeof(long double) == 16) << error::NoF128();
|
|
||||||
dispatch(T{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kI1: {
|
|
||||||
dispatch(std::int8_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kI2: {
|
|
||||||
dispatch(std::int16_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kI4: {
|
|
||||||
dispatch(std::int32_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kI8: {
|
|
||||||
dispatch(std::int64_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kU1: {
|
|
||||||
dispatch(std::uint8_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kU2: {
|
|
||||||
dispatch(std::uint16_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kU4: {
|
|
||||||
dispatch(std::uint32_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArrayInterfaceHandler::kU8: {
|
|
||||||
dispatch(std::uint64_t{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
72
tests/cpp/collective/test_allreduce.cc
Normal file
72
tests/cpp/collective/test_allreduce.cc
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/collective/allreduce.h"
|
||||||
|
#include "../../../src/collective/tracker.h"
|
||||||
|
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class AllreduceWorker : public WorkerForTest {
|
||||||
|
public:
|
||||||
|
using WorkerForTest::WorkerForTest;
|
||||||
|
|
||||||
|
void Basic() {
|
||||||
|
{
|
||||||
|
std::vector<double> data(13, 0.0);
|
||||||
|
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
|
rhs[i] += lhs[i];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<double> data(1, 1.0);
|
||||||
|
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
|
rhs[i] += lhs[i];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Acc() {
|
||||||
|
std::vector<double> data(314, 1.5);
|
||||||
|
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
|
rhs[i] += lhs[i];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||||
|
auto v = data[i];
|
||||||
|
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AllreduceTest : public SocketTest {};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_F(AllreduceTest, Basic) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.Basic();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AllreduceTest, Sum) {
|
||||||
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
|
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||||
|
worker.Acc();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -10,8 +10,8 @@
|
|||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../src/collective/broadcast.h" // for Broadcast
|
#include "../../../src/collective/broadcast.h" // for Broadcast
|
||||||
#include "../../../src/collective/tracker.h" // for GetHostAddress, Tracker
|
#include "../../../src/collective/tracker.h" // for GetHostAddress
|
||||||
#include "test_worker.h" // for WorkerForTest
|
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace {
|
namespace {
|
||||||
@ -41,28 +41,11 @@ class BroadcastTest : public SocketTest {};
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST_F(BroadcastTest, Basic) {
|
TEST_F(BroadcastTest, Basic) {
|
||||||
std::int32_t n_workers = std::min(24u, std::thread::hardware_concurrency());
|
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||||
std::chrono::seconds timeout{3};
|
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
|
std::int32_t r) {
|
||||||
std::string host;
|
Worker worker{host, port, timeout, n_workers, r};
|
||||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
|
||||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
|
||||||
auto fut = tracker.Run();
|
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
|
||||||
std::int32_t port = tracker.Port();
|
|
||||||
|
|
||||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
|
||||||
workers.emplace_back([=] {
|
|
||||||
Worker worker{host, port, timeout, n_workers, i};
|
|
||||||
worker.Run();
|
worker.Run();
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& t : workers) {
|
|
||||||
t.join();
|
|
||||||
}
|
|
||||||
|
|
||||||
ASSERT_TRUE(fut.get().OK());
|
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -88,4 +88,27 @@ class TrackerTest : public SocketTest {
|
|||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename WorkerFn>
|
||||||
|
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||||
|
std::chrono::seconds timeout{1};
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||||
|
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
||||||
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
workers.emplace_back([=] { worker_fn(host, port, timeout, i); });
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : workers) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user