[coll] Improve event loop. (#10199)

- Add a test for blocking calls.
- Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls.
- Handle EOF in read.
- Improve the error message in the result. Allow concatenation of multiple results.
This commit is contained in:
Jiaming Yuan 2024-04-18 03:29:52 +08:00 committed by GitHub
parent 7c0c9677a9
commit 4b10200456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 312 additions and 111 deletions

View File

@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/context.o \ $(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \

View File

@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/context.o \ $(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \

View File

@ -40,7 +40,7 @@ def main(client):
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain) prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history) print("Evaluation history:", history)
return prediction print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -144,6 +144,14 @@ which provides higher flexibility. For example:
ctest --verbose ctest --verbose
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
.. code-block::
::testing::GTEST_FLAG(filter) = "Suite.Test";
::testing::GTEST_FLAG(repeat) = 10;
*********************************************** ***********************************************
Sanitizers: Detect memory errors and data races Sanitizers: Detect memory errors and data races
*********************************************** ***********************************************

View File

@ -3,13 +3,11 @@
*/ */
#pragma once #pragma once
#include <xgboost/logging.h> #include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <memory> // for unique_ptr #include <string> // for string
#include <sstream> // for stringstream #include <system_error> // for error_code
#include <stack> // for stack #include <utility> // for move
#include <string> // for string
#include <utility> // for move
namespace xgboost::collective { namespace xgboost::collective {
namespace detail { namespace detail {
@ -48,48 +46,19 @@ struct ResultImpl {
return cur_eq; return cur_eq;
} }
[[nodiscard]] std::string Report() { [[nodiscard]] std::string Report() const;
std::stringstream ss; [[nodiscard]] std::error_code Code() const;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get(); void Concat(std::unique_ptr<ResultImpl> rhs);
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
}; };
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
} // namespace detail } // namespace detail
/** /**
@ -131,8 +100,21 @@ struct Result {
} }
return *impl_ == *that.impl_; return *impl_ == *that.impl_;
} }
friend Result operator+(Result&& lhs, Result&& rhs);
}; };
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
if (lhs.OK()) {
return std::forward<Result>(rhs);
}
if (rhs.OK()) {
return std::forward<Result>(lhs);
}
lhs.impl_->Concat(std::move(rhs.impl_));
return std::forward<Result>(lhs);
}
/** /**
* @brief Return success. * @brief Return success.
*/ */
@ -140,38 +122,43 @@ struct Result {
/** /**
* @brief Return failure. * @brief Return failure.
*/ */
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } [[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line)};
}
/** /**
* @brief Return failure with `errno`. * @brief Return failure with `errno`.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { [[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
return Result{std::move(msg), std::move(errc)}; char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
} }
/** /**
* @brief Return failure with a previous error. * @brief Return failure with a previous error.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { [[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
return Result{std::move(msg), std::forward<Result>(prev)}; std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
} }
/** /**
* @brief Return failure with a previous error and a new `errno`. * @brief Return failure with a previous error and a new `errno`.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { [[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)}; char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
std::forward<Result>(prev)};
} }
// We don't have monad, a simple helper would do. // We don't have monad, a simple helper would do.
template <typename Fn> template <typename Fn>
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { [[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) { if (!r.OK()) {
return std::forward<Result>(r); return std::forward<Result>(r);
} }
return fn(); return fn();
} }
inline void SafeColl(Result const& rc) { void SafeColl(Result const& rc);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright (c) 2022-2023, XGBoost Contributors * Copyright (c) 2022-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
@ -12,7 +12,6 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t #include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset #include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string #include <string> // std::string
#include <system_error> // std::error_code, std::system_category #include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap #include <utility> // std::swap
@ -468,19 +467,30 @@ class TCPSocket {
*addr = SockAddress{SockAddrV6{caddr}}; *addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd}; *out = TCPSocket{newfd};
} }
// On MacOS, this is automatically set to async socket if the parent socket is async
// We make sure all socket are blocking by default.
//
// On Windows, a closed socket is returned during shutdown. We guard against it when
// setting non-blocking.
if (!out->IsClosed()) {
return out->NonBlocking(false);
}
return Success(); return Success();
} }
~TCPSocket() { ~TCPSocket() {
if (!IsClosed()) { if (!IsClosed()) {
Close(); auto rc = this->Close();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
} }
} }
TCPSocket(TCPSocket const &that) = delete; TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete; TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) { TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
std::swap(this->handle_, that.handle_); std::swap(this->handle_, that.handle_);
return *this; return *this;
} }
@ -635,22 +645,26 @@ class TCPSocket {
*/ */
std::size_t Recv(std::string *p_str); std::size_t Recv(std::string *p_str);
/** /**
* \brief Close the socket, called automatically in destructor if the socket is not closed. * @brief Close the socket, called automatically in destructor if the socket is not closed.
*/ */
void Close() { Result Close() {
if (InvalidSocket() != handle_) { if (InvalidSocket() != handle_) {
#if defined(_WIN32)
auto rc = system::CloseSocket(handle_); auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
// it's possible that we close TCP sockets after finalizing WSA due to detached thread. // it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) { if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
system::ThrowAtError("close", rc); return system::FailWithCode("Failed to close the socket.");
} }
#else #else
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); if (rc != 0) {
return system::FailWithCode("Failed to close the socket.");
}
#endif #endif
handle_ = InvalidSocket(); handle_ = InvalidSocket();
} }
return Success();
} }
/** /**
* \brief Create a TCP socket on specified domain. * \brief Create a TCP socket on specified domain.
*/ */

