[rabit] Improved connection handling. (#9531)

- Enable timeout.
- Report connection error from the system.
- Handle retry for both tracker connection and peer connection.
This commit is contained in:
Jiaming Yuan 2023-08-30 13:00:04 +08:00 committed by GitHub
parent 2462e22cd4
commit ccfc90e4c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 463 additions and 130 deletions

View File

@ -32,4 +32,3 @@ formats:
python:
install:
- requirements: doc/requirements.txt
system_packages: true

View File

@ -0,0 +1,160 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
#include <string> // for string
#include <utility> // for move
namespace xgboost::collective {
namespace detail {
struct ResultImpl {
std::string message;
std::error_code errc{}; // optional for system error.
std::unique_ptr<ResultImpl> prev{nullptr};
ResultImpl() = delete; // must initialize.
ResultImpl(ResultImpl const& that) = delete;
ResultImpl(ResultImpl&& that) = default;
ResultImpl& operator=(ResultImpl const& that) = delete;
ResultImpl& operator=(ResultImpl&& that) = default;
explicit ResultImpl(std::string msg) : message{std::move(msg)} {}
explicit ResultImpl(std::string msg, std::error_code errc)
: message{std::move(msg)}, errc{std::move(errc)} {}
explicit ResultImpl(std::string msg, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, prev{std::move(prev)} {}
explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {}
[[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) {
if ((prev && !that.prev) || (!prev && that.prev)) {
// one of them doesn't have prev
return false;
}
auto cur_eq = message == that.message && errc == that.errc;
if (prev && that.prev) {
// recursive comparison
auto prev_eq = *prev == *that.prev;
return cur_eq && prev_eq;
}
return cur_eq;
}
[[nodiscard]] std::string Report() {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
};
} // namespace detail
/**
* @brief An error type that's easier to handle than throwing dmlc exception. We can
* record and propagate the system error code.
*/
struct Result {
private:
std::unique_ptr<detail::ResultImpl> impl_{nullptr};
public:
Result() noexcept(true) = default;
explicit Result(std::string msg) : impl_{std::make_unique<detail::ResultImpl>(std::move(msg))} {}
explicit Result(std::string msg, std::error_code errc)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc))} {}
Result(std::string msg, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(prev.impl_))} {}
Result(std::string msg, std::error_code errc, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc),
std::move(prev.impl_))} {}
Result(Result const& that) = delete;
Result& operator=(Result const& that) = delete;
Result(Result&& that) = default;
Result& operator=(Result&& that) = default;
[[nodiscard]] bool OK() const noexcept(true) { return !impl_; }
[[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); }
/**
* @brief Return the root system error. This might return success if there's no system error.
*/
[[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); }
[[nodiscard]] bool operator==(Result const& that) const noexcept(true) {
if (OK() && that.OK()) {
return true;
}
if ((OK() && !that.OK()) || (!OK() && that.OK())) {
return false;
}
return *impl_ == *that.impl_;
}
};
/**
* @brief Return success.
*/
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
/**
* @brief Return failure.
*/
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
/**
* @brief Return failure with `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
return Result{std::move(msg), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
return Result{std::move(msg), std::forward<Result>(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
}
} // namespace xgboost::collective

View File

@ -56,9 +56,10 @@ using ssize_t = int;
#endif // defined(_WIN32)
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/collective/result.h" // for Result
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#if !defined(HOST_NAME_MAX)
#define HOST_NAME_MAX 256 // macos
@ -81,6 +82,10 @@ inline std::int32_t LastError() {
#endif
}
[[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
}
#if defined(__GLIBC__)
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
std::int32_t line = __builtin_LINE(),
@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
#endif // _WIN32
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
return ErrorWouldBlock(errsv);
}
inline void SocketStartup() {
#if defined(_WIN32)
WSADATA wsa_data;
@ -315,23 +324,35 @@ class TCPSocket {
bool IsClosed() const { return handle_ == InvalidSocket(); }
/** \brief get last error code if any */
std::int32_t GetSockError() const {
std::int32_t error = 0;
socklen_t len = sizeof(error);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
return error;
Result GetSockError() const {
std::int32_t optval = 0;
socklen_t len = sizeof(optval);
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
if (ret != 0) {
auto errc = std::error_code{system::LastError(), std::system_category()};
return Fail("Failed to retrieve socket error.", std::move(errc));
}
if (optval != 0) {
auto errc = std::error_code{optval, std::system_category()};
return Fail("Socket error.", std::move(errc));
}
return Success();
}
/** \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
std::int32_t err = GetSockError();
if (err == EBADF || err == EINTR) return true;
if (IsClosed()) {
return true;
}
auto err = GetSockError();
if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
return true;
}
return false;
}
void SetNonBlock() {
bool non_block{true};
void SetNonBlock(bool non_block) {
#if defined(_WIN32)
u_long mode = non_block ? 1 : 0;
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
@ -530,10 +551,20 @@ class TCPSocket {
};
/**
* \brief Connect to remote address, returns the error code if failed (no exception is
* raised so that we can retry).
* @brief Connect to remote address, returns the error code if failed.
*
* @param host Host IP address.
* @param port Connection port.
* @param retry Number of retries to attempt.
* @param timeout Timeout of each connection attempt.
* @param out_conn Output socket if the connection is successful. Value is invalid and undefined if
* the connection failed.
*
* @return Connection status.
*/
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn);
/**
* \brief Get the local host name.

View File

@ -94,6 +94,10 @@ def no_ipv6() -> PytestSkip:
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
def not_linux() -> PytestSkip:
return {"condition": system() != "Linux", "reason": "Linux is required."}
def no_ubjson() -> PytestSkip:
return no_mod("ubjson")

View File

@ -1,10 +1,11 @@
/*!
* Copyright (c) 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file socket.h
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_SOCKET_H_
#define RABIT_INTERNAL_SOCKET_H_
#include "xgboost/collective/result.h"
#include "xgboost/collective/socket.h"
#if defined(_WIN32)
@ -77,7 +78,7 @@ namespace rabit {
namespace utils {
template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
#if defined(_WIN32)
#if IS_MINGW()
@ -135,11 +136,11 @@ struct PollHelper {
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(SOCKET fd) const {
[[nodiscard]] bool CheckRead(SOCKET fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
[[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
return this->CheckRead(socket.Handle());
}
@ -147,19 +148,19 @@ struct PollHelper {
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(SOCKET fd) const {
[[nodiscard]] bool CheckWrite(SOCKET fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
[[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
return this->CheckWrite(socket.Handle());
}
/*!
* \brief perform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \return
/**
* @brief perform poll on the set defined, read, write, exception
*
* @param timeout specify timeout in seconds. Block if negative.
*/
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) {
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
@ -167,9 +168,9 @@ struct PollHelper {
}
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
LOG(FATAL) << "Poll timeout";
return xgboost::collective::Fail("Poll timeout.");
} else if (ret < 0) {
LOG(FATAL) << "Failed to poll.";
return xgboost::system::FailWithCode("Poll failed.");
} else {
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
@ -180,6 +181,7 @@ struct PollHelper {
}
}
}
return xgboost::collective::Success();
}
std::unordered_map<SOCKET, pollfd> fds;

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
@ -9,9 +9,11 @@
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include "allreduce_base.h"
#include "rabit/base.h"
#include "rabit/internal/rabit-inl.h"
#include "allreduce_base.h"
#include "xgboost/collective/result.h"
#ifndef _WIN32
#include <netinet/tcp.h>
@ -20,8 +22,7 @@
#include <cstring>
#include <map>
namespace rabit {
namespace engine {
namespace rabit::engine {
// constructor
AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
@ -116,7 +117,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = xgboost::collective::GetHostName();
// get information from tracker
return this->ReConnectLinks();
auto rc = this->ReConnectLinks();
if (rc.OK()) {
return true;
}
LOG(FATAL) << rc.Report();
return false;
}
bool AllreduceBase::Shutdown() {
@ -131,7 +137,11 @@ bool AllreduceBase::Shutdown() {
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
xgboost::system::SocketFinalize();
@ -146,7 +156,12 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
@ -215,64 +230,67 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
xgboost::collective::TCPSocket *out) const {
int magic = kMagic;
// get information from tracker
xgboost::collective::TCPSocket tracker;
xgboost::collective::TCPSocket &tracker = *out;
int retry = 0;
do {
auto rc = xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
&tracker);
if (rc != std::errc()) {
if (++retry >= connect_retry) {
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
} else {
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
<< tracker_uri << "]";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(retry << 1);
#else
sleep(retry << 1);
#endif
continue;
}
}
break;
} while (true);
auto rc =
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
using utils::Assert;
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
return tracker;
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to send the verification number.");
}
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
return xgboost::collective::Fail("Failed to recieve the verification number.");
}
if (magic != kMagic) {
return xgboost::collective::Fail("Invalid verification number.");
}
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
}
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
}
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
}
return xgboost::collective::Success();
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
bool AllreduceBase::ReConnectLinks(const char *cmd) {
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0;
world_size = 1;
return true;
return xgboost::collective::Success();
}
try {
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.Send(xgboost::StringView{cmd});
xgboost::collective::TCPSocket tracker;
auto rc = this->ConnectTracker(&tracker);
if (!rc.OK()) {
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
}
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.Send(xgboost::StringView{cmd});
try {
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
@ -334,10 +352,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
tracker.Recv(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
if (xgboost::collective::Connect(
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
std::errc{}) {
// connect to peer
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
continue;
@ -351,8 +369,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
bool match = false;
for (auto & all_link : all_links) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
Assert(all_link.sock.IsClosed(), "Override a link that is active");
all_link.sock = std::move(r.sock);
match = true;
break;
@ -364,10 +381,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
@ -395,7 +412,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
for (auto &all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_link.sock.SetNonBlock();
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive();
if (rabit_enable_tcp_no_delay) {
all_link.sock.SetNoDelay();
@ -415,10 +432,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return true;
return xgboost::collective::Success();
} catch (const std::exception& e) {
LOG(WARNING) << "failed in ReconnectLink " << e.what();
return false;
std::stringstream ss;
ss << "Failed in ReconnectLink " << e.what();
return xgboost::collective::Fail(ss.str());
}
}
/*!
@ -523,9 +541,15 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// finish running allreduce
if (finished) break;
if (finished) {
break;
}
// select must return
watcher.Poll(timeout_sec);
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
@ -698,7 +722,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
// finish running
if (finished) break;
// select
watcher.Poll(timeout_sec);
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
@ -780,8 +807,14 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
}
finished = false;
}
if (finished) break;
watcher.Poll(timeout_sec);
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
@ -880,8 +913,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
}
finished = false;
}
if (finished) break;
watcher.Poll(timeout_sec);
if (finished) {
break;
}
auto poll_res = watcher.Poll(timeout_sec);
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
@ -953,5 +991,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace engine
} // namespace rabit
} // namespace rabit::engine

View File

@ -12,14 +12,16 @@
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <algorithm>
#include <functional>
#include <future>
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/internal/utils.h"
#include <vector>
#include "rabit/internal/engine.h"
#include "rabit/internal/socket.h"
#include "rabit/internal/utils.h"
#include "xgboost/collective/result.h"
#ifdef RABIT_CXXTESTDEFS_H
#define private public
@ -329,13 +331,13 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
xgboost::collective::TCPSocket ConnectTracker() const;
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
bool ReConnectLinks(const char *cmd = "start");
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*

View File

@ -1,19 +1,22 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // getaddrinfo, freeaddrinfo
#endif // defined(__unix__) || defined(__APPLE__)
namespace xgboost {
namespace collective {
namespace xgboost::collective {
SockAddress MakeSockAddress(StringView host, in_port_t port) {
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
return bytes;
}
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn) {
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
auto &conn = *out_conn;
sockaddr const *addr_handle{nullptr};
socklen_t addr_len{0};
if (addr.IsV4()) {
@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
addr_len = sizeof(addr.V6().Handle());
}
auto socket = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
auto rc = connect(socket.Handle(), addr_handle, addr_len);
if (rc != 0) {
return std::error_code{errno, std::system_category()};
conn = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
conn.SetNonBlock(true);
Result last_error;
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
last_error = std::move(err);
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
};
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
if (rc != 0) {
auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
rabit::utils::PollHelper poll;
poll.WatchWrite(conn);
auto result = poll.Poll(timeout);
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
result = conn.GetSockError();
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
conn.SetNonBlock(false);
return Success();
} else {
conn.SetNonBlock(false);
return Success();
}
}
*out = std::move(socket);
return std::make_error_code(std::errc{});
std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port;
conn.Close();
return Fail(ss.str(), std::move(last_error));
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
@ -10,8 +10,7 @@
#include "../helpers.h"
namespace xgboost {
namespace collective {
namespace xgboost::collective {
TEST(Socket, Basic) {
system::SocketStartup();
@ -31,15 +30,16 @@ TEST(Socket, Basic) {
TCPSocket client;
if (domain == SockDomain::kV4) {
auto const& addr = SockAddrV4::Loopback().Addr();
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
ASSERT_TRUE(rc.OK()) << rc.Report();
} else {
auto const& addr = SockAddrV6::Loopback().Addr();
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
// some environment (docker) has restricted network configuration.
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
GTEST_SKIP_(msg.c_str());
}
ASSERT_EQ(rc, std::errc{});
ASSERT_EQ(rc, Success()) << rc.Report();
}
ASSERT_EQ(client.Domain(), domain);
@ -73,5 +73,4 @@ TEST(Socket, Basic) {
system::SocketFinalize();
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@ -20,6 +20,18 @@ def test_rabit_tracker():
assert str(ret) == "test1234"
@pytest.mark.skipif(**tm.not_linux())
def test_socket_error():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
env = tracker.worker_envs()
env["DMLC_TRACKER_PORT"] = 0
env["DMLC_WORKER_CONNECT_RETRY"] = 1
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
with xgb.collective.CommunicatorContext(**env):
pass
def run_rabit_ops(client, n_workers):
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
@ -58,6 +70,32 @@ def test_rabit_ops():
run_rabit_ops(client, n_workers)
def run_broadcast(client):
from xgboost.dask import _get_dask_config, _get_rabit_args
workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
def local_test(worker_id):
with collective.CommunicatorContext(**rabit_args):
res = collective.broadcast(17, 0)
return res
futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
np.testing.assert_allclose(np.array(results), 17)
@pytest.mark.skipif(**tm.no_dask())
def test_broadcast():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_broadcast(client)
@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():