[coll] Improvements and fixes for tracker and allreduce. (#9745)
- Allow the tracker to wait. - Fix allreduce type cast - Return args from the federated tracker.
This commit is contained in:
parent
0ff8572737
commit
4da4e092b5
@ -6,7 +6,6 @@
|
|||||||
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
#include <grpcpp/security/server_credentials.h> // for InsecureServerCredentials, ...
|
||||||
#include <grpcpp/server_builder.h> // for ServerBuilder
|
#include <grpcpp/server_builder.h> // for ServerBuilder
|
||||||
|
|
||||||
#include <chrono> // for ms
|
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <exception> // for exception
|
#include <exception> // for exception
|
||||||
#include <limits> // for numeric_limits
|
#include <limits> // for numeric_limits
|
||||||
@ -61,7 +60,7 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::future<Result> FederatedTracker::Run() {
|
std::future<Result> FederatedTracker::Run() {
|
||||||
return std::async([this]() {
|
return std::async(std::launch::async, [this]() {
|
||||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||||
xgboost::collective::federated::FederatedService service{
|
xgboost::collective::federated::FederatedService service{
|
||||||
static_cast<std::int32_t>(this->n_workers_)};
|
static_cast<std::int32_t>(this->n_workers_)};
|
||||||
@ -98,10 +97,13 @@ std::future<Result> FederatedTracker::Run() {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
server_ = builder.BuildAndStart();
|
server_ = builder.BuildAndStart();
|
||||||
|
ready_ = true;
|
||||||
server_->Wait();
|
server_->Wait();
|
||||||
} catch (std::exception const& e) {
|
} catch (std::exception const& e) {
|
||||||
return collective::Fail(std::string{e.what()});
|
return collective::Fail(std::string{e.what()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ready_ = false;
|
||||||
return collective::Success();
|
return collective::Success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -109,18 +111,8 @@ std::future<Result> FederatedTracker::Run() {
|
|||||||
FederatedTracker::~FederatedTracker() = default;
|
FederatedTracker::~FederatedTracker() = default;
|
||||||
|
|
||||||
Result FederatedTracker::Shutdown() {
|
Result FederatedTracker::Shutdown() {
|
||||||
common::Timer timer;
|
auto rc = this->WaitUntilReady();
|
||||||
timer.Start();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
using namespace std::chrono_literals;
|
|
||||||
while (!server_) {
|
|
||||||
timer.Stop();
|
|
||||||
auto ela = timer.ElapsedSeconds();
|
|
||||||
if (ela > this->Timeout().count()) {
|
|
||||||
return Fail("Failed to shutdown, timeout:" + std::to_string(this->Timeout().count()) +
|
|
||||||
" seconds.");
|
|
||||||
}
|
|
||||||
std::this_thread::sleep_for(10ms);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
server_->Shutdown();
|
server_->Shutdown();
|
||||||
@ -130,4 +122,17 @@ Result FederatedTracker::Shutdown() {
|
|||||||
|
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
||||||
|
auto rc = this->WaitUntilReady();
|
||||||
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
rc = GetHostAddress(&host);
|
||||||
|
CHECK(rc.OK());
|
||||||
|
Json args{Object{}};
|
||||||
|
args["DMLC_TRACKER_URI"] = String{host};
|
||||||
|
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||||
|
return args;
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -57,9 +57,8 @@ class FederatedTracker : public collective::Tracker {
|
|||||||
explicit FederatedTracker(Json const& config);
|
explicit FederatedTracker(Json const& config);
|
||||||
~FederatedTracker() override;
|
~FederatedTracker() override;
|
||||||
std::future<Result> Run() override;
|
std::future<Result> Run() override;
|
||||||
// federated tracker do not provide initialization parameters, users have to provide it
|
|
||||||
// themseleves.
|
[[nodiscard]] Json WorkerArgs() const override;
|
||||||
[[nodiscard]] Json WorkerArgs() const override { return Json{Null{}}; }
|
|
||||||
[[nodiscard]] Result Shutdown();
|
[[nodiscard]] Result Shutdown();
|
||||||
};
|
};
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -3,19 +3,35 @@
|
|||||||
*/
|
*/
|
||||||
#include "coll.h"
|
#include "coll.h"
|
||||||
|
|
||||||
#include <algorithm> // for min, max
|
#include <algorithm> // for min, max, copy_n
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int8_t, int64_t
|
#include <cstdint> // for int8_t, int64_t
|
||||||
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
#include <functional> // for bit_and, bit_or, bit_xor, plus
|
||||||
|
#include <type_traits> // for is_floating_point_v, is_same_v
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
#include "allgather.h" // for RingAllgatherV, RingAllgather
|
#include "allgather.h" // for RingAllgatherV, RingAllgather
|
||||||
#include "allreduce.h" // for Allreduce
|
#include "allreduce.h" // for Allreduce
|
||||||
#include "broadcast.h" // for Broadcast
|
#include "broadcast.h" // for Broadcast
|
||||||
#include "comm.h" // for Comm
|
#include "comm.h" // for Comm
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
#include "cuda_fp16.h" // for __half
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
template <typename T>
|
||||||
|
bool constexpr IsFloatingPointV() {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
return std::is_floating_point_v<T> || std::is_same_v<T, __half>;
|
||||||
|
#else
|
||||||
|
return std::is_floating_point_v<T>;
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Coll::Allreduce(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
ArrayInterfaceHandler::Type, Op op) {
|
ArrayInterfaceHandler::Type type, Op op) {
|
||||||
namespace coll = ::xgboost::collective;
|
namespace coll = ::xgboost::collective;
|
||||||
|
|
||||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||||
@ -25,32 +41,59 @@ namespace xgboost::collective {
|
|||||||
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto fn = [&](auto elem_op) {
|
|
||||||
return coll::Allreduce(
|
auto fn = [&](auto elem_op, auto t) {
|
||||||
comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); });
|
using T = decltype(t);
|
||||||
|
auto erased_fn = [redop_fn, elem_op](common::Span<std::int8_t const> lhs,
|
||||||
|
common::Span<std::int8_t> out) {
|
||||||
|
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||||
|
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||||
|
auto rhs_t = common::RestoreType<T>(out);
|
||||||
|
|
||||||
|
redop_fn(lhs_t, rhs_t, elem_op);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto rc = DispatchDType(type, [&](auto t) {
|
||||||
|
using T = decltype(t);
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case Op::kMax: {
|
case Op::kMax: {
|
||||||
return fn([](auto l, auto r) { return std::max(l, r); });
|
return fn([](auto l, auto r) { return std::max(l, r); }, t);
|
||||||
}
|
}
|
||||||
case Op::kMin: {
|
case Op::kMin: {
|
||||||
return fn([](auto l, auto r) { return std::min(l, r); });
|
return fn([](auto l, auto r) { return std::min(l, r); }, t);
|
||||||
}
|
}
|
||||||
case Op::kSum: {
|
case Op::kSum: {
|
||||||
return fn(std::plus<>{});
|
return fn(std::plus<>{}, t);
|
||||||
}
|
}
|
||||||
case Op::kBitwiseAND: {
|
case Op::kBitwiseAND: {
|
||||||
return fn(std::bit_and<>{});
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
|
return Fail("Invalid type.");
|
||||||
|
} else {
|
||||||
|
return fn(std::bit_and<>{}, t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case Op::kBitwiseOR: {
|
case Op::kBitwiseOR: {
|
||||||
return fn(std::bit_or<>{});
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
|
return Fail("Invalid type.");
|
||||||
|
} else {
|
||||||
|
return fn(std::bit_or<>{}, t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case Op::kBitwiseXOR: {
|
case Op::kBitwiseXOR: {
|
||||||
return fn(std::bit_xor<>{});
|
if constexpr (IsFloatingPointV<T>()) {
|
||||||
|
return Fail("Invalid type.");
|
||||||
|
} else {
|
||||||
|
return fn(std::bit_xor<>{}, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return comm.Block();
|
}
|
||||||
|
return Fail("Invalid op.");
|
||||||
|
});
|
||||||
|
|
||||||
|
return std::move(rc) << [&] { return comm.Block(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Coll::Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
#include <algorithm> // for sort
|
#include <algorithm> // for sort
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds, ms
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <utility> // for move, forward
|
#include <utility> // for move, forward
|
||||||
@ -37,6 +37,25 @@ Tracker::Tracker(Json const& config)
|
|||||||
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
|
||||||
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
|
||||||
|
|
||||||
|
Result Tracker::WaitUntilReady() const {
|
||||||
|
using namespace std::chrono_literals; // NOLINT
|
||||||
|
|
||||||
|
// Busy waiting. The function is mostly for waiting for the OS to launch an async
|
||||||
|
// thread, which should be reasonably fast.
|
||||||
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
while (!this->Ready()) {
|
||||||
|
auto ela = timer.Duration().count();
|
||||||
|
if (ela > this->Timeout().count()) {
|
||||||
|
return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) +
|
||||||
|
" seconds.");
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(100ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
|
||||||
: sock_{std::move(sock)} {
|
: sock_{std::move(sock)} {
|
||||||
auto host = addr.Addr();
|
auto host = addr.Addr();
|
||||||
@ -76,6 +95,7 @@ RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
|||||||
auto rc = collective::GetHostAddress(&self);
|
auto rc = collective::GetHostAddress(&self);
|
||||||
auto host = OptionalArg<String>(config, "host", self);
|
auto host = OptionalArg<String>(config, "host", self);
|
||||||
|
|
||||||
|
host_ = host;
|
||||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
listener_ = TCPSocket::Create(SockDomain::kV4);
|
||||||
rc = listener_.Bind(host, &this->port_);
|
rc = listener_.Bind(host, &this->port_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
CHECK(rc.OK()) << rc.Report();
|
||||||
@ -173,6 +193,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
while (state.ShouldContinue()) {
|
while (state.ShouldContinue()) {
|
||||||
TCPSocket sock;
|
TCPSocket sock;
|
||||||
SockAddrV4 addr;
|
SockAddrV4 addr;
|
||||||
|
this->ready_ = true;
|
||||||
auto rc = listener_.Accept(&sock, &addr);
|
auto rc = listener_.Accept(&sock, &addr);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to accept connection.", std::move(rc));
|
return Fail("Failed to accept connection.", std::move(rc));
|
||||||
@ -237,10 +258,21 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ready_ = false;
|
||||||
return Success();
|
return Success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Json RabitTracker::WorkerArgs() const {
|
||||||
|
auto rc = this->WaitUntilReady();
|
||||||
|
CHECK(rc.OK()) << rc.Report();
|
||||||
|
|
||||||
|
Json args{Object{}};
|
||||||
|
args["DMLC_TRACKER_URI"] = String{host_};
|
||||||
|
args["DMLC_TRACKER_PORT"] = this->Port();
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||||
auto rc = GetHostName(out);
|
auto rc = GetHostName(out);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
|
|||||||
@ -40,6 +40,7 @@ class Tracker {
|
|||||||
std::int32_t n_workers_{0};
|
std::int32_t n_workers_{0};
|
||||||
std::int32_t port_{-1};
|
std::int32_t port_{-1};
|
||||||
std::chrono::seconds timeout_{0};
|
std::chrono::seconds timeout_{0};
|
||||||
|
std::atomic<bool> ready_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Tracker(Json const& config);
|
explicit Tracker(Json const& config);
|
||||||
@ -47,10 +48,17 @@ class Tracker {
|
|||||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
||||||
|
|
||||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
virtual ~Tracker() noexcept(false){}; // NOLINT
|
||||||
|
|
||||||
|
[[nodiscard]] Result WaitUntilReady() const;
|
||||||
|
|
||||||
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
[[nodiscard]] virtual std::future<Result> Run() = 0;
|
||||||
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
[[nodiscard]] virtual Json WorkerArgs() const = 0;
|
||||||
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
[[nodiscard]] std::chrono::seconds Timeout() const { return timeout_; }
|
||||||
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
|
[[nodiscard]] virtual std::int32_t Port() const { return port_; }
|
||||||
|
/**
|
||||||
|
* @brief Flag to indicate whether the server is running.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bool Ready() const { return ready_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class RabitTracker : public Tracker {
|
class RabitTracker : public Tracker {
|
||||||
@ -124,13 +132,7 @@ class RabitTracker : public Tracker {
|
|||||||
~RabitTracker() noexcept(false) override = default;
|
~RabitTracker() noexcept(false) override = default;
|
||||||
|
|
||||||
std::future<Result> Run() override;
|
std::future<Result> Run() override;
|
||||||
|
[[nodiscard]] Json WorkerArgs() const override;
|
||||||
[[nodiscard]] Json WorkerArgs() const override {
|
|
||||||
Json args{Object{}};
|
|
||||||
args["DMLC_TRACKER_URI"] = String{host_};
|
|
||||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Prob the public IP address of the host, need a better method.
|
// Prob the public IP address of the host, need a better method.
|
||||||
|
|||||||
@ -29,6 +29,7 @@ struct Timer {
|
|||||||
void Start() { start = ClockT::now(); }
|
void Start() { start = ClockT::now(); }
|
||||||
void Stop() { elapsed += ClockT::now() - start; }
|
void Stop() { elapsed += ClockT::now() - start; }
|
||||||
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
|
double ElapsedSeconds() const { return SecondsT(elapsed).count(); }
|
||||||
|
SecondsT Duration() const { return ClockT::now() - start; }
|
||||||
void PrintElapsed(std::string label) {
|
void PrintElapsed(std::string label) {
|
||||||
char buffer[255];
|
char buffer[255];
|
||||||
snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(),
|
snprintf(buffer, sizeof(buffer), "%s:\t %fs", label.c_str(),
|
||||||
|
|||||||
@ -27,9 +27,15 @@ class PrintWorker : public WorkerForTest {
|
|||||||
|
|
||||||
TEST_F(TrackerTest, Bootstrap) {
|
TEST_F(TrackerTest, Bootstrap) {
|
||||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
RabitTracker tracker{host, n_workers, 0, timeout};
|
||||||
|
ASSERT_FALSE(tracker.Ready());
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
|
|
||||||
|
auto args = tracker.WorkerArgs();
|
||||||
|
ASSERT_TRUE(tracker.Ready());
|
||||||
|
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||||
|
|
||||||
std::int32_t port = tracker.Port();
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
@ -47,6 +53,9 @@ TEST_F(TrackerTest, Print) {
|
|||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
|
auto rc = tracker.WaitUntilReady();
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
|
||||||
std::int32_t port = tracker.Port();
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
|||||||
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
36
tests/cpp/plugin/federated/test_federated_tracker.cc
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <memory> // for make_unique
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "../../../../src/collective/tracker.h" // for GetHostAddress
|
||||||
|
#include "federated_tracker.h"
|
||||||
|
#include "test_worker.h"
|
||||||
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
TEST(FederatedTrackerTest, Basic) {
|
||||||
|
Json config{Object()};
|
||||||
|
config["federated_secure"] = Boolean{false};
|
||||||
|
config["n_workers"] = Integer{3};
|
||||||
|
|
||||||
|
auto tracker = std::make_unique<FederatedTracker>(config);
|
||||||
|
ASSERT_FALSE(tracker->Ready());
|
||||||
|
auto fut = tracker->Run();
|
||||||
|
auto args = tracker->WorkerArgs();
|
||||||
|
ASSERT_TRUE(tracker->Ready());
|
||||||
|
|
||||||
|
ASSERT_GE(tracker->Port(), 1);
|
||||||
|
std::string host;
|
||||||
|
auto rc = GetHostAddress(&host);
|
||||||
|
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
||||||
|
|
||||||
|
rc = tracker->Shutdown();
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
ASSERT_FALSE(tracker->Ready());
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
Loading…
x
Reference in New Issue
Block a user