[coll] Reduce the scope of lock in the event loop. (#9784)

This commit is contained in:
Jiaming Yuan 2023-11-15 14:16:19 +08:00 committed by GitHub
parent 36a552ac98
commit ada377c57e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 117 additions and 70 deletions

View File

@ -412,19 +412,24 @@ class TCPSocket {
return Success(); return Success();
} }
void SetKeepAlive() { [[nodiscard]] Result SetKeepAlive() {
std::int32_t keepalive = 1; std::int32_t keepalive = 1;
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)), sizeof(keepalive));
0); if (rc != 0) {
return system::FailWithCode("Failed to set TCP keeaplive.");
}
return Success();
} }
void SetNoDelay() { [[nodiscard]] Result SetNoDelay() {
std::int32_t tcp_no_delay = 1; std::int32_t tcp_no_delay = 1;
xgboost_CHECK_SYS_CALL( auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay), sizeof(tcp_no_delay));
sizeof(tcp_no_delay)), if (rc != 0) {
0); return system::FailWithCode("Failed to set TCP no delay.");
}
return Success();
} }
/** /**

View File

@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive // set the socket to non-blocking mode, enable TCP keepalive
CHECK(all_link.sock.NonBlocking(true).OK()); CHECK(all_link.sock.NonBlocking(true).OK());
all_link.sock.SetKeepAlive(); CHECK(all_link.sock.SetKeepAlive().OK());
if (rabit_enable_tcp_no_delay) { if (rabit_enable_tcp_no_delay) {
all_link.sock.SetNoDelay(); CHECK(all_link.sock.SetNoDelay().OK());
} }
if (tree_neighbors.count(all_link.rank) != 0) { if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) { if (all_link.rank == parent_rank) {

View File

@ -6,6 +6,7 @@
#include <algorithm> // for min #include <algorithm> // for min
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t #include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include <vector> // for vector #include <vector> // for vector
#include "../data/array_interface.h" // for Type, DispatchDType #include "../data/array_interface.h" // for Type, DispatchDType
@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto seg = s_buf.subspan(0, recv_seg.size()); auto seg = s_buf.subspan(0, recv_seg.size());
prev_ch->RecvAll(seg); prev_ch->RecvAll(seg);
auto rc = prev_ch->Block(); auto rc = comm.Block();
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto prev_ch = comm.Chan(prev); auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next); auto next_ch = comm.Chan(next);
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); return std::move(rc) << [&] {
if (!rc.OK()) { return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
return rc; } << [&] { return comm.Block(); };
}
return comm.Block();
}); });
} }
} // namespace xgboost::collective::cpu_impl } // namespace xgboost::collective::cpu_impl

View File

@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank, std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) { std::int32_t world) {
// get information from tracker // Get information from the tracker
CHECK(!info.host.empty()); CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
TCPSocket& tracker = *out; TCPSocket& tracker = *out;
return std::move(rc) return Success() << [&] {
<< [&] { return tracker.NonBlocking(false); } auto rc = Connect(info.host, info.port, retry, timeout, out);
<< [&] { return tracker.RecvTimeout(timeout); } if (rc.OK()) {
<< [&] { return proto::Magic{}.Verify(&tracker); } return rc;
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); }; } else {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
} << [&] {
return tracker.NonBlocking(false);
} << [&] {
return tracker.RecvTimeout(timeout);
} << [&] {
return proto::Magic{}.Verify(&tracker);
} << [&] {
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
} << [&] {
LOG(INFO) << "Task " << task_id << " connected to the tracker";
return Success();
};
} }
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const { [[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
CHECK(this->channels_.empty()); CHECK(this->channels_.empty());
for (auto& w : workers) { for (auto& w : workers) {
if (w) { if (w) {
w->SetNoDelay(); rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
rc = w->NonBlocking(true); << [&] { return w->SetKeepAlive(); };
} }
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;

View File

@ -10,21 +10,26 @@
#include "xgboost/logging.h" // for CHECK #include "xgboost/logging.h" // for CHECK
namespace xgboost::collective { namespace xgboost::collective {
Result Loop::EmptyQueue() { Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__); timer_.Start(__func__);
auto error = [this] { auto error = [this] { timer_.Stop(__func__); };
this->stop_ = true;
if (stop_) {
timer_.Stop(__func__); timer_.Stop(__func__);
}; return Success();
}
while (!queue_.empty() && !stop_) { auto& qcopy = *p_queue;
std::queue<Op> qcopy;
// clear the copied queue
while (!qcopy.empty()) {
rabit::utils::PollHelper poll; rabit::utils::PollHelper poll;
std::size_t n_ops = qcopy.size();
// watch all ops // Iterate through all the ops for poll
while (!queue_.empty()) { for (std::size_t i = 0; i < n_ops; ++i) {
auto op = queue_.front(); auto op = qcopy.front();
queue_.pop(); qcopy.pop();
switch (op.code) { switch (op.code) {
case Op::kRead: { case Op::kRead: {
@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
return Fail("Invalid socket operation."); return Fail("Invalid socket operation.");
} }
} }
qcopy.push(op); qcopy.push(op);
} }
@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
error(); error();
return rc; return rc;
} }
// we wonldn't be here if the queue is empty. // we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty()); CHECK(!qcopy.empty());
while (!qcopy.empty() && !stop_) { // Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front(); auto op = qcopy.front();
qcopy.pop(); qcopy.pop();
@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
} }
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output."); auto rc = system::FailWithCode("Invalid socket output.");
error(); error();
return rc; return rc;
} }
op.off += n_bytes_done; op.off += n_bytes_done;
CHECK_LE(op.off, op.n); CHECK_LE(op.off, op.n);
if (op.off != op.n) { if (op.off != op.n) {
// not yet finished, push back to queue for next round. // not yet finished, push back to queue for next round.
queue_.push(op); qcopy.push(op);
} }
} }
} }
timer_.Stop(__func__); timer_.Stop(__func__);
return Success(); return Success();
} }
@ -107,22 +116,42 @@ void Loop::Process() {
if (stop_) { if (stop_) {
break; break;
} }
CHECK(!mu_.try_lock());
this->rc_ = this->EmptyQueue(); auto unlock_notify = [&](bool is_blocking) {
if (!rc_.OK()) { if (!is_blocking) {
stop_ = true; return;
}
lock.unlock();
cv_.notify_one(); cv_.notify_one();
break; };
// move the queue
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
}
// unblock the queue
if (!is_blocking) {
lock.unlock();
}
// clear the queue
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
this->rc_ = std::move(rc);
unlock_notify(is_blocking);
return;
} }
CHECK(queue_.empty()); CHECK(qcopy.empty());
CHECK(!mu_.try_lock()); unlock_notify(is_blocking);
cv_.notify_one();
}
if (rc_.OK()) {
CHECK(queue_.empty());
} }
} }
@ -140,6 +169,15 @@ Result Loop::Stop() {
return Success(); return Success();
} }
[[nodiscard]] Result Loop::Block() {
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
return std::move(rc_);
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__); timer_.Init(__func__);
worker_ = std::thread{[this] { worker_ = std::thread{[this] {

View File

@ -20,13 +20,14 @@ namespace xgboost::collective {
class Loop { class Loop {
public: public:
struct Op { struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code; enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
std::int32_t rank{-1}; std::int32_t rank{-1};
std::int8_t* ptr{nullptr}; std::int8_t* ptr{nullptr};
std::size_t n{0}; std::size_t n{0};
TCPSocket* sock{nullptr}; TCPSocket* sock{nullptr};
std::size_t off{0}; std::size_t off{0};
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default; Op(Op const&) = default;
@ -44,9 +45,9 @@ class Loop {
Result rc_; Result rc_;
bool stop_{false}; bool stop_{false};
std::exception_ptr curr_exce_{nullptr}; std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_; common::Monitor mutable timer_;
Result EmptyQueue(); Result EmptyQueue(std::queue<Op>* p_queue) const;
void Process(); void Process();
public: public:
@ -60,15 +61,7 @@ class Loop {
cv_.notify_one(); cv_.notify_one();
} }
[[nodiscard]] Result Block() { [[nodiscard]] Result Block();
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
explicit Loop(std::chrono::seconds timeout); explicit Loop(std::chrono::seconds timeout);

View File

@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
void Basic() { void Basic() {
{ {
std::vector<double> data(13, 0.0); std::vector<double> data(13, 0.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) { for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK());
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0); ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
} }
{ {
std::vector<double> data(1, 1.0); std::vector<double> data(1, 1.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) { for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK());
ASSERT_EQ(data[0], static_cast<double>(comm_.World())); ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
} }
} }
void Acc() { void Acc() {
std::vector<double> data(314, 1.5); std::vector<double> data(314, 1.5);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) { for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i]; rhs[i] += lhs[i];
} }
}); });
ASSERT_TRUE(rc.OK());
for (std::size_t i = 0; i < data.size(); ++i) { for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i]; auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i; ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;