Initial support for IPv6 (#8225)

- Merge rabit socket into XGBoost.
- Dask interface support.
- Add test to the socket.
This commit is contained in:
Jiaming Yuan
2022-09-21 18:06:50 +08:00
committed by GitHub
parent 7d43e74e71
commit b791446623
17 changed files with 924 additions and 595 deletions

View File

@@ -27,8 +27,6 @@ AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
slave_port = 9010;
nport_trial = 1000;
rank = 0;
world_size = -1;
connect_retry = 5;
@@ -114,16 +112,16 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
this->rank = -1;
//---------------------
// start socket
utils::Socket::Startup();
xgboost::system::SocketStartup();
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
this->host_uri = xgboost::collective::GetHostName();
// get information from tracker
return this->ReConnectLinks();
}
bool AllreduceBase::Shutdown() {
try {
for (auto & all_link : all_links) {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
}
@@ -133,12 +131,12 @@ bool AllreduceBase::Shutdown() {
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
utils::TCPSocket::Finalize();
xgboost::system::SocketFinalize();
return true;
} catch (const std::exception& e) {
} catch (std::exception const &e) {
LOG(WARNING) << "Failed to shutdown due to" << e.what();
return false;
}
@@ -148,9 +146,9 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
}
@@ -227,21 +225,23 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker() const {
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
tracker.Create();
xgboost::collective::TCPSocket tracker;
int retry = 0;
do {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
auto rc = xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
&tracker);
if (rc != std::errc()) {
if (++retry >= connect_retry) {
LOG(WARNING) << "Connect to (failed): [" << tracker_uri << "]\n";
utils::Socket::Error("Connect");
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
} else {
LOG(WARNING) << "Retry connect to ip(retry time " << retry << "): [" << tracker_uri << "]\n";
#if defined(_MSC_VER) || defined (__MINGW32__)
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
<< tracker_uri << "]";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(retry << 1);
#else
sleep(retry << 1);
@@ -253,16 +253,13 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
} while (true);
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
return tracker;
}
/*!
@@ -272,12 +269,15 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const {
bool AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return true;
rank = 0;
world_size = 1;
return true;
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.SendStr(std::string(cmd));
tracker.Send(xgboost::StringView{cmd});
// the rank of previous link, next link in ring
int prev_rank, next_rank;
@@ -315,13 +315,9 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
utils::TCPSocket sock_listen;
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
@@ -338,29 +334,27 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (int & i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == \
sizeof(i), "ReConnectLink failure 6");
// tracker construct goodset
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5");
for (int &i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
tracker.Recv(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
if (xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
std::errc{}) {
num_error += 1;
r.sock.Close();
continue;
@@ -376,12 +370,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_link.sock = r.sock;
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.push_back(r);
if (!match) all_links.emplace_back(std::move(r));
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
@@ -404,30 +398,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_link.rank == r.rank) {
utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_link.sock = r.sock;
all_link.sock = std::move(r.sock);
match = true;
break;
}
}
if (!match) all_links.push_back(r);
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (auto & all_link : all_links) {
for (auto &all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true);
all_link.sock.SetNonBlock();
all_link.sock.SetKeepAlive();
if (rabit_enable_tcp_no_delay) {
#if defined(__unix__)
setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
#else
LOG(WARNING) << "tcp no delay is not implemented on non unix platforms";
#endif
all_link.sock.SetNoDelay();
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {

View File

@@ -201,8 +201,8 @@ class AllreduceBase : public IEngine {
}
};
/*! \brief translate errno to return type */
inline static ReturnType Errno2Return() {
int errsv = utils::Socket::GetLastError();
static ReturnType Errno2Return() {
int errsv = xgboost::system::LastError();
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
#ifdef _WIN32
if (errsv == WSAEWOULDBLOCK) return kSuccess;
@@ -215,7 +215,7 @@ class AllreduceBase : public IEngine {
struct LinkRecord {
public:
// socket to get data from/to link
utils::TCPSocket sock;
xgboost::collective::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link
@@ -329,7 +329,7 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker() const;
xgboost::collective::TCPSocket ConnectTracker() const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
@@ -473,8 +473,6 @@ class AllreduceBase : public IEngine {
std::string dmlc_role; // NOLINT
// port of tracker address
int tracker_port; // NOLINT
// port of slave process
int slave_port, nport_trial; // NOLINT
// reduce buffer size
size_t reduce_buffer_size; // NOLINT
// reduction method