diff --git a/src/collective/loop.cc b/src/collective/loop.cc index 5cfb0034d..b51749fcd 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -1,11 +1,19 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "loop.h" -#include // for queue +#include // for size_t +#include // for int32_t +#include // for exception, current_exception, rethrow_exception +#include // for lock_guard, unique_lock +#include // for queue +#include // for string +#include // for thread +#include // for move #include "rabit/internal/socket.h" // for PollHelper +#include "xgboost/collective/result.h" // for Fail, Success #include "xgboost/collective/socket.h" // for FailWithCode #include "xgboost/logging.h" // for CHECK @@ -109,62 +117,94 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { } void Loop::Process() { - // consumer - while (true) { - std::unique_lock lock{mu_}; - cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); - if (stop_) { - break; - } + auto set_rc = [this](Result&& rc) { + std::lock_guard lock{rc_lock_}; + rc_ = std::forward(rc); + }; + + // This loop cannot exit unless `stop_` is set to true. There must always be a thread to + // answer the blocking call even if there are errors, otherwise the blocking will wait + // forever. + while (true) { + try { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); + if (stop_) { + break; // only point where this loop can exit. + } + + // Move the global queue into a local variable to unblock it. + std::queue qcopy; + + bool is_blocking = false; + while (!queue_.empty()) { + auto op = queue_.front(); + queue_.pop(); + if (op.code == Op::kBlock) { + is_blocking = true; + // Block must be the last op in the current batch since no further submit can be + // issued until the blocking call is finished. + CHECK(queue_.empty()); + } else { + qcopy.push(op); + } + } - auto unlock_notify = [&](bool is_blocking, bool stop) { if (!is_blocking) { - std::lock_guard guard{mu_}; - stop_ = stop; - } else { - stop_ = stop; + // Unblock, we can write to the global queue again. lock.unlock(); } - cv_.notify_one(); - }; - // move the queue - std::queue qcopy; - bool is_blocking = false; - while (!queue_.empty()) { - auto op = queue_.front(); - queue_.pop(); - if (op.code == Op::kBlock) { - is_blocking = true; - } else { - qcopy.push(op); + // Clear the local queue, this is blocking the current worker thread (but not the + // client thread), wait until all operations are finished. + auto rc = this->EmptyQueue(&qcopy); + + if (is_blocking) { + // The unlock is delayed if this is a blocking call + lock.unlock(); } - } - // unblock the queue - if (!is_blocking) { - lock.unlock(); - } - // clear the queue - auto rc = this->EmptyQueue(&qcopy); - // Handle error - if (!rc.OK()) { - unlock_notify(is_blocking, true); - std::lock_guard guard{rc_lock_}; - this->rc_ = std::move(rc); - return; - } - CHECK(qcopy.empty()); - unlock_notify(is_blocking, false); + // Notify the client thread who called block after all error conditions are set. + auto notify_if_block = [&] { + if (is_blocking) { + std::unique_lock lock{mu_}; + block_done_ = true; + lock.unlock(); + block_cv_.notify_one(); + } + }; + + // Handle error + if (!rc.OK()) { + set_rc(std::move(rc)); + } else { + CHECK(qcopy.empty()); + } + + notify_if_block(); + } catch (std::exception const& e) { + curr_exce_ = std::current_exception(); + set_rc(Fail("Exception inside the event loop:" + std::string{e.what()})); + } catch (...) { + curr_exce_ = std::current_exception(); + set_rc(Fail("Unknown exception inside the event loop.")); + } } } Result Loop::Stop() { + // Finish all remaining tasks + CHECK_EQ(this->Block().OK(), this->rc_.OK()); + + // Notify the loop to stop std::unique_lock lock{mu_}; stop_ = true; lock.unlock(); + this->cv_.notify_one(); - CHECK_EQ(this->Block().OK(), this->rc_.OK()); + if (this->worker_.joinable()) { + this->worker_.join(); + } if (curr_exce_) { std::rethrow_exception(curr_exce_); @@ -175,17 +215,29 @@ Result Loop::Stop() { [[nodiscard]] Result Loop::Block() { { + // Check whether the last op was successful, stop if not. std::lock_guard guard{rc_lock_}; if (!rc_.OK()) { - return std::move(rc_); + stop_ = true; } } - this->Submit(Op{Op::kBlock}); - { - std::unique_lock lock{mu_}; - cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; }); + + if (!this->worker_.joinable()) { + std::lock_guard guard{rc_lock_}; + return Fail("Worker has stopped.", std::move(rc_)); } + + this->Submit(Op{Op::kBlock}); + { + // Wait for the block call to finish. + std::unique_lock lock{mu_}; + block_cv_.wait(lock, [this] { return block_done_ || stop_; }); + block_done_ = false; + } + + { + // Transfer the rc. std::lock_guard lock{rc_lock_}; return std::move(rc_); } @@ -193,26 +245,6 @@ Result Loop::Stop() { Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); - worker_ = std::thread{[this] { - try { - this->Process(); - } catch (std::exception const& e) { - std::lock_guard guard{mu_}; - if (!curr_exce_) { - curr_exce_ = std::current_exception(); - rc_ = Fail("Exception was thrown"); - } - stop_ = true; - cv_.notify_all(); - } catch (...) { - std::lock_guard guard{mu_}; - if (!curr_exce_) { - curr_exce_ = std::current_exception(); - rc_ = Fail("Exception was thrown"); - } - stop_ = true; - cv_.notify_all(); - } - }}; + worker_ = std::thread{[this] { this->Process(); }}; } } // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h index 0c1fdcbfe..4839abfd3 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for seconds @@ -10,7 +10,6 @@ #include // for unique_lock, mutex #include // for queue #include // for thread -#include // for move #include "../common/timer.h" // for Monitor #include "xgboost/collective/result.h" // for Result @@ -37,10 +36,15 @@ class Loop { }; private: - std::thread worker_; - std::condition_variable cv_; - std::mutex mu_; - std::queue queue_; + std::thread worker_; // thread worker to execute the tasks + + std::condition_variable cv_; // CV used to notify a new submit call + std::condition_variable block_cv_; // CV used to notify the blocking call + bool block_done_{false}; // Flag to indicate whether the blocking call has finished. + + std::queue queue_; // event queue + std::mutex mu_; // mutex to protect the queue, cv, and block_done + std::chrono::seconds timeout_; Result rc_; @@ -51,29 +55,33 @@ class Loop { common::Monitor mutable timer_; Result EmptyQueue(std::queue* p_queue) const; + // The cunsumer function that runs inside a worker thread. void Process(); public: + /** + * @brief Stop the worker thread. + */ Result Stop(); void Submit(Op op) { - // producer std::unique_lock lock{mu_}; queue_.push(op); lock.unlock(); cv_.notify_one(); } + /** + * @brief Block the event loop until all ops are finished. In the case of failure, this + * loop should be not be used for new operations. + */ [[nodiscard]] Result Block(); explicit Loop(std::chrono::seconds timeout); ~Loop() noexcept(false) { + // The worker will be joined in the stop function. this->Stop(); - - if (worker_.joinable()) { - worker_.join(); - } } }; } // namespace xgboost::collective