[rabit] Improved connection handling. (#9531)

- Enable timeout.
- Report connection error from the system.
- Handle retry for both tracker connection and peer connection.
This commit is contained in:
Jiaming Yuan
2023-08-30 13:00:04 +08:00
committed by GitHub
parent 2462e22cd4
commit ccfc90e4c6
10 changed files with 463 additions and 130 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
@@ -9,9 +9,11 @@
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include "allreduce_base.h"
#include "rabit/base.h"
#include "rabit/internal/rabit-inl.h"
#include "allreduce_base.h"
#include "xgboost/collective/result.h"
#ifndef _WIN32
#include <netinet/tcp.h>
@@ -20,8 +22,7 @@
#include <cstring>
#include <map>
namespace rabit {
namespace engine {
namespace rabit::engine {
// constructor
AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
@@ -116,7 +117,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = xgboost::collective::GetHostName();
// get information from tracker
return this->ReConnectLinks();
auto rc = this->ReConnectLinks();
if (rc.OK()) {
return true;
}
LOG(FATAL) << rc.Report();
return false;
}
bool AllreduceBase::Shutdown() {
@@ -131,7 +137,11 @@ bool AllreduceBase::Shutdown() {
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
xgboost::system::SocketFinalize();
@@ -146,7 +156,12 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
@@ -215,64 +230,67 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
xgboost::collective::TCPSocket *out) const {
int magic = kMagic;
// get information from tracker
xgboost::collective::TCPSocket tracker;
xgboost::collective::TCPSocket &tracker = *out;
int retry = 0;
do {
auto rc = xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
&tracker);
if (rc != std::errc()) {
if (++retry >= connect_retry) {
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
} else {
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
<< tracker_uri << "]";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(retry << 1);
#else
sleep(retry << 1);
#endif
continue;
}
}
break;
} while (true);
auto rc =
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
using utils::Assert;
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
return tracker;
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to send the verification number.");
}
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to recieve the verification number.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
}
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
}
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
}
return xgboost::collective::Success();
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
bool AllreduceBase::ReConnectLinks(const char *cmd) {
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0;
world_size = 1;
return true;
return xgboost::collective::Success();
}
try {
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.Send(xgboost::StringView{cmd});
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.Send(xgboost::StringView{cmd});
try {
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
@@ -334,10 +352,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
tracker.Recv(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
if (xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
std::errc{}) {
// connect to peer
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
continue;
@@ -351,8 +369,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
bool match = false;
for (auto & all_link : all_links) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
Assert(all_link.sock.IsClosed(), "Override a link that is active");
all_link.sock = std::move(r.sock);
match = true;
break;
@@ -364,10 +381,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
@@ -395,7 +412,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
for (auto &all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_link.sock.SetNonBlock();
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive();
if (rabit_enable_tcp_no_delay) {
all_link.sock.SetNoDelay();
@@ -415,10 +432,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return true;
return xgboost::collective::Success();
} catch (const std::exception& e) {
LOG(WARNING) << "failed in ReconnectLink " << e.what();
return false;
std::stringstream ss;
ss << "Failed in ReconnectLink " << e.what();
return xgboost::collective::Fail(ss.str());
}
}
/*!
@@ -523,9 +541,15 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// finish running allreduce
if (finished) break;
if (finished) {
break;
}
// select must return
watcher.Poll(timeout_sec);
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
@@ -698,7 +722,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
// finish running
if (finished) break;
// select
watcher.Poll(timeout_sec);
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
@@ -780,8 +807,14 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
}
finished = false;
}
if (finished) break;
watcher.Poll(timeout_sec);
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
@@ -880,8 +913,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
}
finished = false;
}
if (finished) break;
watcher.Poll(timeout_sec);
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
@@ -953,5 +991,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace engine
} // namespace rabit
} // namespace rabit::engine

View File

@@ -12,14 +12,16 @@
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <algorithm>
#include <functional>
#include <future>
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/internal/utils.h"
#include <vector>
#include "rabit/internal/engine.h"
#include "rabit/internal/socket.h"
#include "rabit/internal/utils.h"
#include "xgboost/collective/result.h"
#ifdef RABIT_CXXTESTDEFS_H
#define private public
@@ -329,13 +331,13 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
xgboost::collective::TCPSocket ConnectTracker() const;
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
bool ReConnectLinks(const char *cmd = "start");
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*