[rabit] Improved connection handling. (#9531)
- Enable timeout. - Report connection error from the system. - Handle retry for both tracker connection and peer connection.
This commit is contained in:
parent
2462e22cd4
commit
ccfc90e4c6
@ -32,4 +32,3 @@ formats:
|
|||||||
python:
|
python:
|
||||||
install:
|
install:
|
||||||
- requirements: doc/requirements.txt
|
- requirements: doc/requirements.txt
|
||||||
system_packages: true
|
|
||||||
|
|||||||
160
include/xgboost/collective/result.h
Normal file
160
include/xgboost/collective/result.h
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <sstream> // for stringstream
|
||||||
|
#include <stack> // for stack
|
||||||
|
#include <string> // for string
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace detail {
|
||||||
|
struct ResultImpl {
|
||||||
|
std::string message;
|
||||||
|
std::error_code errc{}; // optional for system error.
|
||||||
|
|
||||||
|
std::unique_ptr<ResultImpl> prev{nullptr};
|
||||||
|
|
||||||
|
ResultImpl() = delete; // must initialize.
|
||||||
|
ResultImpl(ResultImpl const& that) = delete;
|
||||||
|
ResultImpl(ResultImpl&& that) = default;
|
||||||
|
ResultImpl& operator=(ResultImpl const& that) = delete;
|
||||||
|
ResultImpl& operator=(ResultImpl&& that) = default;
|
||||||
|
|
||||||
|
explicit ResultImpl(std::string msg) : message{std::move(msg)} {}
|
||||||
|
explicit ResultImpl(std::string msg, std::error_code errc)
|
||||||
|
: message{std::move(msg)}, errc{std::move(errc)} {}
|
||||||
|
explicit ResultImpl(std::string msg, std::unique_ptr<ResultImpl> prev)
|
||||||
|
: message{std::move(msg)}, prev{std::move(prev)} {}
|
||||||
|
explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr<ResultImpl> prev)
|
||||||
|
: message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {}
|
||||||
|
|
||||||
|
[[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) {
|
||||||
|
if ((prev && !that.prev) || (!prev && that.prev)) {
|
||||||
|
// one of them doesn't have prev
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cur_eq = message == that.message && errc == that.errc;
|
||||||
|
if (prev && that.prev) {
|
||||||
|
// recursive comparison
|
||||||
|
auto prev_eq = *prev == *that.prev;
|
||||||
|
return cur_eq && prev_eq;
|
||||||
|
}
|
||||||
|
return cur_eq;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string Report() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\n- " << this->message;
|
||||||
|
if (this->errc != std::error_code{}) {
|
||||||
|
ss << " system error:" << this->errc.message();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ptr = prev.get();
|
||||||
|
while (ptr) {
|
||||||
|
ss << "\n- ";
|
||||||
|
ss << ptr->message;
|
||||||
|
|
||||||
|
if (ptr->errc != std::error_code{}) {
|
||||||
|
ss << " " << ptr->errc.message();
|
||||||
|
}
|
||||||
|
ptr = ptr->prev.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
[[nodiscard]] auto Code() const {
|
||||||
|
// Find the root error.
|
||||||
|
std::stack<ResultImpl const*> stack;
|
||||||
|
auto ptr = this;
|
||||||
|
while (ptr) {
|
||||||
|
stack.push(ptr);
|
||||||
|
if (ptr->prev) {
|
||||||
|
ptr = ptr->prev.get();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!stack.empty()) {
|
||||||
|
auto frame = stack.top();
|
||||||
|
stack.pop();
|
||||||
|
if (frame->errc != std::error_code{}) {
|
||||||
|
return frame->errc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::error_code{};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief An error type that's easier to handle than throwing dmlc exception. We can
|
||||||
|
* record and propagate the system error code.
|
||||||
|
*/
|
||||||
|
struct Result {
|
||||||
|
private:
|
||||||
|
std::unique_ptr<detail::ResultImpl> impl_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
Result() noexcept(true) = default;
|
||||||
|
explicit Result(std::string msg) : impl_{std::make_unique<detail::ResultImpl>(std::move(msg))} {}
|
||||||
|
explicit Result(std::string msg, std::error_code errc)
|
||||||
|
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc))} {}
|
||||||
|
Result(std::string msg, Result&& prev)
|
||||||
|
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(prev.impl_))} {}
|
||||||
|
Result(std::string msg, std::error_code errc, Result&& prev)
|
||||||
|
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc),
|
||||||
|
std::move(prev.impl_))} {}
|
||||||
|
|
||||||
|
Result(Result const& that) = delete;
|
||||||
|
Result& operator=(Result const& that) = delete;
|
||||||
|
Result(Result&& that) = default;
|
||||||
|
Result& operator=(Result&& that) = default;
|
||||||
|
|
||||||
|
[[nodiscard]] bool OK() const noexcept(true) { return !impl_; }
|
||||||
|
[[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); }
|
||||||
|
/**
|
||||||
|
* @brief Return the root system error. This might return success if there's no system error.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); }
|
||||||
|
[[nodiscard]] bool operator==(Result const& that) const noexcept(true) {
|
||||||
|
if (OK() && that.OK()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if ((OK() && !that.OK()) || (!OK() && that.OK())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *impl_ == *that.impl_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return success.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
|
||||||
|
/**
|
||||||
|
* @brief Return failure.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
|
||||||
|
/**
|
||||||
|
* @brief Return failure with `errno`.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
|
||||||
|
return Result{std::move(msg), std::move(errc)};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Return failure with a previous error.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
|
||||||
|
return Result{std::move(msg), std::forward<Result>(prev)};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Return failure with a previous error and a new `errno`.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
||||||
|
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -57,6 +57,7 @@ using ssize_t = int;
|
|||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
#include "xgboost/base.h" // XGBOOST_EXPECT
|
#include "xgboost/base.h" // XGBOOST_EXPECT
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/logging.h" // LOG
|
#include "xgboost/logging.h" // LOG
|
||||||
#include "xgboost/string_view.h" // StringView
|
#include "xgboost/string_view.h" // StringView
|
||||||
|
|
||||||
@ -81,6 +82,10 @@ inline std::int32_t LastError() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
|
||||||
|
return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(__GLIBC__)
|
#if defined(__GLIBC__)
|
||||||
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
|
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
|
||||||
std::int32_t line = __builtin_LINE(),
|
std::int32_t line = __builtin_LINE(),
|
||||||
@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool LastErrorWouldBlock() {
|
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||||
int errsv = LastError();
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return errsv == WSAEWOULDBLOCK;
|
return errsv == WSAEWOULDBLOCK;
|
||||||
#else
|
#else
|
||||||
return errsv == EAGAIN || errsv == EWOULDBLOCK;
|
return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool LastErrorWouldBlock() {
|
||||||
|
int errsv = LastError();
|
||||||
|
return ErrorWouldBlock(errsv);
|
||||||
|
}
|
||||||
|
|
||||||
inline void SocketStartup() {
|
inline void SocketStartup() {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
WSADATA wsa_data;
|
WSADATA wsa_data;
|
||||||
@ -315,23 +324,35 @@ class TCPSocket {
|
|||||||
bool IsClosed() const { return handle_ == InvalidSocket(); }
|
bool IsClosed() const { return handle_ == InvalidSocket(); }
|
||||||
|
|
||||||
/** \brief get last error code if any */
|
/** \brief get last error code if any */
|
||||||
std::int32_t GetSockError() const {
|
Result GetSockError() const {
|
||||||
std::int32_t error = 0;
|
std::int32_t optval = 0;
|
||||||
socklen_t len = sizeof(error);
|
socklen_t len = sizeof(optval);
|
||||||
xgboost_CHECK_SYS_CALL(
|
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
|
||||||
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
|
if (ret != 0) {
|
||||||
return error;
|
auto errc = std::error_code{system::LastError(), std::system_category()};
|
||||||
|
return Fail("Failed to retrieve socket error.", std::move(errc));
|
||||||
}
|
}
|
||||||
|
if (optval != 0) {
|
||||||
|
auto errc = std::error_code{optval, std::system_category()};
|
||||||
|
return Fail("Socket error.", std::move(errc));
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
/** \brief check if anything bad happens */
|
/** \brief check if anything bad happens */
|
||||||
bool BadSocket() const {
|
bool BadSocket() const {
|
||||||
if (IsClosed()) return true;
|
if (IsClosed()) {
|
||||||
std::int32_t err = GetSockError();
|
return true;
|
||||||
if (err == EBADF || err == EINTR) return true;
|
}
|
||||||
|
auto err = GetSockError();
|
||||||
|
if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
|
||||||
|
err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetNonBlock() {
|
void SetNonBlock(bool non_block) {
|
||||||
bool non_block{true};
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
u_long mode = non_block ? 1 : 0;
|
u_long mode = non_block ? 1 : 0;
|
||||||
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
|
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
|
||||||
@ -530,10 +551,20 @@ class TCPSocket {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Connect to remote address, returns the error code if failed (no exception is
|
* @brief Connect to remote address, returns the error code if failed.
|
||||||
* raised so that we can retry).
|
*
|
||||||
|
* @param host Host IP address.
|
||||||
|
* @param port Connection port.
|
||||||
|
* @param retry Number of retries to attempt.
|
||||||
|
* @param timeout Timeout of each connection attempt.
|
||||||
|
* @param out_conn Output socket if the connection is successful. Value is invalid and undefined if
|
||||||
|
* the connection failed.
|
||||||
|
*
|
||||||
|
* @return Connection status.
|
||||||
*/
|
*/
|
||||||
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
|
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||||
|
std::chrono::seconds timeout,
|
||||||
|
xgboost::collective::TCPSocket *out_conn);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get the local host name.
|
* \brief Get the local host name.
|
||||||
|
|||||||
@ -94,6 +94,10 @@ def no_ipv6() -> PytestSkip:
|
|||||||
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
|
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}
|
||||||
|
|
||||||
|
|
||||||
|
def not_linux() -> PytestSkip:
|
||||||
|
return {"condition": system() != "Linux", "reason": "Linux is required."}
|
||||||
|
|
||||||
|
|
||||||
def no_ubjson() -> PytestSkip:
|
def no_ubjson() -> PytestSkip:
|
||||||
return no_mod("ubjson")
|
return no_mod("ubjson")
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file socket.h
|
* \file socket.h
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_INTERNAL_SOCKET_H_
|
#ifndef RABIT_INTERNAL_SOCKET_H_
|
||||||
#define RABIT_INTERNAL_SOCKET_H_
|
#define RABIT_INTERNAL_SOCKET_H_
|
||||||
|
#include "xgboost/collective/result.h"
|
||||||
#include "xgboost/collective/socket.h"
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
@ -77,7 +78,7 @@ namespace rabit {
|
|||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
template <typename PollFD>
|
template <typename PollFD>
|
||||||
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
|
||||||
#if IS_MINGW()
|
#if IS_MINGW()
|
||||||
@ -135,11 +136,11 @@ struct PollHelper {
|
|||||||
* \brief Check if the descriptor is ready for read
|
* \brief Check if the descriptor is ready for read
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
*/
|
*/
|
||||||
inline bool CheckRead(SOCKET fd) const {
|
[[nodiscard]] bool CheckRead(SOCKET fd) const {
|
||||||
const auto& pfd = fds.find(fd);
|
const auto& pfd = fds.find(fd);
|
||||||
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
|
||||||
}
|
}
|
||||||
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
|
[[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
|
||||||
return this->CheckRead(socket.Handle());
|
return this->CheckRead(socket.Handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,19 +148,19 @@ struct PollHelper {
|
|||||||
* \brief Check if the descriptor is ready for write
|
* \brief Check if the descriptor is ready for write
|
||||||
* \param fd file descriptor to check status
|
* \param fd file descriptor to check status
|
||||||
*/
|
*/
|
||||||
inline bool CheckWrite(SOCKET fd) const {
|
[[nodiscard]] bool CheckWrite(SOCKET fd) const {
|
||||||
const auto& pfd = fds.find(fd);
|
const auto& pfd = fds.find(fd);
|
||||||
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||||
}
|
}
|
||||||
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
|
[[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
|
||||||
return this->CheckWrite(socket.Handle());
|
return this->CheckWrite(socket.Handle());
|
||||||
}
|
}
|
||||||
/*!
|
/**
|
||||||
* \brief perform poll on the set defined, read, write, exception
|
* @brief perform poll on the set defined, read, write, exception
|
||||||
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
*
|
||||||
* \return
|
* @param timeout specify timeout in seconds. Block if negative.
|
||||||
*/
|
*/
|
||||||
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
|
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) {
|
||||||
std::vector<pollfd> fdset;
|
std::vector<pollfd> fdset;
|
||||||
fdset.reserve(fds.size());
|
fdset.reserve(fds.size());
|
||||||
for (auto kv : fds) {
|
for (auto kv : fds) {
|
||||||
@ -167,9 +168,9 @@ struct PollHelper {
|
|||||||
}
|
}
|
||||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
LOG(FATAL) << "Poll timeout";
|
return xgboost::collective::Fail("Poll timeout.");
|
||||||
} else if (ret < 0) {
|
} else if (ret < 0) {
|
||||||
LOG(FATAL) << "Failed to poll.";
|
return xgboost::system::FailWithCode("Poll failed.");
|
||||||
} else {
|
} else {
|
||||||
for (auto& pfd : fdset) {
|
for (auto& pfd : fdset) {
|
||||||
auto revents = pfd.revents & pfd.events;
|
auto revents = pfd.revents & pfd.events;
|
||||||
@ -180,6 +181,7 @@ struct PollHelper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return xgboost::collective::Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<SOCKET, pollfd> fds;
|
std::unordered_map<SOCKET, pollfd> fds;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2014 by Contributors
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file allreduce_base.cc
|
* \file allreduce_base.cc
|
||||||
* \brief Basic implementation of AllReduce
|
* \brief Basic implementation of AllReduce
|
||||||
*
|
*
|
||||||
@ -9,9 +9,11 @@
|
|||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
#endif // !defined(NOMINMAX)
|
#endif // !defined(NOMINMAX)
|
||||||
|
|
||||||
|
#include "allreduce_base.h"
|
||||||
|
|
||||||
#include "rabit/base.h"
|
#include "rabit/base.h"
|
||||||
#include "rabit/internal/rabit-inl.h"
|
#include "rabit/internal/rabit-inl.h"
|
||||||
#include "allreduce_base.h"
|
#include "xgboost/collective/result.h"
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
@ -20,8 +22,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit::engine {
|
||||||
namespace engine {
|
|
||||||
// constructor
|
// constructor
|
||||||
AllreduceBase::AllreduceBase() {
|
AllreduceBase::AllreduceBase() {
|
||||||
tracker_uri = "NULL";
|
tracker_uri = "NULL";
|
||||||
@ -116,7 +117,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||||
this->host_uri = xgboost::collective::GetHostName();
|
this->host_uri = xgboost::collective::GetHostName();
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
return this->ReConnectLinks();
|
auto rc = this->ReConnectLinks();
|
||||||
|
if (rc.OK()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
LOG(FATAL) << rc.Report();
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AllreduceBase::Shutdown() {
|
bool AllreduceBase::Shutdown() {
|
||||||
@ -131,7 +137,11 @@ bool AllreduceBase::Shutdown() {
|
|||||||
|
|
||||||
if (tracker_uri == "NULL") return true;
|
if (tracker_uri == "NULL") return true;
|
||||||
// notify tracker rank i have shutdown
|
// notify tracker rank i have shutdown
|
||||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
xgboost::collective::TCPSocket tracker;
|
||||||
|
auto rc = this->ConnectTracker(&tracker);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(FATAL) << rc.Report();
|
||||||
|
}
|
||||||
tracker.Send(xgboost::StringView{"shutdown"});
|
tracker.Send(xgboost::StringView{"shutdown"});
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
xgboost::system::SocketFinalize();
|
xgboost::system::SocketFinalize();
|
||||||
@ -146,7 +156,12 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
utils::Printf("%s", msg.c_str()); return;
|
utils::Printf("%s", msg.c_str()); return;
|
||||||
}
|
}
|
||||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
xgboost::collective::TCPSocket tracker;
|
||||||
|
auto rc = this->ConnectTracker(&tracker);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(FATAL) << rc.Report();
|
||||||
|
}
|
||||||
|
|
||||||
tracker.Send(xgboost::StringView{"print"});
|
tracker.Send(xgboost::StringView{"print"});
|
||||||
tracker.Send(xgboost::StringView{msg});
|
tracker.Send(xgboost::StringView{msg});
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
@ -215,64 +230,67 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const {
|
[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker(
|
||||||
|
xgboost::collective::TCPSocket *out) const {
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
xgboost::collective::TCPSocket tracker;
|
xgboost::collective::TCPSocket &tracker = *out;
|
||||||
|
|
||||||
int retry = 0;
|
auto rc =
|
||||||
do {
|
Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker);
|
||||||
auto rc = xgboost::collective::Connect(
|
if (!rc.OK()) {
|
||||||
xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port),
|
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||||
&tracker);
|
|
||||||
if (rc != std::errc()) {
|
|
||||||
if (++retry >= connect_retry) {
|
|
||||||
LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message();
|
|
||||||
} else {
|
|
||||||
LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): ["
|
|
||||||
<< tracker_uri << "]";
|
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
|
||||||
Sleep(retry << 1);
|
|
||||||
#else
|
|
||||||
sleep(retry << 1);
|
|
||||||
#endif
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
break;
|
|
||||||
} while (true);
|
|
||||||
|
|
||||||
using utils::Assert;
|
using utils::Assert;
|
||||||
CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic));
|
if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||||
CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic));
|
return xgboost::collective::Fail("Failed to send the verification number.");
|
||||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
}
|
||||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) {
|
||||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
return xgboost::collective::Fail("Failed to recieve the verification number.");
|
||||||
"ReConnectLink failure 3");
|
}
|
||||||
CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size());
|
if (magic != kMagic) {
|
||||||
return tracker;
|
return xgboost::collective::Fail("Invalid verification number.");
|
||||||
|
}
|
||||||
|
if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) {
|
||||||
|
return xgboost::collective::Fail("Failed to send the local rank back to the tracker.");
|
||||||
|
}
|
||||||
|
if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) {
|
||||||
|
return xgboost::collective::Fail("Failed to send the world size back to the tracker.");
|
||||||
|
}
|
||||||
|
if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) {
|
||||||
|
return xgboost::collective::Fail("Failed to send the task ID back to the tracker.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return xgboost::collective::Success();
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
*/
|
*/
|
||||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||||
// single node mode
|
// single node mode
|
||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
rank = 0;
|
rank = 0;
|
||||||
world_size = 1;
|
world_size = 1;
|
||||||
return true;
|
return xgboost::collective::Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
xgboost::collective::TCPSocket tracker;
|
||||||
|
auto rc = this->ConnectTracker(&tracker);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
|
||||||
xgboost::collective::TCPSocket tracker = this->ConnectTracker();
|
|
||||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||||
tracker.Send(xgboost::StringView{cmd});
|
tracker.Send(xgboost::StringView{cmd});
|
||||||
|
|
||||||
|
try {
|
||||||
// the rank of previous link, next link in ring
|
// the rank of previous link, next link in ring
|
||||||
int prev_rank, next_rank;
|
int prev_rank, next_rank;
|
||||||
// the rank of neighbors
|
// the rank of neighbors
|
||||||
@ -334,10 +352,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
tracker.Recv(&hname);
|
tracker.Recv(&hname);
|
||||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||||
|
// connect to peer
|
||||||
if (xgboost::collective::Connect(
|
if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry,
|
||||||
xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) !=
|
timeout_sec, &r.sock)
|
||||||
std::errc{}) {
|
.OK()) {
|
||||||
num_error += 1;
|
num_error += 1;
|
||||||
r.sock.Close();
|
r.sock.Close();
|
||||||
continue;
|
continue;
|
||||||
@ -351,8 +369,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
bool match = false;
|
bool match = false;
|
||||||
for (auto & all_link : all_links) {
|
for (auto & all_link : all_links) {
|
||||||
if (all_link.rank == hrank) {
|
if (all_link.rank == hrank) {
|
||||||
Assert(all_link.sock.IsClosed(),
|
Assert(all_link.sock.IsClosed(), "Override a link that is active");
|
||||||
"Override a link that is active");
|
|
||||||
all_link.sock = std::move(r.sock);
|
all_link.sock = std::move(r.sock);
|
||||||
match = true;
|
match = true;
|
||||||
break;
|
break;
|
||||||
@ -364,10 +381,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
"ReConnectLink failure 14");
|
"ReConnectLink failure 14");
|
||||||
} while (num_error != 0);
|
} while (num_error != 0);
|
||||||
// send back socket listening port to tracker
|
// send back socket listening port to tracker
|
||||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||||
"ReConnectLink failure 14");
|
|
||||||
// close connection to tracker
|
// close connection to tracker
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
|
|
||||||
// listen to incoming links
|
// listen to incoming links
|
||||||
for (int i = 0; i < num_accept; ++i) {
|
for (int i = 0; i < num_accept; ++i) {
|
||||||
LinkRecord r;
|
LinkRecord r;
|
||||||
@ -395,7 +412,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
for (auto &all_link : all_links) {
|
for (auto &all_link : all_links) {
|
||||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_link.sock.SetNonBlock();
|
all_link.sock.SetNonBlock(true);
|
||||||
all_link.sock.SetKeepAlive();
|
all_link.sock.SetKeepAlive();
|
||||||
if (rabit_enable_tcp_no_delay) {
|
if (rabit_enable_tcp_no_delay) {
|
||||||
all_link.sock.SetNoDelay();
|
all_link.sock.SetNoDelay();
|
||||||
@ -415,10 +432,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
"cannot find prev ring in the link");
|
"cannot find prev ring in the link");
|
||||||
Assert(next_rank == -1 || ring_next != nullptr,
|
Assert(next_rank == -1 || ring_next != nullptr,
|
||||||
"cannot find next ring in the link");
|
"cannot find next ring in the link");
|
||||||
return true;
|
return xgboost::collective::Success();
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG(WARNING) << "failed in ReconnectLink " << e.what();
|
std::stringstream ss;
|
||||||
return false;
|
ss << "Failed in ReconnectLink " << e.what();
|
||||||
|
return xgboost::collective::Fail(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -523,9 +541,15 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// finish running allreduce
|
// finish running allreduce
|
||||||
if (finished) break;
|
if (finished) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
// select must return
|
// select must return
|
||||||
watcher.Poll(timeout_sec);
|
auto poll_res = watcher.Poll(timeout_sec);
|
||||||
|
if (!poll_res.OK()) {
|
||||||
|
LOG(FATAL) << poll_res.Report();
|
||||||
|
}
|
||||||
|
|
||||||
// read data from childs
|
// read data from childs
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
||||||
@ -698,7 +722,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|||||||
// finish running
|
// finish running
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
// select
|
// select
|
||||||
watcher.Poll(timeout_sec);
|
auto poll_res = watcher.Poll(timeout_sec);
|
||||||
|
if (!poll_res.OK()) {
|
||||||
|
LOG(FATAL) << poll_res.Report();
|
||||||
|
}
|
||||||
if (in_link == -2) {
|
if (in_link == -2) {
|
||||||
// probe in-link
|
// probe in-link
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
@ -780,8 +807,14 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
|||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) {
|
||||||
watcher.Poll(timeout_sec);
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto poll_res = watcher.Poll(timeout_sec);
|
||||||
|
if (!poll_res.OK()) {
|
||||||
|
LOG(FATAL) << poll_res.Report();
|
||||||
|
}
|
||||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||||
size_t size = stop_read - read_ptr;
|
size_t size = stop_read - read_ptr;
|
||||||
size_t start = read_ptr % total_size;
|
size_t start = read_ptr % total_size;
|
||||||
@ -880,8 +913,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) {
|
||||||
watcher.Poll(timeout_sec);
|
break;
|
||||||
|
}
|
||||||
|
auto poll_res = watcher.Poll(timeout_sec);
|
||||||
|
if (!poll_res.OK()) {
|
||||||
|
LOG(FATAL) << poll_res.Report();
|
||||||
|
}
|
||||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
@ -953,5 +991,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
|
|||||||
(std::min((prank + 1) * step, count) -
|
(std::min((prank + 1) * step, count) -
|
||||||
std::min(prank * step, count)) * type_nbytes);
|
std::min(prank * step, count)) * type_nbytes);
|
||||||
}
|
}
|
||||||
} // namespace engine
|
} // namespace rabit::engine
|
||||||
} // namespace rabit
|
|
||||||
|
|||||||
@ -12,14 +12,16 @@
|
|||||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||||
#define RABIT_ALLREDUCE_BASE_H_
|
#define RABIT_ALLREDUCE_BASE_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <vector>
|
||||||
#include "rabit/internal/utils.h"
|
|
||||||
#include "rabit/internal/engine.h"
|
#include "rabit/internal/engine.h"
|
||||||
#include "rabit/internal/socket.h"
|
#include "rabit/internal/socket.h"
|
||||||
|
#include "rabit/internal/utils.h"
|
||||||
|
#include "xgboost/collective/result.h"
|
||||||
|
|
||||||
#ifdef RABIT_CXXTESTDEFS_H
|
#ifdef RABIT_CXXTESTDEFS_H
|
||||||
#define private public
|
#define private public
|
||||||
@ -329,13 +331,13 @@ class AllreduceBase : public IEngine {
|
|||||||
* \brief initialize connection to the tracker
|
* \brief initialize connection to the tracker
|
||||||
* \return a socket that initializes the connection
|
* \return a socket that initializes the connection
|
||||||
*/
|
*/
|
||||||
xgboost::collective::TCPSocket ConnectTracker() const;
|
[[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const;
|
||||||
/*!
|
/*!
|
||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
* \param cmd possible command to sent to tracker
|
* \param cmd possible command to sent to tracker
|
||||||
*/
|
*/
|
||||||
bool ReConnectLinks(const char *cmd = "start");
|
[[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start");
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
|
|||||||
@ -1,19 +1,22 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/collective/socket.h"
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cstring> // std::memcpy, std::memset
|
#include <cstring> // std::memcpy, std::memset
|
||||||
|
#include <filesystem> // for path
|
||||||
#include <system_error> // std::error_code, std::system_category
|
#include <system_error> // std::error_code, std::system_category
|
||||||
|
|
||||||
|
#include "rabit/internal/socket.h" // for PollHelper
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
|
||||||
#if defined(__unix__) || defined(__APPLE__)
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
||||||
#endif // defined(__unix__) || defined(__APPLE__)
|
#endif // defined(__unix__) || defined(__APPLE__)
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::collective {
|
||||||
namespace collective {
|
|
||||||
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
||||||
struct addrinfo hints;
|
struct addrinfo hints;
|
||||||
std::memset(&hints, 0, sizeof(hints));
|
std::memset(&hints, 0, sizeof(hints));
|
||||||
@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||||
|
std::chrono::seconds timeout,
|
||||||
|
xgboost::collective::TCPSocket *out_conn) {
|
||||||
|
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
|
||||||
|
auto &conn = *out_conn;
|
||||||
|
|
||||||
sockaddr const *addr_handle{nullptr};
|
sockaddr const *addr_handle{nullptr};
|
||||||
socklen_t addr_len{0};
|
socklen_t addr_len{0};
|
||||||
if (addr.IsV4()) {
|
if (addr.IsV4()) {
|
||||||
@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
|||||||
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
||||||
addr_len = sizeof(addr.V6().Handle());
|
addr_len = sizeof(addr.V6().Handle());
|
||||||
}
|
}
|
||||||
auto socket = TCPSocket::Create(addr.Domain());
|
|
||||||
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
conn = TCPSocket::Create(addr.Domain());
|
||||||
auto rc = connect(socket.Handle(), addr_handle, addr_len);
|
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||||
if (rc != 0) {
|
conn.SetNonBlock(true);
|
||||||
return std::error_code{errno, std::system_category()};
|
|
||||||
|
Result last_error;
|
||||||
|
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
||||||
|
last_error = std::move(err);
|
||||||
|
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||||
|
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||||
|
if (attempt > 0) {
|
||||||
|
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||||
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
|
Sleep(attempt << 1);
|
||||||
|
#else
|
||||||
|
sleep(attempt << 1);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
*out = std::move(socket);
|
|
||||||
return std::make_error_code(std::errc{});
|
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||||
|
if (rc != 0) {
|
||||||
|
auto errcode = system::LastError();
|
||||||
|
if (!system::ErrorWouldBlock(errcode)) {
|
||||||
|
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||||
|
__FILE__, __LINE__);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
rabit::utils::PollHelper poll;
|
||||||
|
poll.WatchWrite(conn);
|
||||||
|
auto result = poll.Poll(timeout);
|
||||||
|
if (!result.OK()) {
|
||||||
|
log_failure(std::move(result), __FILE__, __LINE__);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!poll.CheckWrite(conn)) {
|
||||||
|
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
|
||||||
|
__FILE__, __LINE__);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result = conn.GetSockError();
|
||||||
|
if (!result.OK()) {
|
||||||
|
log_failure(std::move(result), __FILE__, __LINE__);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetNonBlock(false);
|
||||||
|
return Success();
|
||||||
|
|
||||||
|
} else {
|
||||||
|
conn.SetNonBlock(false);
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Failed to connect to " << host << ":" << port;
|
||||||
|
conn.Close();
|
||||||
|
return Fail(ss.str(), std::move(last_error));
|
||||||
}
|
}
|
||||||
} // namespace collective
|
} // namespace xgboost::collective
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2022 by XGBoost Contributors
|
* Copyright 2022-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/collective/socket.h>
|
#include <xgboost/collective/socket.h>
|
||||||
@ -10,8 +10,7 @@
|
|||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::collective {
|
||||||
namespace collective {
|
|
||||||
TEST(Socket, Basic) {
|
TEST(Socket, Basic) {
|
||||||
system::SocketStartup();
|
system::SocketStartup();
|
||||||
|
|
||||||
@ -31,15 +30,16 @@ TEST(Socket, Basic) {
|
|||||||
TCPSocket client;
|
TCPSocket client;
|
||||||
if (domain == SockDomain::kV4) {
|
if (domain == SockDomain::kV4) {
|
||||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||||
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
|
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||||
|
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||||
} else {
|
} else {
|
||||||
auto const& addr = SockAddrV6::Loopback().Addr();
|
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||||
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
|
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||||
// some environment (docker) has restricted network configuration.
|
// some environment (docker) has restricted network configuration.
|
||||||
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
||||||
GTEST_SKIP_(msg.c_str());
|
GTEST_SKIP_(msg.c_str());
|
||||||
}
|
}
|
||||||
ASSERT_EQ(rc, std::errc{});
|
ASSERT_EQ(rc, Success()) << rc.Report();
|
||||||
}
|
}
|
||||||
ASSERT_EQ(client.Domain(), domain);
|
ASSERT_EQ(client.Domain(), domain);
|
||||||
|
|
||||||
@ -73,5 +73,4 @@ TEST(Socket, Basic) {
|
|||||||
|
|
||||||
system::SocketFinalize();
|
system::SocketFinalize();
|
||||||
}
|
}
|
||||||
} // namespace collective
|
} // namespace xgboost::collective
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -20,6 +20,18 @@ def test_rabit_tracker():
|
|||||||
assert str(ret) == "test1234"
|
assert str(ret) == "test1234"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.not_linux())
|
||||||
|
def test_socket_error():
|
||||||
|
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||||
|
tracker.start(1)
|
||||||
|
env = tracker.worker_envs()
|
||||||
|
env["DMLC_TRACKER_PORT"] = 0
|
||||||
|
env["DMLC_WORKER_CONNECT_RETRY"] = 1
|
||||||
|
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
|
||||||
|
with xgb.collective.CommunicatorContext(**env):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def run_rabit_ops(client, n_workers):
|
def run_rabit_ops(client, n_workers):
|
||||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||||
|
|
||||||
@ -58,6 +70,32 @@ def test_rabit_ops():
|
|||||||
run_rabit_ops(client, n_workers)
|
run_rabit_ops(client, n_workers)
|
||||||
|
|
||||||
|
|
||||||
|
def run_broadcast(client):
|
||||||
|
from xgboost.dask import _get_dask_config, _get_rabit_args
|
||||||
|
|
||||||
|
workers = tm.get_client_workers(client)
|
||||||
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||||
|
|
||||||
|
def local_test(worker_id):
|
||||||
|
with collective.CommunicatorContext(**rabit_args):
|
||||||
|
res = collective.broadcast(17, 0)
|
||||||
|
return res
|
||||||
|
|
||||||
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||||
|
results = client.gather(futures)
|
||||||
|
np.testing.assert_allclose(np.array(results), 17)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
|
def test_broadcast():
|
||||||
|
from distributed import Client, LocalCluster
|
||||||
|
|
||||||
|
n_workers = 3
|
||||||
|
with LocalCluster(n_workers=n_workers) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
run_broadcast(client)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_ipv6())
|
@pytest.mark.skipif(**tm.no_ipv6())
|
||||||
@pytest.mark.skipif(**tm.no_dask())
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
def test_rabit_ops_ipv6():
|
def test_rabit_ops_ipv6():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user