[coll] Improve event loop. (#10199)
- Add a test for blocking calls. - Do not require the queue to be empty after waking up; this frees up the thread to answer blocking calls. - Handle EOF in read. - Improve the error message in the result. Allow concatenation of multiple results.
This commit is contained in:
@@ -18,9 +18,11 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] { timer_.Stop(__func__); };
|
||||
auto error = [this] {
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
if (stop_) {
|
||||
timer_.Stop(__func__);
|
||||
@@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
poll.WatchWrite(*op.sock);
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
timer_.Start("poll");
|
||||
auto rc = poll.Poll(timeout_);
|
||||
timer_.Stop("poll");
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
if (!poll.fds.empty()) {
|
||||
auto rc = poll.Poll(timeout_);
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
timer_.Stop("poll");
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
@@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
CHECK(op.sock->NonBlocking());
|
||||
if (!op.sock) {
|
||||
CHECK(op.code == Op::kSleep);
|
||||
} else {
|
||||
CHECK(op.sock->NonBlocking());
|
||||
}
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
if (n_bytes_done == 0) {
|
||||
error();
|
||||
return Fail("Encountered EOF. The other end is likely closed.");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Op::kSleep: {
|
||||
// For testing only.
|
||||
std::this_thread::sleep_for(std::chrono::seconds{op.n});
|
||||
n_bytes_done = op.n;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
@@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!blocking) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
@@ -128,6 +153,15 @@ void Loop::Process() {
|
||||
while (true) {
|
||||
try {
|
||||
std::unique_lock lock{mu_};
|
||||
// This can handle missed notification: wait(lock, predicate) is equivalent to:
|
||||
//
|
||||
// while (!predicate()) {
|
||||
// cv.wait(lock);
|
||||
// }
|
||||
//
|
||||
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
|
||||
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
|
||||
// the blocking call can never go unanswered.
|
||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||
if (stop_) {
|
||||
break; // only point where this loop can exit.
|
||||
@@ -142,26 +176,27 @@ void Loop::Process() {
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
// Block must be the last op in the current batch since no further submit can be
|
||||
// issued until the blocking call is finished.
|
||||
CHECK(queue_.empty());
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_blocking) {
|
||||
// Unblock, we can write to the global queue again.
|
||||
lock.unlock();
|
||||
lock.unlock();
|
||||
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||
// worker thread (but not the client thread), wait until all operations are
|
||||
// finished.
|
||||
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||
|
||||
if (is_blocking && rc.OK()) {
|
||||
CHECK(qcopy.empty());
|
||||
}
|
||||
|
||||
// Clear the local queue, this is blocking the current worker thread (but not the
|
||||
// client thread), wait until all operations are finished.
|
||||
auto rc = this->EmptyQueue(&qcopy);
|
||||
|
||||
if (is_blocking) {
|
||||
// The unlock is delayed if this is a blocking call
|
||||
lock.unlock();
|
||||
// Push back the remaining operations.
|
||||
if (rc.OK()) {
|
||||
std::unique_lock lock{mu_};
|
||||
while (!qcopy.empty()) {
|
||||
queue_.push(qcopy.front());
|
||||
qcopy.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the client thread who called block after all error conditions are set.
|
||||
@@ -228,7 +263,6 @@ Result Loop::Stop() {
|
||||
}
|
||||
|
||||
this->Submit(Op{Op::kBlock});
|
||||
|
||||
{
|
||||
// Wait for the block call to finish.
|
||||
std::unique_lock lock{mu_};
|
||||
@@ -243,8 +277,20 @@ Result Loop::Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
void Loop::Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
if (op.code != Op::kBlock) {
|
||||
CHECK_NE(op.n, 0);
|
||||
}
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] { this->Process(); }};
|
||||
worker_ = std::thread{[this] {
|
||||
this->Process();
|
||||
}};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -19,20 +19,27 @@ namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||
// kSleep is only for testing
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
Op& operator=(Op const&) = default;
|
||||
Op(Op&&) = default;
|
||||
Op& operator=(Op&&) = default;
|
||||
// For testing purpose only
|
||||
[[nodiscard]] static Op Sleep(std::size_t seconds) {
|
||||
Op op{kSleep};
|
||||
op.n = seconds;
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -54,7 +61,7 @@ class Loop {
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||
// The cunsumer function that runs inside a worker thread.
|
||||
void Process();
|
||||
|
||||
@@ -64,12 +71,7 @@ class Loop {
|
||||
*/
|
||||
Result Stop();
|
||||
|
||||
void Submit(Op op) {
|
||||
std::unique_lock lock{mu_};
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
void Submit(Op op);
|
||||
|
||||
/**
|
||||
* @brief Block the event loop until all ops are finished. In the case of failure, this
|
||||
|
||||
86
src/collective/result.cc
Normal file
86
src/collective/result.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#include <filesystem> // for path
|
||||
#include <sstream> // for stringstream
|
||||
#include <stack> // for stack
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace detail {
|
||||
[[nodiscard]] std::string ResultImpl::Report() const {
|
||||
std::stringstream ss;
|
||||
ss << "\n- " << this->message;
|
||||
if (this->errc != std::error_code{}) {
|
||||
ss << " system error:" << this->errc.message();
|
||||
}
|
||||
|
||||
auto ptr = prev.get();
|
||||
while (ptr) {
|
||||
ss << "\n- ";
|
||||
ss << ptr->message;
|
||||
|
||||
if (ptr->errc != std::error_code{}) {
|
||||
ss << " " << ptr->errc.message();
|
||||
}
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::error_code ResultImpl::Code() const {
|
||||
// Find the root error.
|
||||
std::stack<ResultImpl const*> stack;
|
||||
auto ptr = this;
|
||||
while (ptr) {
|
||||
stack.push(ptr);
|
||||
if (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!stack.empty()) {
|
||||
auto frame = stack.top();
|
||||
stack.pop();
|
||||
if (frame->errc != std::error_code{}) {
|
||||
return frame->errc;
|
||||
}
|
||||
}
|
||||
return std::error_code{};
|
||||
}
|
||||
|
||||
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||
auto ptr = this;
|
||||
while (ptr->prev) {
|
||||
ptr = ptr->prev.get();
|
||||
}
|
||||
ptr->prev = std::move(rhs);
|
||||
}
|
||||
|
||||
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#else
|
||||
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||
auto name = std::filesystem::path{file}.filename();
|
||||
if (file && line != -1) {
|
||||
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||
"]: " + std::forward<std::string>(msg);
|
||||
}
|
||||
return std::forward<std::string>(msg);
|
||||
}
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
void SafeColl(Result const& rc) {
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
Reference in New Issue
Block a user