[coll] Reduce the scope of lock in the event loop. (#9784)
This commit is contained in:
parent
36a552ac98
commit
ada377c57e
@ -412,19 +412,24 @@ class TCPSocket {
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetKeepAlive() {
|
[[nodiscard]] Result SetKeepAlive() {
|
||||||
std::int32_t keepalive = 1;
|
std::int32_t keepalive = 1;
|
||||||
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
|
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
|
||||||
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
|
sizeof(keepalive));
|
||||||
0);
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to set TCP keeaplive.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNoDelay() {
|
[[nodiscard]] Result SetNoDelay() {
|
||||||
std::int32_t tcp_no_delay = 1;
|
std::int32_t tcp_no_delay = 1;
|
||||||
xgboost_CHECK_SYS_CALL(
|
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
||||||
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
|
sizeof(tcp_no_delay));
|
||||||
sizeof(tcp_no_delay)),
|
if (rc != 0) {
|
||||||
0);
|
return system::FailWithCode("Failed to set TCP no delay.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
CHECK(all_link.sock.NonBlocking(true).OK());
|
CHECK(all_link.sock.NonBlocking(true).OK());
|
||||||
all_link.sock.SetKeepAlive();
|
CHECK(all_link.sock.SetKeepAlive().OK());
|
||||||
if (rabit_enable_tcp_no_delay) {
|
if (rabit_enable_tcp_no_delay) {
|
||||||
all_link.sock.SetNoDelay();
|
CHECK(all_link.sock.SetNoDelay().OK());
|
||||||
}
|
}
|
||||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||||
if (all_link.rank == parent_rank) {
|
if (all_link.rank == parent_rank) {
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include <algorithm> // for min
|
#include <algorithm> // for min
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t, int8_t
|
#include <cstdint> // for int32_t, int8_t
|
||||||
|
#include <utility> // for move
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../data/array_interface.h" // for Type, DispatchDType
|
#include "../data/array_interface.h" // for Type, DispatchDType
|
||||||
@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
|||||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
auto seg = s_buf.subspan(0, recv_seg.size());
|
||||||
|
|
||||||
prev_ch->RecvAll(seg);
|
prev_ch->RecvAll(seg);
|
||||||
auto rc = prev_ch->Block();
|
auto rc = comm.Block();
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
|||||||
auto prev_ch = comm.Chan(prev);
|
auto prev_ch = comm.Chan(prev);
|
||||||
auto next_ch = comm.Chan(next);
|
auto next_ch = comm.Chan(next);
|
||||||
|
|
||||||
rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
return std::move(rc) << [&] {
|
||||||
if (!rc.OK()) {
|
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||||
return rc;
|
} << [&] { return comm.Block(); };
|
||||||
}
|
|
||||||
return comm.Block();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective::cpu_impl
|
} // namespace xgboost::collective::cpu_impl
|
||||||
|
|||||||
@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
|
|||||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||||
std::int32_t world) {
|
std::int32_t world) {
|
||||||
// get information from tracker
|
// Get information from the tracker
|
||||||
CHECK(!info.host.empty());
|
CHECK(!info.host.empty());
|
||||||
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
|
||||||
if (!rc.OK()) {
|
|
||||||
return Fail("Failed to connect to the tracker.", std::move(rc));
|
|
||||||
}
|
|
||||||
|
|
||||||
TCPSocket& tracker = *out;
|
TCPSocket& tracker = *out;
|
||||||
return std::move(rc)
|
return Success() << [&] {
|
||||||
<< [&] { return tracker.NonBlocking(false); }
|
auto rc = Connect(info.host, info.port, retry, timeout, out);
|
||||||
<< [&] { return tracker.RecvTimeout(timeout); }
|
if (rc.OK()) {
|
||||||
<< [&] { return proto::Magic{}.Verify(&tracker); }
|
return rc;
|
||||||
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
|
} else {
|
||||||
|
return Fail("Failed to connect to the tracker.", std::move(rc));
|
||||||
|
}
|
||||||
|
} << [&] {
|
||||||
|
return tracker.NonBlocking(false);
|
||||||
|
} << [&] {
|
||||||
|
return tracker.RecvTimeout(timeout);
|
||||||
|
} << [&] {
|
||||||
|
return proto::Magic{}.Verify(&tracker);
|
||||||
|
} << [&] {
|
||||||
|
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
|
||||||
|
} << [&] {
|
||||||
|
LOG(INFO) << "Task " << task_id << " connected to the tracker";
|
||||||
|
return Success();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
|
||||||
@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
|
|||||||
CHECK(this->channels_.empty());
|
CHECK(this->channels_.empty());
|
||||||
for (auto& w : workers) {
|
for (auto& w : workers) {
|
||||||
if (w) {
|
if (w) {
|
||||||
w->SetNoDelay();
|
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
||||||
rc = w->NonBlocking(true);
|
<< [&] { return w->SetKeepAlive(); };
|
||||||
}
|
}
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
|
|||||||
@ -10,21 +10,26 @@
|
|||||||
#include "xgboost/logging.h" // for CHECK
|
#include "xgboost/logging.h" // for CHECK
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
Result Loop::EmptyQueue() {
|
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
auto error = [this] {
|
auto error = [this] { timer_.Stop(__func__); };
|
||||||
this->stop_ = true;
|
|
||||||
|
if (stop_) {
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
};
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
while (!queue_.empty() && !stop_) {
|
auto& qcopy = *p_queue;
|
||||||
std::queue<Op> qcopy;
|
|
||||||
|
// clear the copied queue
|
||||||
|
while (!qcopy.empty()) {
|
||||||
rabit::utils::PollHelper poll;
|
rabit::utils::PollHelper poll;
|
||||||
|
std::size_t n_ops = qcopy.size();
|
||||||
|
|
||||||
// watch all ops
|
// Iterate through all the ops for poll
|
||||||
while (!queue_.empty()) {
|
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||||
auto op = queue_.front();
|
auto op = qcopy.front();
|
||||||
queue_.pop();
|
qcopy.pop();
|
||||||
|
|
||||||
switch (op.code) {
|
switch (op.code) {
|
||||||
case Op::kRead: {
|
case Op::kRead: {
|
||||||
@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
|
|||||||
return Fail("Invalid socket operation.");
|
return Fail("Invalid socket operation.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
qcopy.push(op);
|
qcopy.push(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
|
|||||||
error();
|
error();
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// we wonldn't be here if the queue is empty.
|
// we wonldn't be here if the queue is empty.
|
||||||
CHECK(!qcopy.empty());
|
CHECK(!qcopy.empty());
|
||||||
|
|
||||||
while (!qcopy.empty() && !stop_) {
|
// Iterate through all the ops for performing the operations
|
||||||
|
for (std::size_t i = 0; i < n_ops; ++i) {
|
||||||
auto op = qcopy.front();
|
auto op = qcopy.front();
|
||||||
qcopy.pop();
|
qcopy.pop();
|
||||||
|
|
||||||
@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
|
||||||
stop_ = true;
|
|
||||||
auto rc = system::FailWithCode("Invalid socket output.");
|
auto rc = system::FailWithCode("Invalid socket output.");
|
||||||
error();
|
error();
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
op.off += n_bytes_done;
|
op.off += n_bytes_done;
|
||||||
CHECK_LE(op.off, op.n);
|
CHECK_LE(op.off, op.n);
|
||||||
|
|
||||||
if (op.off != op.n) {
|
if (op.off != op.n) {
|
||||||
// not yet finished, push back to queue for next round.
|
// not yet finished, push back to queue for next round.
|
||||||
queue_.push(op);
|
qcopy.push(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
@ -107,22 +116,42 @@ void Loop::Process() {
|
|||||||
if (stop_) {
|
if (stop_) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
CHECK(!mu_.try_lock());
|
|
||||||
|
|
||||||
this->rc_ = this->EmptyQueue();
|
auto unlock_notify = [&](bool is_blocking) {
|
||||||
if (!rc_.OK()) {
|
if (!is_blocking) {
|
||||||
stop_ = true;
|
return;
|
||||||
|
}
|
||||||
|
lock.unlock();
|
||||||
cv_.notify_one();
|
cv_.notify_one();
|
||||||
break;
|
};
|
||||||
|
|
||||||
|
// move the queue
|
||||||
|
std::queue<Op> qcopy;
|
||||||
|
bool is_blocking = false;
|
||||||
|
while (!queue_.empty()) {
|
||||||
|
auto op = queue_.front();
|
||||||
|
queue_.pop();
|
||||||
|
if (op.code == Op::kBlock) {
|
||||||
|
is_blocking = true;
|
||||||
|
} else {
|
||||||
|
qcopy.push(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// unblock the queue
|
||||||
|
if (!is_blocking) {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
// clear the queue
|
||||||
|
auto rc = this->EmptyQueue(&qcopy);
|
||||||
|
// Handle error
|
||||||
|
if (!rc.OK()) {
|
||||||
|
this->rc_ = std::move(rc);
|
||||||
|
unlock_notify(is_blocking);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(queue_.empty());
|
CHECK(qcopy.empty());
|
||||||
CHECK(!mu_.try_lock());
|
unlock_notify(is_blocking);
|
||||||
cv_.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rc_.OK()) {
|
|
||||||
CHECK(queue_.empty());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,6 +169,15 @@ Result Loop::Stop() {
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result Loop::Block() {
|
||||||
|
this->Submit(Op{Op::kBlock});
|
||||||
|
{
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
|
||||||
|
}
|
||||||
|
return std::move(rc_);
|
||||||
|
}
|
||||||
|
|
||||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||||
timer_.Init(__func__);
|
timer_.Init(__func__);
|
||||||
worker_ = std::thread{[this] {
|
worker_ = std::thread{[this] {
|
||||||
|
|||||||
@ -20,13 +20,14 @@ namespace xgboost::collective {
|
|||||||
class Loop {
|
class Loop {
|
||||||
public:
|
public:
|
||||||
struct Op {
|
struct Op {
|
||||||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
|
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
||||||
std::int32_t rank{-1};
|
std::int32_t rank{-1};
|
||||||
std::int8_t* ptr{nullptr};
|
std::int8_t* ptr{nullptr};
|
||||||
std::size_t n{0};
|
std::size_t n{0};
|
||||||
TCPSocket* sock{nullptr};
|
TCPSocket* sock{nullptr};
|
||||||
std::size_t off{0};
|
std::size_t off{0};
|
||||||
|
|
||||||
|
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
||||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||||
Op(Op const&) = default;
|
Op(Op const&) = default;
|
||||||
@ -44,9 +45,9 @@ class Loop {
|
|||||||
Result rc_;
|
Result rc_;
|
||||||
bool stop_{false};
|
bool stop_{false};
|
||||||
std::exception_ptr curr_exce_{nullptr};
|
std::exception_ptr curr_exce_{nullptr};
|
||||||
common::Monitor timer_;
|
common::Monitor mutable timer_;
|
||||||
|
|
||||||
Result EmptyQueue();
|
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
||||||
void Process();
|
void Process();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -60,15 +61,7 @@ class Loop {
|
|||||||
cv_.notify_one();
|
cv_.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Block() {
|
[[nodiscard]] Result Block();
|
||||||
{
|
|
||||||
std::unique_lock lock{mu_};
|
|
||||||
cv_.notify_all();
|
|
||||||
}
|
|
||||||
std::unique_lock lock{mu_};
|
|
||||||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
|
|
||||||
return std::move(rc_);
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit Loop(std::chrono::seconds timeout);
|
explicit Loop(std::chrono::seconds timeout);
|
||||||
|
|
||||||
|
|||||||
@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
|
|||||||
void Basic() {
|
void Basic() {
|
||||||
{
|
{
|
||||||
std::vector<double> data(13, 0.0);
|
std::vector<double> data(13, 0.0);
|
||||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
rhs[i] += lhs[i];
|
rhs[i] += lhs[i];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
std::vector<double> data(1, 1.0);
|
std::vector<double> data(1, 1.0);
|
||||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
rhs[i] += lhs[i];
|
rhs[i] += lhs[i];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Acc() {
|
void Acc() {
|
||||||
std::vector<double> data(314, 1.5);
|
std::vector<double> data(314, 1.5);
|
||||||
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
|
||||||
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
for (std::size_t i = 0; i < rhs.size(); ++i) {
|
||||||
rhs[i] += lhs[i];
|
rhs[i] += lhs[i];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
for (std::size_t i = 0; i < data.size(); ++i) {
|
for (std::size_t i = 0; i < data.size(); ++i) {
|
||||||
auto v = data[i];
|
auto v = data[i];
|
||||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user