View File

@ -18,9 +18,11 @@
#include "xgboost/logging.h" // for CHECK #include "xgboost/logging.h" // for CHECK
namespace xgboost::collective { namespace xgboost::collective {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const { Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
timer_.Start(__func__); timer_.Start(__func__);
auto error = [this] { timer_.Stop(__func__); }; auto error = [this] {
timer_.Stop(__func__);
};
if (stop_) { if (stop_) {
timer_.Stop(__func__); timer_.Stop(__func__);
@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
poll.WatchWrite(*op.sock); poll.WatchWrite(*op.sock);
break; break;
} }
case Op::kSleep: {
break;
}
default: { default: {
error(); error();
return Fail("Invalid socket operation."); return Fail("Invalid socket operation.");
@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
// poll, work on fds that are ready. // poll, work on fds that are ready.
timer_.Start("poll"); timer_.Start("poll");
auto rc = poll.Poll(timeout_); if (!poll.fds.empty()) {
timer_.Stop("poll"); auto rc = poll.Poll(timeout_);
if (!rc.OK()) { if (!rc.OK()) {
error(); error();
return rc; return rc;
}
} }
timer_.Stop("poll");
// we wonldn't be here if the queue is empty. // we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty()); CHECK(!qcopy.empty());
@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.pop(); qcopy.pop();
std::int32_t n_bytes_done{0}; std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking()); if (!op.sock) {
CHECK(op.code == Op::kSleep);
} else {
CHECK(op.sock->NonBlocking());
}
switch (op.code) { switch (op.code) {
case Op::kRead: { case Op::kRead: {
if (poll.CheckRead(*op.sock)) { if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
}
} }
break; break;
} }
@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
} }
break; break;
} }
case Op::kSleep: {
// For testing only.
std::this_thread::sleep_for(std::chrono::seconds{op.n});
n_bytes_done = op.n;
break;
}
default: { default: {
error(); error();
return Fail("Invalid socket operation."); return Fail("Invalid socket operation.");
@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.push(op); qcopy.push(op);
} }
} }
if (!blocking) {
break;
}
} }
timer_.Stop(__func__); timer_.Stop(__func__);
@ -128,6 +153,15 @@ void Loop::Process() {
while (true) { while (true) {
try { try {
std::unique_lock lock{mu_}; std::unique_lock lock{mu_};
// This can handle missed notification: wait(lock, predicate) is equivalent to:
//
// while (!predicate()) {
// cv.wait(lock);
// }
//
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
// the blocking call can never go unanswered.
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) { if (stop_) {
break; // only point where this loop can exit. break; // only point where this loop can exit.
@ -142,26 +176,27 @@ void Loop::Process() {
queue_.pop(); queue_.pop();
if (op.code == Op::kBlock) { if (op.code == Op::kBlock) {
is_blocking = true; is_blocking = true;
// Block must be the last op in the current batch since no further submit can be
// issued until the blocking call is finished.
CHECK(queue_.empty());
} else { } else {
qcopy.push(op); qcopy.push(op);
} }
} }
if (!is_blocking) { lock.unlock();
// Unblock, we can write to the global queue again. // Clear the local queue, if `is_blocking` is true, this is blocking the current
lock.unlock(); // worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
} }
// Push back the remaining operations.
// Clear the local queue, this is blocking the current worker thread (but not the if (rc.OK()) {
// client thread), wait until all operations are finished. std::unique_lock lock{mu_};
auto rc = this->EmptyQueue(&qcopy); while (!qcopy.empty()) {
queue_.push(qcopy.front());
if (is_blocking) { qcopy.pop();
// The unlock is delayed if this is a blocking call }
lock.unlock();
} }
// Notify the client thread who called block after all error conditions are set. // Notify the client thread who called block after all error conditions are set.
@ -228,7 +263,6 @@ Result Loop::Stop() {
} }
this->Submit(Op{Op::kBlock}); this->Submit(Op{Op::kBlock});
{ {
// Wait for the block call to finish. // Wait for the block call to finish.
std::unique_lock lock{mu_}; std::unique_lock lock{mu_};
@ -243,8 +277,20 @@ Result Loop::Stop() {
} }
} }
void Loop::Submit(Op op) {
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__); timer_.Init(__func__);
worker_ = std::thread{[this] { this->Process(); }}; worker_ = std::thread{[this] {
this->Process();
}};
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -19,20 +19,27 @@ namespace xgboost::collective {
class Loop { class Loop {
public: public:
struct Op { struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; // kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
std::int32_t rank{-1}; std::int32_t rank{-1};
std::int8_t* ptr{nullptr}; std::int8_t* ptr{nullptr};
std::size_t n{0}; std::size_t n{0};
TCPSocket* sock{nullptr}; TCPSocket* sock{nullptr};
std::size_t off{0}; std::size_t off{0};
explicit Op(Code c) : code{c} { CHECK(c == kBlock); } explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default; Op(Op const&) = default;
Op& operator=(Op const&) = default; Op& operator=(Op const&) = default;
Op(Op&&) = default; Op(Op&&) = default;
Op& operator=(Op&&) = default; Op& operator=(Op&&) = default;
// For testing purpose only
[[nodiscard]] static Op Sleep(std::size_t seconds) {
Op op{kSleep};
op.n = seconds;
return op;
}
}; };
private: private:
@ -54,7 +61,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr}; std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_; common::Monitor mutable timer_;
Result EmptyQueue(std::queue<Op>* p_queue) const; Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
// The cunsumer function that runs inside a worker thread. // The cunsumer function that runs inside a worker thread.
void Process(); void Process();
@ -64,12 +71,7 @@ class Loop {
*/ */
Result Stop(); Result Stop();
void Submit(Op op) { void Submit(Op op);
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
/** /**
* @brief Block the event loop until all ops are finished. In the case of failure, this * @brief Block the event loop until all ops are finished. In the case of failure, this

86
src/collective/result.cc Normal file
View File

@ -0,0 +1,86 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "xgboost/collective/result.h"
#include <filesystem> // for path
#include <sstream> // for stringstream
#include <stack> // for stack
#include "xgboost/logging.h"
namespace xgboost::collective {
namespace detail {
[[nodiscard]] std::string ResultImpl::Report() const {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] std::error_code ResultImpl::Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
auto ptr = this;
while (ptr->prev) {
ptr = ptr->prev.get();
}
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ #include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup #include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
@ -28,18 +28,25 @@ class LoopTest : public ::testing::Test {
auto domain = SockDomain::kV4; auto domain = SockDomain::kV4;
pair_.first = TCPSocket::Create(domain); pair_.first = TCPSocket::Create(domain);
auto port = pair_.first.BindHost(); in_port_t port{0};
pair_.first.Listen(); auto rc = Success() << [&] {
port = pair_.first.BindHost();
return Success();
} << [&] {
pair_.first.Listen();
return Success();
};
SafeColl(rc);
auto const& addr = SockAddrV4::Loopback().Addr(); auto const& addr = SockAddrV4::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
rc = pair_.second.NonBlocking(true); rc = pair_.second.NonBlocking(true);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
pair_.first = pair_.first.Accept(); pair_.first = pair_.first.Accept();
rc = pair_.first.NonBlocking(true); rc = pair_.first.NonBlocking(true);
ASSERT_TRUE(rc.OK()); SafeColl(rc);
loop_ = std::shared_ptr<Loop>{new Loop{timeout}}; loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
} }
@ -74,8 +81,26 @@ TEST_F(LoopTest, Op) {
loop_->Submit(rop); loop_->Submit(rop);
auto rc = loop_->Block(); auto rc = loop_->Block();
ASSERT_TRUE(rc.OK()) << rc.Report(); SafeColl(rc);
ASSERT_EQ(rbuf[0], wbuf[0]); ASSERT_EQ(rbuf[0], wbuf[0]);
} }
TEST_F(LoopTest, Block) {
// We need to ensure that a blocking call doesn't go unanswered.
auto op = Loop::Op::Sleep(2);
common::Timer t;
t.Start();
loop_->Submit(op);
t.Stop();
// submit is non-blocking
ASSERT_LT(t.ElapsedSeconds(), 1);
t.Start();
auto rc = loop_->Block();
t.Stop();
SafeColl(rc);
ASSERT_GE(t.ElapsedSeconds(), 1);
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -0,0 +1,31 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/result.h>
namespace xgboost::collective {
TEST(Result, Concat) {
auto rc0 = Fail("foo");
auto rc1 = Fail("bar");
auto rc = std::move(rc0) + std::move(rc1);
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
auto rc2 = Fail("Another", std::move(rc));
auto assert_that = [](Result const& rc) {
ASSERT_NE(rc.Report().find("Another"), std::string::npos);
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
};
assert_that(rc2);
auto empty = Success();
auto rc3 = std::move(empty) + std::move(rc2);
assert_that(rc3);
empty = Success();
auto rc4 = std::move(rc3) + std::move(empty);
assert_that(rc4);
}
} // namespace xgboost::collective