[rabit] Improved connection handling. (#9531)
- Enable timeout. - Report connection error from the system. - Handle retry for both tracker connection and peer connection.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* \file allreduce_base.cc
|
||||
* \brief Basic implementation of AllReduce
|
||||
*
|
||||
@@ -9,9 +9,11 @@
|
||||
#define NOMINMAX
|
||||
#endif // !defined(NOMINMAX)
|
||||
|
||||
#include "allreduce_base.h"
|
||||
|
||||
#include "rabit/base.h"
|
||||
#include "rabit/internal/rabit-inl.h"
|
||||
#include "allreduce_base.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <netinet/tcp.h>
|
||||
@@ -20,8 +22,7 @@
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
namespace rabit::engine {
|
||||
// constructor
|
||||
AllreduceBase::AllreduceBase() {
|
||||
tracker_uri = "NULL";
|
||||
@@ -116,7 +117,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = xgboost::collective::GetHostName();
|
||||
// get information from tracker
|
||||
return this->ReConnectLinks();
|
||||
auto rc = this->ReConnectLinks();
|
||||
if (rc.OK()) {
|
||||
return true;
|
||||
}
|
||||
LOG(FATAL) << rc.Report();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AllreduceBase::Shutdown() {
|
||||
@@ -131,7 +137,11 @@ bool AllreduceBase::Shutdown() {
|
||||
|
||||
if (tracker_uri == "NULL") return true;
|
||||
// notify tracker rank i have shutdown
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
tracker.Send(xgboost::StringView{"shutdown"});
|
||||
tracker.Close();
|
||||
xgboost::system::SocketFinalize();
|
||||
@@ -146,7 +156,12 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
LOG(FATAL) << rc.Report();
|
||||
}
|
||||
|
||||
tracker.Send(xgboost::StringView{"print"});
|
||||
tracker.Send(xgboost::StringView{msg});
|
||||
tracker.Close();
|
||||
@@ -215,64 +230,67 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
|
||||
xgboost::collective::TCPSocket *out) const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
xgboost::collective::TCPSocket &tracker = *out;
|
||||
|
||||
int retry = 0;
|
||||
do {
|
||||
auto rc = xgboost::collective::Connect(
|
||||
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
|
||||
&tracker);
|
||||
if (rc != std::errc()) {
|
||||
if (++retry >= connect_retry) {
|
||||
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
|
||||
} else {
|
||||
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
|
||||
<< tracker_uri << "]";
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
Sleep(retry << 1);
|
||||
#else
|
||||
sleep(retry << 1);
|
||||
#endif
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
} while (true);
|
||||
auto rc =
|
||||
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
|
||||
if (!rc.OK()) {
|
||||
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
using utils::Assert;
|
||||
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
|
||||
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 3");
|
||||
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
|
||||
return tracker;
|
||||
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||
return xgboost::collective::Fail("Failed to send the verification number.");
|
||||
}
|
||||
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||
return xgboost::collective::Fail("Failed to recieve the verification number.");
|
||||
}
|
||||
if (magic != kMagic) {
|
||||
return xgboost::collective::Fail("Invalid verification number.");
|
||||
}
|
||||
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
|
||||
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
|
||||
}
|
||||
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
|
||||
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
|
||||
}
|
||||
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
|
||||
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
|
||||
}
|
||||
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0;
|
||||
world_size = 1;
|
||||
return true;
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
|
||||
try {
|
||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||
tracker.Send(xgboost::StringView{cmd});
|
||||
xgboost::collective::TCPSocket tracker;
|
||||
auto rc = this->ConnectTracker(&tracker);
|
||||
if (!rc.OK()) {
|
||||
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||
}
|
||||
|
||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||
tracker.Send(xgboost::StringView{cmd});
|
||||
|
||||
try {
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
@@ -334,10 +352,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
tracker.Recv(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
|
||||
if (xgboost::collective::Connect(
|
||||
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
|
||||
std::errc{}) {
|
||||
// connect to peer
|
||||
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
|
||||
timeout_sec, &r.sock)
|
||||
.OK()) {
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
continue;
|
||||
@@ -351,8 +369,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
bool match = false;
|
||||
for (auto & all_link : all_links) {
|
||||
if (all_link.rank == hrank) {
|
||||
Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
Assert(all_link.sock.IsClosed(), "Override a link that is active");
|
||||
all_link.sock = std::move(r.sock);
|
||||
match = true;
|
||||
break;
|
||||
@@ -364,10 +381,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||
"ReConnectLink failure 14");
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
@@ -395,7 +412,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
for (auto &all_link : all_links) {
|
||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_link.sock.SetNonBlock();
|
||||
all_link.sock.SetNonBlock(true);
|
||||
all_link.sock.SetKeepAlive();
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
all_link.sock.SetNoDelay();
|
||||
@@ -415,10 +432,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != nullptr,
|
||||
"cannot find next ring in the link");
|
||||
return true;
|
||||
return xgboost::collective::Success();
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "failed in ReconnectLink " << e.what();
|
||||
return false;
|
||||
std::stringstream ss;
|
||||
ss << "Failed in ReconnectLink " << e.what();
|
||||
return xgboost::collective::Fail(ss.str());
|
||||
}
|
||||
}
|
||||
/*!
|
||||
@@ -523,9 +541,15 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
}
|
||||
}
|
||||
// finish running allreduce
|
||||
if (finished) break;
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
// select must return
|
||||
watcher.Poll(timeout_sec);
|
||||
auto poll_res = watcher.Poll(timeout_sec);
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
||||
@@ -698,7 +722,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
// finish running
|
||||
if (finished) break;
|
||||
// select
|
||||
watcher.Poll(timeout_sec);
|
||||
auto poll_res = watcher.Poll(timeout_sec);
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
@@ -780,8 +807,14 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
watcher.Poll(timeout_sec);
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto poll_res = watcher.Poll(timeout_sec);
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
size_t size = stop_read - read_ptr;
|
||||
size_t start = read_ptr % total_size;
|
||||
@@ -880,8 +913,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
watcher.Poll(timeout_sec);
|
||||
if (finished) {
|
||||
break;
|
||||
}
|
||||
auto poll_res = watcher.Poll(timeout_sec);
|
||||
if (!poll_res.OK()) {
|
||||
LOG(FATAL) << poll_res.Report();
|
||||
}
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||
if (ret != kSuccess) {
|
||||
@@ -953,5 +991,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
||||
(std::min((prank + 1) * step, count) -
|
||||
std::min(prank * step, count)) * type_nbytes);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
} // namespace rabit::engine
|
||||
|
||||
@@ -12,14 +12,16 @@
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||
#define RABIT_ALLREDUCE_BASE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "rabit/internal/utils.h"
|
||||
#include <vector>
|
||||
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "rabit/internal/socket.h"
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "xgboost/collective/result.h"
|
||||
|
||||
#ifdef RABIT_CXXTESTDEFS_H
|
||||
#define private public
|
||||
@@ -329,13 +331,13 @@ class AllreduceBase : public IEngine {
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
xgboost::collective::TCPSocket ConnectTracker() const;
|
||||
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
bool ReConnectLinks(const char *cmd = "start");
|
||||
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user