Build a simple event loop for collective. (#9593)
This commit is contained in:
167
src/collective/loop.cc
Normal file
167
src/collective/loop.cc
Normal file
@@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include "loop.h"
|
||||
|
||||
#include <queue> // for queue
|
||||
|
||||
#include "rabit/internal/socket.h" // for PollHelper
|
||||
#include "xgboost/collective/socket.h" // for FailWithCode
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue() {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
this->stop_ = true;
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
|
||||
while (!queue_.empty() && !stop_) {
|
||||
std::queue<Op> qcopy;
|
||||
rabit::utils::PollHelper poll;
|
||||
|
||||
// watch all ops
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
poll.WatchRead(*op.sock);
|
||||
break;
|
||||
}
|
||||
case Op::kWrite: {
|
||||
poll.WatchWrite(*op.sock);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
// poll, work on fds that are ready.
|
||||
timer_.Start("poll");
|
||||
auto rc = poll.Poll(timeout_);
|
||||
timer_.Stop("poll");
|
||||
if (!rc.OK()) {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
while (!qcopy.empty() && !stop_) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
std::int32_t n_bytes_done{0};
|
||||
CHECK(op.sock->NonBlocking());
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
if (poll.CheckRead(*op.sock)) {
|
||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Op::kWrite: {
|
||||
if (poll.CheckWrite(*op.sock)) {
|
||||
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
error();
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
stop_ = true;
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
op.off += n_bytes_done;
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
queue_.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
timer_.Stop(__func__);
|
||||
return Success();
|
||||
}
|
||||
|
||||
void Loop::Process() {
|
||||
// consumer
|
||||
while (true) {
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||
if (stop_) {
|
||||
break;
|
||||
}
|
||||
CHECK(!mu_.try_lock());
|
||||
|
||||
this->rc_ = this->EmptyQueue();
|
||||
if (!rc_.OK()) {
|
||||
stop_ = true;
|
||||
cv_.notify_one();
|
||||
break;
|
||||
}
|
||||
|
||||
CHECK(queue_.empty());
|
||||
CHECK(!mu_.try_lock());
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
if (rc_.OK()) {
|
||||
CHECK(queue_.empty());
|
||||
}
|
||||
}
|
||||
|
||||
Result Loop::Stop() {
|
||||
std::unique_lock lock{mu_};
|
||||
stop_ = true;
|
||||
lock.unlock();
|
||||
|
||||
CHECK_EQ(this->Block().OK(), this->rc_.OK());
|
||||
|
||||
if (curr_exce_) {
|
||||
std::rethrow_exception(curr_exce_);
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] {
|
||||
try {
|
||||
this->Process();
|
||||
} catch (std::exception const& e) {
|
||||
std::lock_guard<std::mutex> guard{mu_};
|
||||
if (!curr_exce_) {
|
||||
curr_exce_ = std::current_exception();
|
||||
rc_ = Fail("Exception was thrown");
|
||||
}
|
||||
stop_ = true;
|
||||
cv_.notify_all();
|
||||
} catch (...) {
|
||||
std::lock_guard<std::mutex> guard{mu_};
|
||||
if (!curr_exce_) {
|
||||
curr_exce_ = std::current_exception();
|
||||
rc_ = Fail("Exception was thrown");
|
||||
}
|
||||
stop_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
}};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
83
src/collective/loop.h
Normal file
83
src/collective/loop.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <condition_variable> // for condition_variable
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <exception> // for exception_ptr
|
||||
#include <mutex> // for unique_lock, mutex
|
||||
#include <queue> // for queue
|
||||
#include <thread> // for thread
|
||||
#include <utility> // for move
|
||||
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
|
||||
namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
Op& operator=(Op const&) = default;
|
||||
Op(Op&&) = default;
|
||||
Op& operator=(Op&&) = default;
|
||||
};
|
||||
|
||||
private:
|
||||
std::thread worker_;
|
||||
std::condition_variable cv_;
|
||||
std::mutex mu_;
|
||||
std::queue<Op> queue_;
|
||||
std::chrono::seconds timeout_;
|
||||
Result rc_;
|
||||
bool stop_{false};
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor timer_;
|
||||
|
||||
Result EmptyQueue();
|
||||
void Process();
|
||||
|
||||
public:
|
||||
Result Stop();
|
||||
|
||||
void Submit(Op op) {
|
||||
// producer
|
||||
std::unique_lock lock{mu_};
|
||||
queue_.push(op);
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Block() {
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.notify_all();
|
||||
}
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
|
||||
return std::move(rc_);
|
||||
}
|
||||
|
||||
explicit Loop(std::chrono::seconds timeout);
|
||||
|
||||
~Loop() noexcept(false) {
|
||||
this->Stop();
|
||||
|
||||
if (worker_.joinable()) {
|
||||
worker_.join();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -118,36 +118,42 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
}
|
||||
|
||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
return conn.NonBlocking(non_blocking);
|
||||
} else {
|
||||
if (rc == 0) {
|
||||
return conn.NonBlocking(non_blocking);
|
||||
}
|
||||
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
// poll would fail if there's a socket error, we log the root cause instead of the
|
||||
// poll failure.
|
||||
auto sockerr = conn.GetSockError();
|
||||
if (!sockerr.OK()) {
|
||||
result = std::move(sockerr);
|
||||
}
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__,
|
||||
__LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
return conn.NonBlocking(non_blocking);
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
Reference in New Issue
Block a user