Build a simple event loop for collective. (#9593)

This commit is contained in:
Jiaming Yuan 2023-09-20 02:09:07 +08:00 committed by GitHub
parent 259d80c0cf
commit 38ac52dd87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 402 additions and 47 deletions

View File

@ -16,8 +16,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "rabit/internal/utils.h" #include "dmlc/io.h"
#include "rabit/serializable.h" #include "xgboost/logging.h"
namespace rabit::utils { namespace rabit::utils {
/*! \brief re-use definition of dmlc::SeekStream */ /*! \brief re-use definition of dmlc::SeekStream */
@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream {
} }
~MemoryBufferStream() override = default; ~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override { size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(), CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread; curr_ptr_ += nread;

View File

@ -29,11 +29,10 @@
#include <chrono> #include <chrono>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <system_error> // make_error_code, errc
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "utils.h"
#if !defined(_WIN32) #if !defined(_WIN32)
#include <sys/poll.h> #include <sys/poll.h>
@ -93,6 +92,20 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
#endif // IS_MINGW() #endif // IS_MINGW()
} }
template <typename E>
std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
if ((revents & POLLERR) != 0) {
return xgboost::system::FailWithCode("Poll error condition.");
}
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up.");
}
return xgboost::collective::Success();
}
/*! \brief helper data structure to perform poll */ /*! \brief helper data structure to perform poll */
struct PollHelper { struct PollHelper {
public: public:
@ -160,27 +173,34 @@ struct PollHelper {
* *
* @param timeout specify timeout in seconds. Block if negative. * @param timeout specify timeout in seconds. Block if negative.
*/ */
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) { [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
bool check_error = true) {
std::vector<pollfd> fdset; std::vector<pollfd> fdset;
fdset.reserve(fds.size()); fdset.reserve(fds.size());
for (auto kv : fds) { for (auto kv : fds) {
fdset.push_back(kv.second); fdset.push_back(kv.second);
} }
int ret = PollImpl(fdset.data(), fdset.size(), timeout); std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) { if (ret == 0) {
return xgboost::collective::Fail("Poll timeout."); return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
} else if (ret < 0) { } else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed."); return xgboost::system::FailWithCode("Poll failed.");
} else { }
for (auto& pfd : fdset) { for (auto& pfd : fdset) {
auto result = PollError(pfd.revents);
if (check_error && !result.OK()) {
return result;
}
auto revents = pfd.revents & pfd.events; auto revents = pfd.revents & pfd.events;
if (!revents) { if (!revents) {
// FIXME(jiamingy): remove this once rabit is replaced.
fds.erase(pfd.fd); fds.erase(pfd.fd);
} else { } else {
fds[pfd.fd].events = revents; fds[pfd.fd].events = revents;
} }
} }
}
return xgboost::collective::Success(); return xgboost::collective::Success();
} }

View File

@ -721,12 +721,11 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} }
finished = false; finished = false;
} }
watcher.WatchException(links[i].sock);
} }
// finish running // finish running
if (finished) break; if (finished) break;
// select // select
auto poll_res = watcher.Poll(timeout_sec); auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) { if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report(); LOG(FATAL) << poll_res.Report();
} }

167
src/collective/loop.cc Normal file
View File

@ -0,0 +1,167 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "loop.h"
#include <queue> // for queue
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/socket.h" // for FailWithCode
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::EmptyQueue() {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
timer_.Stop(__func__);
};
while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
rabit::utils::PollHelper poll;
// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
switch (op.code) {
case Op::kRead: {
poll.WatchRead(*op.sock);
break;
}
case Op::kWrite: {
poll.WatchWrite(*op.sock);
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
}
// poll, work on fds that are ready.
timer_.Start("poll");
auto rc = poll.Poll(timeout_);
timer_.Stop("poll");
if (!rc.OK()) {
error();
return rc;
}
// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
while (!qcopy.empty() && !stop_) {
auto op = qcopy.front();
qcopy.pop();
std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking());
switch (op.code) {
case Op::kRead: {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
}
break;
}
case Op::kWrite: {
if (poll.CheckWrite(*op.sock)) {
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off);
}
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}
op.off += n_bytes_done;
CHECK_LE(op.off, op.n);
if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
}
}
}
timer_.Stop(__func__);
return Success();
}
void Loop::Process() {
// consumer
while (true) {
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) {
break;
}
CHECK(!mu_.try_lock());
this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
cv_.notify_one();
break;
}
CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}
if (rc_.OK()) {
CHECK(queue_.empty());
}
}
Result Loop::Stop() {
std::unique_lock lock{mu_};
stop_ = true;
lock.unlock();
CHECK_EQ(this->Block().OK(), this->rc_.OK());
if (curr_exce_) {
std::rethrow_exception(curr_exce_);
}
return Success();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {
try {
this->Process();
} catch (std::exception const& e) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
} catch (...) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
}
}};
}
} // namespace xgboost::collective

