From 4b102004565b16c4f08f532f71291c1dd19bda09 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 18 Apr 2024 03:29:52 +0800 Subject: [PATCH] [coll] Improve event loop. (#10199) - Add a test for blocking calls. - Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls. - Handle EOF in read. - Improve the error message in the result. Allow concatenation of multiple results. --- R-package/src/Makevars.in | 1 + R-package/src/Makevars.win | 1 + demo/dask/cpu_training.py | 2 +- doc/contrib/unit_tests.rst | 8 +++ include/xgboost/collective/result.h | 107 ++++++++++++---------------- include/xgboost/collective/socket.h | 32 ++++++--- src/collective/loop.cc | 94 +++++++++++++++++------- src/collective/loop.h | 20 +++--- src/collective/result.cc | 86 ++++++++++++++++++++++ tests/cpp/collective/test_loop.cc | 41 ++++++++--- tests/cpp/collective/test_result.cc | 31 ++++++++ 11 files changed, 312 insertions(+), 111 deletions(-) create mode 100644 src/collective/result.cc create mode 100644 tests/cpp/collective/test_result.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0f4b3ac6f..99241249f 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 0c2084de9..fc2cd3b9f 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -99,6 +99,7 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py index 2bee444f7..7117eddd9 100644 --- a/demo/dask/cpu_training.py +++ b/demo/dask/cpu_training.py @@ -40,7 +40,7 @@ def main(client): # you can pass output directly into `predict` too. prediction = dxgb.predict(client, bst, dtrain) print("Evaluation history:", history) - return prediction + print("Error:", da.sqrt((prediction - y) ** 2).mean().compute()) if __name__ == "__main__": diff --git a/doc/contrib/unit_tests.rst b/doc/contrib/unit_tests.rst index 662a632e2..908e5ed99 100644 --- a/doc/contrib/unit_tests.rst +++ b/doc/contrib/unit_tests.rst @@ -144,6 +144,14 @@ which provides higher flexibility. For example: ctest --verbose +If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`: + +.. code-block:: + + ::testing::GTEST_FLAG(filter) = "Suite.Test"; + ::testing::GTEST_FLAG(repeat) = 10; + + *********************************************** Sanitizers: Detect memory errors and data races *********************************************** diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 919d3a902..23e70a8e6 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -3,13 +3,11 @@ */ #pragma once -#include - -#include // for unique_ptr -#include // for stringstream -#include // for stack -#include // for string -#include // for move +#include // for int32_t +#include // for unique_ptr +#include // for string +#include // for error_code +#include // for move namespace xgboost::collective { namespace detail { @@ -48,48 +46,19 @@ struct ResultImpl { return cur_eq; } - [[nodiscard]] std::string Report() { - std::stringstream ss; - ss << "\n- " << this->message; - if (this->errc != std::error_code{}) { - ss << " system error:" << this->errc.message(); - } + [[nodiscard]] std::string Report() const; + [[nodiscard]] std::error_code Code() const; - auto ptr = prev.get(); - while (ptr) { - ss << "\n- "; - ss << ptr->message; - - if (ptr->errc != std::error_code{}) { - ss << " " << ptr->errc.message(); - } - ptr = ptr->prev.get(); - } - - return ss.str(); - } - [[nodiscard]] auto Code() const { - // Find the root error. - std::stack stack; - auto ptr = this; - while (ptr) { - stack.push(ptr); - if (ptr->prev) { - ptr = ptr->prev.get(); - } else { - break; - } - } - while (!stack.empty()) { - auto frame = stack.top(); - stack.pop(); - if (frame->errc != std::error_code{}) { - return frame->errc; - } - } - return std::error_code{}; - } + void Concat(std::unique_ptr rhs); }; + +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) +#define __builtin_FILE() nullptr +#define __builtin_LINE() (-1) +std::string MakeMsg(std::string&& msg, char const*, std::int32_t); +#else +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line); +#endif } // namespace detail /** @@ -131,8 +100,21 @@ struct Result { } return *impl_ == *that.impl_; } + + friend Result operator+(Result&& lhs, Result&& rhs); }; +[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) { + if (lhs.OK()) { + return std::forward(rhs); + } + if (rhs.OK()) { + return std::forward(lhs); + } + lhs.impl_->Concat(std::move(rhs.impl_)); + return std::forward(lhs); +} + /** * @brief Return success. */ @@ -140,38 +122,43 @@ struct Result { /** * @brief Return failure. */ -[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } +[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line)}; +} /** * @brief Return failure with `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { - return Result{std::move(msg), std::move(errc)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)}; } /** * @brief Return failure with a previous error. */ -[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { - return Result{std::move(msg), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::forward(prev)}; } /** * @brief Return failure with a previous error and a new `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { - return Result{std::move(msg), std::move(errc), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc), + std::forward(prev)}; } // We don't have monad, a simple helper would do. template -[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { +[[nodiscard]] std::enable_if_t, Result> operator<<(Result&& r, Fn&& fn) { if (!r.OK()) { return std::forward(r); } return fn(); } -inline void SafeColl(Result const& rc) { - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } -} +void SafeColl(Result const& rc); } // namespace xgboost::collective diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 3bc3b389c..11520eede 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, XGBoost Contributors + * Copyright (c) 2022-2024, XGBoost Contributors */ #pragma once @@ -12,7 +12,6 @@ #include // std::size_t #include // std::int32_t, std::uint16_t #include // memset -#include // std::numeric_limits #include // std::string #include // std::error_code, std::system_category #include // std::swap @@ -468,19 +467,30 @@ class TCPSocket { *addr = SockAddress{SockAddrV6{caddr}}; *out = TCPSocket{newfd}; } + // On MacOS, this is automatically set to async socket if the parent socket is async + // We make sure all socket are blocking by default. + // + // On Windows, a closed socket is returned during shutdown. We guard against it when + // setting non-blocking. + if (!out->IsClosed()) { + return out->NonBlocking(false); + } return Success(); } ~TCPSocket() { if (!IsClosed()) { - Close(); + auto rc = this->Close(); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + } } } TCPSocket(TCPSocket const &that) = delete; TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } TCPSocket &operator=(TCPSocket const &that) = delete; - TCPSocket &operator=(TCPSocket &&that) { + TCPSocket &operator=(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); return *this; } @@ -635,22 +645,26 @@ class TCPSocket { */ std::size_t Recv(std::string *p_str); /** - * \brief Close the socket, called automatically in destructor if the socket is not closed. + * @brief Close the socket, called automatically in destructor if the socket is not closed. */ - void Close() { + Result Close() { if (InvalidSocket() != handle_) { -#if defined(_WIN32) auto rc = system::CloseSocket(handle_); +#if defined(_WIN32) // it's possible that we close TCP sockets after finalizing WSA due to detached thread. if (rc != 0 && system::LastError() != WSANOTINITIALISED) { - system::ThrowAtError("close", rc); + return system::FailWithCode("Failed to close the socket."); } #else - xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); + if (rc != 0) { + return system::FailWithCode("Failed to close the socket."); + } #endif handle_ = InvalidSocket(); } + return Success(); } + /** * \brief Create a TCP socket on specified domain. */ diff --git a/src/collective/loop.cc b/src/collective/loop.cc index b51749fcd..0cd41426d 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -18,9 +18,11 @@ #include "xgboost/logging.h" // for CHECK namespace xgboost::collective { -Result Loop::EmptyQueue(std::queue* p_queue) const { +Result Loop::ProcessQueue(std::queue* p_queue, bool blocking) const { timer_.Start(__func__); - auto error = [this] { timer_.Stop(__func__); }; + auto error = [this] { + timer_.Stop(__func__); + }; if (stop_) { timer_.Stop(__func__); @@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { poll.WatchWrite(*op.sock); break; } + case Op::kSleep: { + break; + } default: { error(); return Fail("Invalid socket operation."); @@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { // poll, work on fds that are ready. timer_.Start("poll"); - auto rc = poll.Poll(timeout_); - timer_.Stop("poll"); - if (!rc.OK()) { - error(); - return rc; + if (!poll.fds.empty()) { + auto rc = poll.Poll(timeout_); + if (!rc.OK()) { + error(); + return rc; + } } + timer_.Stop("poll"); // we wonldn't be here if the queue is empty. CHECK(!qcopy.empty()); @@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { qcopy.pop(); std::int32_t n_bytes_done{0}; - CHECK(op.sock->NonBlocking()); + if (!op.sock) { + CHECK(op.code == Op::kSleep); + } else { + CHECK(op.sock->NonBlocking()); + } switch (op.code) { case Op::kRead: { if (poll.CheckRead(*op.sock)) { n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); + if (n_bytes_done == 0) { + error(); + return Fail("Encountered EOF. The other end is likely closed."); + } } break; } @@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { } break; } + case Op::kSleep: { + // For testing only. + std::this_thread::sleep_for(std::chrono::seconds{op.n}); + n_bytes_done = op.n; + break; + } default: { error(); return Fail("Invalid socket operation."); @@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { qcopy.push(op); } } + + if (!blocking) { + break; + } } timer_.Stop(__func__); @@ -128,6 +153,15 @@ void Loop::Process() { while (true) { try { std::unique_lock lock{mu_}; + // This can handle missed notification: wait(lock, predicate) is equivalent to: + // + // while (!predicate()) { + // cv.wait(lock); + // } + // + // As a result, if there's a missed notification, the queue wouldn't be empty, hence + // the predicate would be false and the actual wait wouldn't be invoked. Therefore, + // the blocking call can never go unanswered. cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); if (stop_) { break; // only point where this loop can exit. @@ -142,26 +176,27 @@ void Loop::Process() { queue_.pop(); if (op.code == Op::kBlock) { is_blocking = true; - // Block must be the last op in the current batch since no further submit can be - // issued until the blocking call is finished. - CHECK(queue_.empty()); } else { qcopy.push(op); } } - if (!is_blocking) { - // Unblock, we can write to the global queue again. - lock.unlock(); + lock.unlock(); + // Clear the local queue, if `is_blocking` is true, this is blocking the current + // worker thread (but not the client thread), wait until all operations are + // finished. + auto rc = this->ProcessQueue(&qcopy, is_blocking); + + if (is_blocking && rc.OK()) { + CHECK(qcopy.empty()); } - - // Clear the local queue, this is blocking the current worker thread (but not the - // client thread), wait until all operations are finished. - auto rc = this->EmptyQueue(&qcopy); - - if (is_blocking) { - // The unlock is delayed if this is a blocking call - lock.unlock(); + // Push back the remaining operations. + if (rc.OK()) { + std::unique_lock lock{mu_}; + while (!qcopy.empty()) { + queue_.push(qcopy.front()); + qcopy.pop(); + } } // Notify the client thread who called block after all error conditions are set. @@ -228,7 +263,6 @@ Result Loop::Stop() { } this->Submit(Op{Op::kBlock}); - { // Wait for the block call to finish. std::unique_lock lock{mu_}; @@ -243,8 +277,20 @@ Result Loop::Stop() { } } +void Loop::Submit(Op op) { + std::unique_lock lock{mu_}; + if (op.code != Op::kBlock) { + CHECK_NE(op.n, 0); + } + queue_.push(op); + lock.unlock(); + cv_.notify_one(); +} + Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); - worker_ = std::thread{[this] { this->Process(); }}; + worker_ = std::thread{[this] { + this->Process(); + }}; } } // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h index 4839abfd3..a4de2a81b 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -19,20 +19,27 @@ namespace xgboost::collective { class Loop { public: struct Op { - enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; + // kSleep is only for testing + enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code; std::int32_t rank{-1}; std::int8_t* ptr{nullptr}; std::size_t n{0}; TCPSocket* sock{nullptr}; std::size_t off{0}; - explicit Op(Code c) : code{c} { CHECK(c == kBlock); } + explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); } Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} Op(Op const&) = default; Op& operator=(Op const&) = default; Op(Op&&) = default; Op& operator=(Op&&) = default; + // For testing purpose only + [[nodiscard]] static Op Sleep(std::size_t seconds) { + Op op{kSleep}; + op.n = seconds; + return op; + } }; private: @@ -54,7 +61,7 @@ class Loop { std::exception_ptr curr_exce_{nullptr}; common::Monitor mutable timer_; - Result EmptyQueue(std::queue* p_queue) const; + Result ProcessQueue(std::queue* p_queue, bool blocking) const; // The cunsumer function that runs inside a worker thread. void Process(); @@ -64,12 +71,7 @@ class Loop { */ Result Stop(); - void Submit(Op op) { - std::unique_lock lock{mu_}; - queue_.push(op); - lock.unlock(); - cv_.notify_one(); - } + void Submit(Op op); /** * @brief Block the event loop until all ops are finished. In the case of failure, this diff --git a/src/collective/result.cc b/src/collective/result.cc new file mode 100644 index 000000000..b11710572 --- /dev/null +++ b/src/collective/result.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include "xgboost/collective/result.h" + +#include // for path +#include // for stringstream +#include // for stack + +#include "xgboost/logging.h" + +namespace xgboost::collective { +namespace detail { +[[nodiscard]] std::string ResultImpl::Report() const { + std::stringstream ss; + ss << "\n- " << this->message; + if (this->errc != std::error_code{}) { + ss << " system error:" << this->errc.message(); + } + + auto ptr = prev.get(); + while (ptr) { + ss << "\n- "; + ss << ptr->message; + + if (ptr->errc != std::error_code{}) { + ss << " " << ptr->errc.message(); + } + ptr = ptr->prev.get(); + } + + return ss.str(); +} + +[[nodiscard]] std::error_code ResultImpl::Code() const { + // Find the root error. + std::stack stack; + auto ptr = this; + while (ptr) { + stack.push(ptr); + if (ptr->prev) { + ptr = ptr->prev.get(); + } else { + break; + } + } + while (!stack.empty()) { + auto frame = stack.top(); + stack.pop(); + if (frame->errc != std::error_code{}) { + return frame->errc; + } + } + return std::error_code{}; +} + +void ResultImpl::Concat(std::unique_ptr rhs) { + auto ptr = this; + while (ptr->prev) { + ptr = ptr->prev.get(); + } + ptr->prev = std::move(rhs); +} + +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) +std::string MakeMsg(std::string&& msg, char const*, std::int32_t) { + return std::forward(msg); +} +#else +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) { + auto name = std::filesystem::path{file}.filename(); + if (file && line != -1) { + return "[" + name.string() + ":" + std::to_string(line) + // NOLINT + "]: " + std::forward(msg); + } + return std::forward(msg); +} +#endif +} // namespace detail + +void SafeColl(Result const& rc) { + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } +} +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index e5ef987f3..0908d9623 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_TRUE, ASSERT_EQ #include // for TCPSocket, Connect, SocketFinalize, SocketStartup @@ -28,18 +28,25 @@ class LoopTest : public ::testing::Test { auto domain = SockDomain::kV4; pair_.first = TCPSocket::Create(domain); - auto port = pair_.first.BindHost(); - pair_.first.Listen(); + in_port_t port{0}; + auto rc = Success() << [&] { + port = pair_.first.BindHost(); + return Success(); + } << [&] { + pair_.first.Listen(); + return Success(); + }; + SafeColl(rc); auto const& addr = SockAddrV4::Loopback().Addr(); - auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); - ASSERT_TRUE(rc.OK()); + rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); + SafeColl(rc); rc = pair_.second.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); pair_.first = pair_.first.Accept(); rc = pair_.first.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); loop_ = std::shared_ptr{new Loop{timeout}}; } @@ -74,8 +81,26 @@ TEST_F(LoopTest, Op) { loop_->Submit(rop); auto rc = loop_->Block(); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(rbuf[0], wbuf[0]); } + +TEST_F(LoopTest, Block) { + // We need to ensure that a blocking call doesn't go unanswered. + auto op = Loop::Op::Sleep(2); + + common::Timer t; + t.Start(); + loop_->Submit(op); + t.Stop(); + // submit is non-blocking + ASSERT_LT(t.ElapsedSeconds(), 1); + + t.Start(); + auto rc = loop_->Block(); + t.Stop(); + SafeColl(rc); + ASSERT_GE(t.ElapsedSeconds(), 1); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_result.cc b/tests/cpp/collective/test_result.cc new file mode 100644 index 000000000..1c7194f92 --- /dev/null +++ b/tests/cpp/collective/test_result.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include +#include + +namespace xgboost::collective { +TEST(Result, Concat) { + auto rc0 = Fail("foo"); + auto rc1 = Fail("bar"); + auto rc = std::move(rc0) + std::move(rc1); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + + auto rc2 = Fail("Another", std::move(rc)); + auto assert_that = [](Result const& rc) { + ASSERT_NE(rc.Report().find("Another"), std::string::npos); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + }; + assert_that(rc2); + + auto empty = Success(); + auto rc3 = std::move(empty) + std::move(rc2); + assert_that(rc3); + + empty = Success(); + auto rc4 = std::move(rc3) + std::move(empty); + assert_that(rc4); +} +} // namespace xgboost::collective