From 38ac52dd87bb47f4f61e0249831f377181eff83a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 20 Sep 2023 02:09:07 +0800 Subject: [PATCH] Build a simple event loop for collective. (#9593) --- rabit/include/rabit/internal/io.h | 7 +- rabit/include/rabit/internal/socket.h | 46 +++++-- rabit/src/allreduce_base.cc | 3 +- src/collective/loop.cc | 167 ++++++++++++++++++++++++++ src/collective/loop.h | 83 +++++++++++++ src/collective/socket.cc | 62 +++++----- tests/cpp/collective/test_loop.cc | 81 +++++++++++++ 7 files changed, 402 insertions(+), 47 deletions(-) create mode 100644 src/collective/loop.cc create mode 100644 src/collective/loop.h create mode 100644 tests/cpp/collective/test_loop.cc diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h index d93f32ff9..d5d0fee4d 100644 --- a/rabit/include/rabit/internal/io.h +++ b/rabit/include/rabit/internal/io.h @@ -16,8 +16,8 @@ #include #include -#include "rabit/internal/utils.h" -#include "rabit/serializable.h" +#include "dmlc/io.h" +#include "xgboost/logging.h" namespace rabit::utils { /*! \brief re-use definition of dmlc::SeekStream */ @@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream { } ~MemoryBufferStream() override = default; size_t Read(void *ptr, size_t size) override { - utils::Assert(curr_ptr_ <= p_buffer_->length(), - "read can not have position excceed buffer length"); + CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); curr_ptr_ += nread; diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index 6fb7fe725..f1a6699fb 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -29,11 +29,10 @@ #include #include #include +#include // make_error_code, errc #include #include -#include "utils.h" - #if !defined(_WIN32) #include @@ -93,6 +92,20 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) #endif // IS_MINGW() } +template +std::enable_if_t, xgboost::collective::Result> PollError(E const& revents) { + if ((revents & POLLERR) != 0) { + return xgboost::system::FailWithCode("Poll error condition."); + } + if ((revents & POLLNVAL) != 0) { + return xgboost::system::FailWithCode("Invalid polling request."); + } + if ((revents & POLLHUP) != 0) { + return xgboost::system::FailWithCode("Poll hung up."); + } + return xgboost::collective::Success(); +} + /*! \brief helper data structure to perform poll */ struct PollHelper { public: @@ -160,25 +173,32 @@ struct PollHelper { * * @param timeout specify timeout in seconds. Block if negative. */ - [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) { + [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout, + bool check_error = true) { std::vector fdset; fdset.reserve(fds.size()); for (auto kv : fds) { fdset.push_back(kv.second); } - int ret = PollImpl(fdset.data(), fdset.size(), timeout); + std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - return xgboost::collective::Fail("Poll timeout."); + return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); } else if (ret < 0) { return xgboost::system::FailWithCode("Poll failed."); - } else { - for (auto& pfd : fdset) { - auto revents = pfd.revents & pfd.events; - if (!revents) { - fds.erase(pfd.fd); - } else { - fds[pfd.fd].events = revents; - } + } + + for (auto& pfd : fdset) { + auto result = PollError(pfd.revents); + if (check_error && !result.OK()) { + return result; + } + + auto revents = pfd.revents & pfd.events; + if (!revents) { + // FIXME(jiamingy): remove this once rabit is replaced. + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; } } return xgboost::collective::Success(); diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index 6480adf03..04246b5a1 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -721,12 +721,11 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } finished = false; } - watcher.WatchException(links[i].sock); } // finish running if (finished) break; // select - auto poll_res = watcher.Poll(timeout_sec); + auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos if (!poll_res.OK()) { LOG(FATAL) << poll_res.Report(); } diff --git a/src/collective/loop.cc b/src/collective/loop.cc new file mode 100644 index 000000000..95a1019ac --- /dev/null +++ b/src/collective/loop.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "loop.h" + +#include // for queue + +#include "rabit/internal/socket.h" // for PollHelper +#include "xgboost/collective/socket.h" // for FailWithCode +#include "xgboost/logging.h" // for CHECK + +namespace xgboost::collective { +Result Loop::EmptyQueue() { + timer_.Start(__func__); + auto error = [this] { + this->stop_ = true; + timer_.Stop(__func__); + }; + + while (!queue_.empty() && !stop_) { + std::queue qcopy; + rabit::utils::PollHelper poll; + + // watch all ops + while (!queue_.empty()) { + auto op = queue_.front(); + queue_.pop(); + + switch (op.code) { + case Op::kRead: { + poll.WatchRead(*op.sock); + break; + } + case Op::kWrite: { + poll.WatchWrite(*op.sock); + break; + } + default: { + error(); + return Fail("Invalid socket operation."); + } + } + qcopy.push(op); + } + + // poll, work on fds that are ready. + timer_.Start("poll"); + auto rc = poll.Poll(timeout_); + timer_.Stop("poll"); + if (!rc.OK()) { + error(); + return rc; + } + // we wonldn't be here if the queue is empty. + CHECK(!qcopy.empty()); + + while (!qcopy.empty() && !stop_) { + auto op = qcopy.front(); + qcopy.pop(); + + std::int32_t n_bytes_done{0}; + CHECK(op.sock->NonBlocking()); + + switch (op.code) { + case Op::kRead: { + if (poll.CheckRead(*op.sock)) { + n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); + } + break; + } + case Op::kWrite: { + if (poll.CheckWrite(*op.sock)) { + n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off); + } + break; + } + default: { + error(); + return Fail("Invalid socket operation."); + } + } + + if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { + stop_ = true; + auto rc = system::FailWithCode("Invalid socket output."); + error(); + return rc; + } + op.off += n_bytes_done; + CHECK_LE(op.off, op.n); + + if (op.off != op.n) { + // not yet finished, push back to queue for next round. + queue_.push(op); + } + } + } + timer_.Stop(__func__); + return Success(); +} + +void Loop::Process() { + // consumer + while (true) { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); + if (stop_) { + break; + } + CHECK(!mu_.try_lock()); + + this->rc_ = this->EmptyQueue(); + if (!rc_.OK()) { + stop_ = true; + cv_.notify_one(); + break; + } + + CHECK(queue_.empty()); + CHECK(!mu_.try_lock()); + cv_.notify_one(); + } + + if (rc_.OK()) { + CHECK(queue_.empty()); + } +} + +Result Loop::Stop() { + std::unique_lock lock{mu_}; + stop_ = true; + lock.unlock(); + + CHECK_EQ(this->Block().OK(), this->rc_.OK()); + + if (curr_exce_) { + std::rethrow_exception(curr_exce_); + } + + return Success(); +} + +Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { + timer_.Init(__func__); + worker_ = std::thread{[this] { + try { + this->Process(); + } catch (std::exception const& e) { + std::lock_guard guard{mu_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + rc_ = Fail("Exception was thrown"); + } + stop_ = true; + cv_.notify_all(); + } catch (...) { + std::lock_guard guard{mu_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + rc_ = Fail("Exception was thrown"); + } + stop_ = true; + cv_.notify_all(); + } + }}; +} +} // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h new file mode 100644 index 000000000..0bccbc0d0 --- /dev/null +++ b/src/collective/loop.h @@ -0,0 +1,83 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for seconds +#include // for condition_variable +#include // for size_t +#include // for int8_t, int32_t +#include // for exception_ptr +#include // for unique_lock, mutex +#include // for queue +#include // for thread +#include // for move + +#include "../common/timer.h" // for Monitor +#include "xgboost/collective/result.h" // for Result +#include "xgboost/collective/socket.h" // for TCPSocket + +namespace xgboost::collective { +class Loop { + public: + struct Op { + enum Code : std::int8_t { kRead = 0, kWrite = 1 } code; + std::int32_t rank{-1}; + std::int8_t* ptr{nullptr}; + std::size_t n{0}; + TCPSocket* sock{nullptr}; + std::size_t off{0}; + + Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) + : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} + Op(Op const&) = default; + Op& operator=(Op const&) = default; + Op(Op&&) = default; + Op& operator=(Op&&) = default; + }; + + private: + std::thread worker_; + std::condition_variable cv_; + std::mutex mu_; + std::queue queue_; + std::chrono::seconds timeout_; + Result rc_; + bool stop_{false}; + std::exception_ptr curr_exce_{nullptr}; + common::Monitor timer_; + + Result EmptyQueue(); + void Process(); + + public: + Result Stop(); + + void Submit(Op op) { + // producer + std::unique_lock lock{mu_}; + queue_.push(op); + lock.unlock(); + cv_.notify_one(); + } + + [[nodiscard]] Result Block() { + { + std::unique_lock lock{mu_}; + cv_.notify_all(); + } + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return this->queue_.empty() || stop_; }); + return std::move(rc_); + } + + explicit Loop(std::chrono::seconds timeout); + + ~Loop() noexcept(false) { + this->Stop(); + + if (worker_.joinable()) { + worker_.join(); + } + } +}; +} // namespace xgboost::collective diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 8ca936ff3..43da366bd 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -118,36 +118,42 @@ std::size_t TCPSocket::Recv(std::string *p_str) { } auto rc = connect(conn.Handle(), addr_handle, addr_len); - if (rc != 0) { - auto errcode = system::LastError(); - if (!system::ErrorWouldBlock(errcode)) { - log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}), - __FILE__, __LINE__); - continue; - } - - rabit::utils::PollHelper poll; - poll.WatchWrite(conn); - auto result = poll.Poll(timeout); - if (!result.OK()) { - log_failure(std::move(result), __FILE__, __LINE__); - continue; - } - if (!poll.CheckWrite(conn)) { - log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), - __FILE__, __LINE__); - continue; - } - result = conn.GetSockError(); - if (!result.OK()) { - log_failure(std::move(result), __FILE__, __LINE__); - continue; - } - - return conn.NonBlocking(non_blocking); - } else { + if (rc == 0) { return conn.NonBlocking(non_blocking); } + + auto errcode = system::LastError(); + if (!system::ErrorWouldBlock(errcode)) { + log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}), + __FILE__, __LINE__); + continue; + } + + rabit::utils::PollHelper poll; + poll.WatchWrite(conn); + auto result = poll.Poll(timeout); + if (!result.OK()) { + // poll would fail if there's a socket error, we log the root cause instead of the + // poll failure. + auto sockerr = conn.GetSockError(); + if (!sockerr.OK()) { + result = std::move(sockerr); + } + log_failure(std::move(result), __FILE__, __LINE__); + continue; + } + if (!poll.CheckWrite(conn)) { + log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__, + __LINE__); + continue; + } + result = conn.GetSockError(); + if (!result.OK()) { + log_failure(std::move(result), __FILE__, __LINE__); + continue; + } + + return conn.NonBlocking(non_blocking); } std::stringstream ss; diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc new file mode 100644 index 000000000..4686060ce --- /dev/null +++ b/tests/cpp/collective/test_loop.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include // for ASSERT_TRUE, ASSERT_EQ +#include // for TCPSocket, Connect, SocketFinalize, SocketStartup +#include // for StringView + +#include // for seconds +#include // for int8_t +#include // for make_shared, shared_ptr +#include // for make_error_code, errc +#include // for pair +#include // for vector + +#include "../../../src/collective/loop.h" // for Loop + +namespace xgboost::collective { +namespace { +class LoopTest : public ::testing::Test { + protected: + std::pair pair_; + std::shared_ptr loop_; + + protected: + void SetUp() override { + system::SocketStartup(); + std::chrono::seconds timeout{1}; + + auto domain = SockDomain::kV4; + pair_.first = TCPSocket::Create(domain); + auto port = pair_.first.BindHost(); + pair_.first.Listen(); + + auto const& addr = SockAddrV4::Loopback().Addr(); + auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); + ASSERT_TRUE(rc.OK()); + rc = pair_.second.NonBlocking(true); + ASSERT_TRUE(rc.OK()); + + pair_.first = pair_.first.Accept(); + rc = pair_.first.NonBlocking(true); + ASSERT_TRUE(rc.OK()); + + loop_ = std::make_shared(timeout); + } + + void TearDown() override { + pair_ = decltype(pair_){}; + system::SocketFinalize(); + } +}; +} // namespace + +TEST_F(LoopTest, Timeout) { + std::vector data(1); + Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0}; + loop_->Submit(op); + auto rc = loop_->Block(); + ASSERT_FALSE(rc.OK()); + ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report(); +} + +TEST_F(LoopTest, Op) { + TCPSocket& send = pair_.first; + TCPSocket& recv = pair_.second; + + std::vector wbuf(1, 1); + std::vector rbuf(1, 0); + + Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0}; + Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0}; + + loop_->Submit(wop); + loop_->Submit(rop); + + auto rc = loop_->Block(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + ASSERT_EQ(rbuf[0], wbuf[0]); +} +} // namespace xgboost::collective