83
src/collective/loop.h Normal file
View File

@ -0,0 +1,83 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <condition_variable> // for condition_variable
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t
#include <exception> // for exception_ptr
#include <mutex> // for unique_lock, mutex
#include <queue> // for queue
#include <thread> // for thread
#include <utility> // for move
#include "../common/timer.h" // for Monitor
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Op& operator=(Op const&) = default;
Op(Op&&) = default;
Op& operator=(Op&&) = default;
};
private:
std::thread worker_;
std::condition_variable cv_;
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;
Result rc_;
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;
Result EmptyQueue();
void Process();
public:
Result Stop();
void Submit(Op op) {
// producer
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
explicit Loop(std::chrono::seconds timeout);
~Loop() noexcept(false) {
this->Stop();
if (worker_.joinable()) {
worker_.join();
}
}
};
} // namespace xgboost::collective

View File

@ -118,7 +118,10 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
} }
auto rc = connect(conn.Handle(), addr_handle, addr_len); auto rc = connect(conn.Handle(), addr_handle, addr_len);
if (rc != 0) { if (rc == 0) {
return conn.NonBlocking(non_blocking);
}
auto errcode = system::LastError(); auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) { if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}), log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
@ -130,12 +133,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
poll.WatchWrite(conn); poll.WatchWrite(conn);
auto result = poll.Poll(timeout); auto result = poll.Poll(timeout);
if (!result.OK()) { if (!result.OK()) {
// poll would fail if there's a socket error, we log the root cause instead of the
// poll failure.
auto sockerr = conn.GetSockError();
if (!sockerr.OK()) {
result = std::move(sockerr);
}
log_failure(std::move(result), __FILE__, __LINE__); log_failure(std::move(result), __FILE__, __LINE__);
continue; continue;
} }
if (!poll.CheckWrite(conn)) { if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), __FILE__,
__FILE__, __LINE__); __LINE__);
continue; continue;
} }
result = conn.GetSockError(); result = conn.GetSockError();
@ -145,9 +154,6 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
} }
return conn.NonBlocking(non_blocking); return conn.NonBlocking(non_blocking);
} else {
return conn.NonBlocking(non_blocking);
}
} }
std::stringstream ss; std::stringstream ss;

View File

@ -0,0 +1,81 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h> // for ASSERT_TRUE, ASSERT_EQ
#include <xgboost/collective/socket.h> // for TCPSocket, Connect, SocketFinalize, SocketStartup
#include <xgboost/string_view.h> // for StringView
#include <chrono> // for seconds
#include <cstdint> // for int8_t
#include <memory> // for make_shared, shared_ptr
#include <system_error> // for make_error_code, errc
#include <utility> // for pair
#include <vector> // for vector
#include "../../../src/collective/loop.h" // for Loop
namespace xgboost::collective {
namespace {
class LoopTest : public ::testing::Test {
protected:
std::pair<TCPSocket, TCPSocket> pair_;
std::shared_ptr<Loop> loop_;
protected:
void SetUp() override {
system::SocketStartup();
std::chrono::seconds timeout{1};
auto domain = SockDomain::kV4;
pair_.first = TCPSocket::Create(domain);
auto port = pair_.first.BindHost();
pair_.first.Listen();
auto const& addr = SockAddrV4::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second);
ASSERT_TRUE(rc.OK());
rc = pair_.second.NonBlocking(true);
ASSERT_TRUE(rc.OK());
pair_.first = pair_.first.Accept();
rc = pair_.first.NonBlocking(true);
ASSERT_TRUE(rc.OK());
loop_ = std::make_shared<Loop>(timeout);
}
void TearDown() override {
pair_ = decltype(pair_){};
system::SocketFinalize();
}
};
} // namespace
TEST_F(LoopTest, Timeout) {
std::vector<std::int8_t> data(1);
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0};
loop_->Submit(op);
auto rc = loop_->Block();
ASSERT_FALSE(rc.OK());
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report();
}
TEST_F(LoopTest, Op) {
TCPSocket& send = pair_.first;
TCPSocket& recv = pair_.second;
std::vector<std::int8_t> wbuf(1, 1);
std::vector<std::int8_t> rbuf(1, 0);
Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};
loop_->Submit(wop);
loop_->Submit(rop);
auto rc = loop_->Block();
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(rbuf[0], wbuf[0]);
}
} // namespace xgboost::collective