[coll] Improve event loop. (#10199)
- Add a test for blocking calls. - Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls. - Handle EOF in read. - Improve the error message in the result. Allow concatenation of multiple results.
This commit is contained in:
parent
7c0c9677a9
commit
4b10200456
@ -99,6 +99,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/context.o \
|
||||
$(PKGROOT)/src/logging.o \
|
||||
$(PKGROOT)/src/global_config.o \
|
||||
$(PKGROOT)/src/collective/result.o \
|
||||
$(PKGROOT)/src/collective/allgather.o \
|
||||
$(PKGROOT)/src/collective/allreduce.o \
|
||||
$(PKGROOT)/src/collective/broadcast.o \
|
||||
|
||||
@ -99,6 +99,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/context.o \
|
||||
$(PKGROOT)/src/logging.o \
|
||||
$(PKGROOT)/src/global_config.o \
|
||||
$(PKGROOT)/src/collective/result.o \
|
||||
$(PKGROOT)/src/collective/allgather.o \
|
||||
$(PKGROOT)/src/collective/allreduce.o \
|
||||
$(PKGROOT)/src/collective/broadcast.o \
|
||||
|
||||
@ -40,7 +40,7 @@ def main(client):
|
||||
# you can pass output directly into `predict` too.
|
||||
prediction = dxgb.predict(client, bst, dtrain)
|
||||
print("Evaluation history:", history)
|
||||
return prediction
|
||||
print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -144,6 +144,14 @@ which provides higher flexibility. For example:
|
||||
|
||||
ctest --verbose
|
||||
|
||||
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
|
||||
|
||||
.. code-block::
|
||||
|
||||
::testing::GTEST_FLAG(filter) = "Suite.Test";
|
||||
::testing::GTEST_FLAG(repeat) = 10;
|
||||
|
||||
|
||||
***********************************************
|
||||
Sanitizers: Detect memory errors and data races
|
||||
***********************************************
|
||||
|
||||
@ -3,12 +3,10 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for unique_ptr
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
#include <string> // for string
|
||||
#include <system_error> // for error_code
|
||||
#include <utility> // for move
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -48,48 +46,19 @@ struct ResultImpl {
|
||||
return cur_eq;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string Report() {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
[[nodiscard]] std::string Report() const;
|
||||
[[nodiscard]] std::error_code Code() const;
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
[[nodiscard]] auto Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
void Concat(std::unique_ptr<ResultImpl> rhs);
|
||||
};
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
#define __builtin_FILE() nullptr
|
||||
#define __builtin_LINE() (-1)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@ -131,8 +100,21 @@ struct Result {
|
||||
}
|
||||
return *impl_ == *that.impl_;
|
||||
}
|
||||
|
||||
friend Result operator+(Result&& lhs, Result&& rhs);
|
||||
};
|
||||
|
||||
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
|
||||
if (lhs.OK()) {
|
||||
return std::forward<Result>(rhs);
|
||||
}
|
||||
if (rhs.OK()) {
|
||||
return std::forward<Result>(lhs);
|
||||
}
|
||||
lhs.impl_->Concat(std::move(rhs.impl_));
|
||||
return std::forward<Result>(lhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return success.
|
||||
*/
|
||||
@ -140,38 +122,43 @@ struct Result {
|
||||
/**
|
||||
* @brief Return failure.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
|
||||
[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
|
||||
return Result{std::move(msg), std::move(errc)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
|
||||
char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
|
||||
return Result{std::move(msg), std::forward<Result>(prev)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
|
||||
}
|
||||
/**
|
||||
* @brief Return failure with a previous error and a new `errno`.
|
||||
*/
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
||||
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
|
||||
char const* file = __builtin_FILE(),
|
||||
std::int32_t line = __builtin_LINE()) {
|
||||
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
|
||||
std::forward<Result>(prev)};
|
||||
}
|
||||
|
||||
// We don't have monad, a simple helper would do.
|
||||
template <typename Fn>
|
||||
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
|
||||
[[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
|
||||
if (!r.OK()) {
|
||||
return std::forward<Result>(r);
|
||||
}
|
||||
return fn();
|
||||
}
|
||||
|
||||
inline void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
void SafeColl(Result const& rc);
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright (c) 2022-2023, XGBoost Contributors
|
||||
* Copyright (c) 2022-2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t, std::uint16_t
|
||||
#include <cstring> // memset
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <string> // std::string
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
#include <utility> // std::swap
|
||||
@ -468,19 +467,30 @@ class TCPSocket {
|
||||
*addr = SockAddress{SockAddrV6{caddr}};
|
||||
*out = TCPSocket{newfd};
|
||||
}
|
||||
// On MacOS, this is automatically set to async socket if the parent socket is async
|
||||
// We make sure all socket are blocking by default.
|
||||
//
|
||||
// On Windows, a closed socket is returned during shutdown. We guard against it when
|
||||
// setting non-blocking.
|
||||
if (!out->IsClosed()) {
|
||||
return out->NonBlocking(false);
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
~TCPSocket() {
|
||||
if (!IsClosed()) {
|
||||
Close();
|
||||
auto rc = this->Close();
|
||||
if (!rc.OK()) {
|
||||
LOG(WARNING) << rc.Report();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TCPSocket(TCPSocket const &that) = delete;
|
||||
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
|
||||
TCPSocket &operator=(TCPSocket const &that) = delete;
|
||||
TCPSocket &operator=(TCPSocket &&that) {
|
||||
TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
|
||||
std::swap(this->handle_, that.handle_);
|
||||
return *this;
|
||||
}
|
||||
@ -635,22 +645,26 @@ class TCPSocket {
|
||||
*/
|
||||
std::size_t Recv(std::string *p_str);
|
||||
/**
|
||||
* \brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||
*/
|
||||
void Close() {
|
||||
Result Close() {
|
||||
if (InvalidSocket() != handle_) {
|
||||
#if defined(_WIN32)
|
||||
auto rc = system::CloseSocket(handle_);
|
||||
#if defined(_WIN32)
|
||||
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
|
||||
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
|
||||
system::ThrowAtError("close", rc);
|
||||
return system::FailWithCode("Failed to close the socket.");
|
||||
}
|
||||
#else
|
||||
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to close the socket.");
|
||||
}
|
||||
#endif
|
||||
handle_ = InvalidSocket();
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create a TCP socket on specified domain.
|
||||
*/
|
||||
|
||||
@ -18,9 +18,11 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] { timer_.Stop(__func__); };
|
||||
auto error = [this] {
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
if (stop_) {
|
||||
timer_.Stop(__func__);
|
||||
@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
poll.WatchWrite(*op.sock);
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
timer_.Start("poll");
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
timer_.Stop("poll");
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
if (!op.sock) {
|
||||
CHECK(op.code == Op::kSleep);
|
||||
} else {
|
||||
CHECK(op.sock->NonBlocking());
|
||||
}
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
// For testing only.
|
||||
std::this_thread::sleep_for(std::chrono::seconds{op.n});
|
||||
n_bytes_done = op.n;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@ -128,6 +153,15 @@ void Loop::Process() {
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
// This can handle missed notification: wait(lock, predicate) is equivalent to:
|
||||
//
|
||||
// while (!predicate()) {
|
||||
// cv.wait(lock);
|
||||
// }
|
||||
//
|
||||
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
|
||||
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
|
||||
// the blocking call can never go unanswered.
|
||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||
if (stop_) {
|
||||
break; // only point where this loop can exit.
|
||||
@ -142,26 +176,27 @@ void Loop::Process() {
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
// Block must be the last op in the current batch since no further submit can be
|
||||
// issued until the blocking call is finished.
|
||||
CHECK(queue_.empty());
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_blocking) {
|
||||
// Unblock, we can write to the global queue again.
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
|
||||
// Clear the local queue, this is blocking the current worker thread (but not the
|
||||
// client thread), wait until all operations are finished.
|
||||
auto rc = this->EmptyQueue(&qcopy);
|
||||
|
||||
if (is_blocking) {
|
||||
// The unlock is delayed if this is a blocking call
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
@ -228,7 +263,6 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
@ -243,8 +277,20 @@ Result Loop::Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] { this->Process(); }};
|
||||
worker_ = std::thread{[this] {
|
||||
this->Process();
|
||||
}};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -19,20 +19,27 @@ namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
Op& operator=(Op const&) = default;
|
||||
Op(Op&&) = default;
|
||||
Op& operator=(Op&&) = default;
|
||||
// For testing purpose only
|
||||
[[nodiscard]] static Op Sleep(std::size_t seconds) {
|
||||
Op op{kSleep};
|
||||
op.n = seconds;
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
@ -54,7 +61,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
@ -64,12 +71,7 @@ class Loop {
|
||||
*/
|
||||
Result Stop();
|
||||
|
||||
void Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
void Submit(Op op);
|
||||
|
||||
/**
|
||||
* @brief Block the event loop until all ops are finished. In the case of failure, this
|
||||
|
||||
86
src/collective/result.cc
Normal file
86
src/collective/result.cc
Normal file
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#include <filesystem> // for path
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
[[nodiscard]] std::string ResultImpl::Report() const {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::error_code ResultImpl::Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
|
||||
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
auto ptr = this;
|
||||
while (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright 2023-2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
|
||||
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
|
||||
@ -28,18 +28,25 @@ class LoopTest : public ::testing::Test {
|
||||
|
||||
auto domain = SockDomain::kV4;
|
||||
pair_.first = TCPSocket::Create(domain);
|
||||
auto port = pair_.first.BindHost();
|
||||
in_port_t port{0};
|
||||
auto rc = Success() << [&] {
|
||||
port = pair_.first.BindHost();
|
||||
return Success();
|
||||
} << [&] {
|
||||
pair_.first.Listen();
|
||||
return Success();
|
||||
};
|
||||
SafeColl(rc);
|
||||
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
|
||||
SafeColl(rc);
|
||||
rc = pair_.second.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
|
||||
pair_.first = pair_.first.Accept();
|
||||
rc = pair_.first.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
SafeColl(rc);
|
||||
|
||||
loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
|
||||
}
|
||||
@ -74,8 +81,26 @@ TEST_F(LoopTest, Op) {
|
||||
loop_->Submit(rop);
|
||||
|
||||
auto rc = loop_->Block();
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
SafeColl(rc);
|
||||
|
||||
ASSERT_EQ(rbuf[0], wbuf[0]);
|
||||
}
|
||||
|
||||
TEST_F(LoopTest, Block) {
|
||||
// We need to ensure that a blocking call doesn't go unanswered.
|
||||
auto op = Loop::Op::Sleep(2);
|
||||
|
||||
common::Timer t;
|
||||
t.Start();
|
||||
loop_->Submit(op);
|
||||
t.Stop();
|
||||
// submit is non-blocking
|
||||
ASSERT_LT(t.ElapsedSeconds(), 1);
|
||||
|
||||
t.Start();
|
||||
auto rc = loop_->Block();
|
||||
t.Stop();
|
||||
SafeColl(rc);
|
||||
ASSERT_GE(t.ElapsedSeconds(), 1);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
31
tests/cpp/collective/test_result.cc
Normal file
31
tests/cpp/collective/test_result.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/result.h>
|
||||
|
||||
namespace xgboost::collective {
|
||||
TEST(Result, Concat) {
|
||||
auto rc0 = Fail("foo");
|
||||
auto rc1 = Fail("bar");
|
||||
auto rc = std::move(rc0) + std::move(rc1);
|
||||
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
|
||||
|
||||
auto rc2 = Fail("Another", std::move(rc));
|
||||
auto assert_that = [](Result const& rc) {
|
||||
ASSERT_NE(rc.Report().find("Another"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("foo"), std::string::npos);
|
||||
ASSERT_NE(rc.Report().find("bar"), std::string::npos);
|
||||
};
|
||||
assert_that(rc2);
|
||||
|
||||
auto empty = Success();
|
||||
auto rc3 = std::move(empty) + std::move(rc2);
|
||||
assert_that(rc3);
|
||||
|
||||
empty = Success();
|
||||
auto rc4 = std::move(rc3) + std::move(empty);
|
||||
assert_that(rc4);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Loading…
x
Reference in New Issue
Block a user