[coll] Reduce the scope of lock in the event loop. (#9784)
This commit is contained in:
parent
36a552ac98
commit
ada377c57e
@ -412,19 +412,24 @@ class TCPSocket {
|
||||
return Success();
|
||||
}
|
||||
|
||||
void SetKeepAlive() {
|
||||
[[nodiscard]] Result SetKeepAlive() {
|
||||
std::int32_t keepalive = 1;
|
||||
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
|
||||
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
|
||||
0);
|
||||
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
|
||||
sizeof(keepalive));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP keeaplive.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
void SetNoDelay() {
|
||||
[[nodiscard]] Result SetNoDelay() {
|
||||
std::int32_t tcp_no_delay = 1;
|
||||
xgboost_CHECK_SYS_CALL(
|
||||
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay)),
|
||||
0);
|
||||
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||
sizeof(tcp_no_delay));
|
||||
if (rc != 0) {
|
||||
return system::FailWithCode("Failed to set TCP no delay.");
|
||||
}
|
||||
return Success();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
CHECK(all_link.sock.NonBlocking(true).OK());
|
||||
all_link.sock.SetKeepAlive();
|
||||
CHECK(all_link.sock.SetKeepAlive().OK());
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
all_link.sock.SetNoDelay();
|
||||
CHECK(all_link.sock.SetNoDelay().OK());
|
||||
}
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <algorithm> // for min
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../data/array_interface.h" // for Type, DispatchDType
|
||||
@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||
|
||||
prev_ch->RecvAll(seg);
|
||||
auto rc = prev_ch->Block();
|
||||
auto rc = comm.Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
return comm.Block();
|
||||
return std::move(rc) << [&] {
|
||||
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||
} << [&] { return comm.Block(); };
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
|
||||
@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
|
||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||
std::int32_t world) {
|
||||
// get information from tracker
|
||||
// Get information from the tracker
|
||||
CHECK(!info.host.empty());
|
||||
TCPSocket& tracker = *out;
|
||||
return Success() << [&] {
|
||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||
if (!rc.OK()) {
|
||||
if (rc.OK()) {
|
||||
return rc;
|
||||
} else {
|
||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
TCPSocket& tracker = *out;
|
||||
return std::move(rc)
|
||||
<< [&] { return tracker.NonBlocking(false); }
|
||||
<< [&] { return tracker.RecvTimeout(timeout); }
|
||||
<< [&] { return proto::Magic{}.Verify(&tracker); }
|
||||
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
|
||||
} << [&] {
|
||||
return tracker.NonBlocking(false);
|
||||
} << [&] {
|
||||
return tracker.RecvTimeout(timeout);
|
||||
} << [&] {
|
||||
return proto::Magic{}.Verify(&tracker);
|
||||
} << [&] {
|
||||
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
|
||||
} << [&] {
|
||||
LOG(INFO) << "Task " << task_id << " connected to the tracker";
|
||||
return Success();
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
||||
@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
||||
CHECK(this->channels_.empty());
|
||||
for (auto& w : workers) {
|
||||
if (w) {
|
||||
w->SetNoDelay();
|
||||
rc = w->NonBlocking(true);
|
||||
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
||||
<< [&] { return w->SetKeepAlive(); };
|
||||
}
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
|
||||
@ -10,21 +10,26 @@
|
||||
#include "xgboost/logging.h" // for CHECK
|
||||
|
||||
namespace xgboost::collective {
|
||||
Result Loop::EmptyQueue() {
|
||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||
timer_.Start(__func__);
|
||||
auto error = [this] {
|
||||
this->stop_ = true;
|
||||
auto error = [this] { timer_.Stop(__func__); };
|
||||
|
||||
if (stop_) {
|
||||
timer_.Stop(__func__);
|
||||
};
|
||||
return Success();
|
||||
}
|
||||
|
||||
while (!queue_.empty() && !stop_) {
|
||||
std::queue<Op> qcopy;
|
||||
auto& qcopy = *p_queue;
|
||||
|
||||
// clear the copied queue
|
||||
while (!qcopy.empty()) {
|
||||
rabit::utils::PollHelper poll;
|
||||
std::size_t n_ops = qcopy.size();
|
||||
|
||||
// watch all ops
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
// Iterate through all the ops for poll
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
switch (op.code) {
|
||||
case Op::kRead: {
|
||||
@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
|
||||
return Fail("Invalid socket operation.");
|
||||
}
|
||||
}
|
||||
|
||||
qcopy.push(op);
|
||||
}
|
||||
|
||||
@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
|
||||
// we wonldn't be here if the queue is empty.
|
||||
CHECK(!qcopy.empty());
|
||||
|
||||
while (!qcopy.empty() && !stop_) {
|
||||
// Iterate through all the ops for performing the operations
|
||||
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||
auto op = qcopy.front();
|
||||
qcopy.pop();
|
||||
|
||||
@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
|
||||
}
|
||||
|
||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||
stop_ = true;
|
||||
auto rc = system::FailWithCode("Invalid socket output.");
|
||||
error();
|
||||
return rc;
|
||||
}
|
||||
|
||||
op.off += n_bytes_done;
|
||||
CHECK_LE(op.off, op.n);
|
||||
|
||||
if (op.off != op.n) {
|
||||
// not yet finished, push back to queue for next round.
|
||||
queue_.push(op);
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timer_.Stop(__func__);
|
||||
return Success();
|
||||
}
|
||||
@ -107,22 +116,42 @@ void Loop::Process() {
|
||||
if (stop_) {
|
||||
break;
|
||||
}
|
||||
CHECK(!mu_.try_lock());
|
||||
|
||||
this->rc_ = this->EmptyQueue();
|
||||
if (!rc_.OK()) {
|
||||
stop_ = true;
|
||||
auto unlock_notify = [&](bool is_blocking) {
|
||||
if (!is_blocking) {
|
||||
return;
|
||||
}
|
||||
lock.unlock();
|
||||
cv_.notify_one();
|
||||
break;
|
||||
};
|
||||
|
||||
// move the queue
|
||||
std::queue<Op> qcopy;
|
||||
bool is_blocking = false;
|
||||
while (!queue_.empty()) {
|
||||
auto op = queue_.front();
|
||||
queue_.pop();
|
||||
if (op.code == Op::kBlock) {
|
||||
is_blocking = true;
|
||||
} else {
|
||||
qcopy.push(op);
|
||||
}
|
||||
}
|
||||
// unblock the queue
|
||||
if (!is_blocking) {
|
||||
lock.unlock();
|
||||
}
|
||||
// clear the queue
|
||||
auto rc = this->EmptyQueue(&qcopy);
|
||||
// Handle error
|
||||
if (!rc.OK()) {
|
||||
this->rc_ = std::move(rc);
|
||||
unlock_notify(is_blocking);
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(queue_.empty());
|
||||
CHECK(!mu_.try_lock());
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
if (rc_.OK()) {
|
||||
CHECK(queue_.empty());
|
||||
CHECK(qcopy.empty());
|
||||
unlock_notify(is_blocking);
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,6 +169,15 @@ Result Loop::Stop() {
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Loop::Block() {
|
||||
this->Submit(Op{Op::kBlock});
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
|
||||
}
|
||||
return std::move(rc_);
|
||||
}
|
||||
|
||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||
timer_.Init(__func__);
|
||||
worker_ = std::thread{[this] {
|
||||
|
||||
@ -20,13 +20,14 @@ namespace xgboost::collective {
|
||||
class Loop {
|
||||
public:
|
||||
struct Op {
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
|
||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||
std::int32_t rank{-1};
|
||||
std::int8_t* ptr{nullptr};
|
||||
std::size_t n{0};
|
||||
TCPSocket* sock{nullptr};
|
||||
std::size_t off{0};
|
||||
|
||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||
Op(Op const&) = default;
|
||||
@ -44,9 +45,9 @@ class Loop {
|
||||
Result rc_;
|
||||
bool stop_{false};
|
||||
std::exception_ptr curr_exce_{nullptr};
|
||||
common::Monitor timer_;
|
||||
common::Monitor mutable timer_;
|
||||
|
||||
Result EmptyQueue();
|
||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||
void Process();
|
||||
|
||||
public:
|
||||
@ -60,15 +61,7 @@ class Loop {
|
||||
cv_.notify_one();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result Block() {
|
||||
{
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.notify_all();
|
||||
}
|
||||
std::unique_lock lock{mu_};
|
||||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
|
||||
return std::move(rc_);
|
||||
}
|
||||
[[nodiscard]] Result Block();
|
||||
|
||||
explicit Loop(std::chrono::seconds timeout);
|
||||
|
||||
|
||||
@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
|
||||
void Basic() {
|
||||
{
|
||||
std::vector<double> data(13, 0.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||
}
|
||||
{
|
||||
std::vector<double> data(1, 1.0);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||
}
|
||||
}
|
||||
|
||||
void Acc() {
|
||||
std::vector<double> data(314, 1.5);
|
||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||
rhs[i] += lhs[i];
|
||||
}
|
||||
});
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||
auto v = data[i];
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user