[coll] Implement shutdown for tracker and comm. (#10208)
- Force shutdown the tracker. - Implement shutdown notice for error handling thread in comm.
This commit is contained in:
parent
8fb05c8c95
commit
3fbb221fec
@ -1555,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the arguments needed for running workers. This should be called after
|
* @brief Get the arguments needed for running workers. This should be called after
|
||||||
* XGTrackerRun() and XGTrackerWait()
|
* XGTrackerRun().
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
* @param args The arguments returned as a JSON document.
|
* @param args The arguments returned as a JSON document.
|
||||||
@ -1565,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
|||||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Run the tracker.
|
* @brief Start the tracker. The tracker runs in the background and this function returns
|
||||||
|
* once the tracker is started.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
|
* @param config Unused at the moment, preserved for the future.
|
||||||
*
|
*
|
||||||
* @return 0 for success, -1 for failure.
|
* @return 0 for success, -1 for failure.
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
|
||||||
|
* function will block until the tracker task is finished or timeout is reached.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||||
@ -1582,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
|||||||
*
|
*
|
||||||
* @return 0 for success, -1 for failure.
|
* @return 0 for success, -1 for failure.
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
|
||||||
* cannot close properly, manual interruption is required.
|
* tracker is not properly waited, this function will shutdown all connections with
|
||||||
|
* the tracker, potentially leading to undefined behavior.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -124,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::int32_t ShutdownSocket(SocketT fd) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
auto rc = shutdown(fd, SD_BOTH);
|
||||||
|
if (rc != 0 && LastError() == WSANOTINITIALISED) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
auto rc = shutdown(fd, SHUT_RDWR);
|
||||||
|
if (rc != 0 && LastError() == ENOTCONN) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return errsv == WSAEWOULDBLOCK;
|
return errsv == WSAEWOULDBLOCK;
|
||||||
@ -499,36 +514,49 @@ class TCPSocket {
|
|||||||
*/
|
*/
|
||||||
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
||||||
/**
|
/**
|
||||||
* \brief Listen to incoming requests. Should be called after bind.
|
* @brief Listen to incoming requests. Should be called after bind.
|
||||||
*/
|
*/
|
||||||
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
|
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
|
||||||
|
if (listen(handle_, backlog) != 0) {
|
||||||
|
return system::FailWithCode("Failed to listen.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] in_port_t BindHost() {
|
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
|
||||||
|
// Use int32 instead of in_port_t for consistency. We take port as parameter from
|
||||||
|
// users using other languages, the port is usually stored and passed around as int.
|
||||||
if (Domain() == SockDomain::kV6) {
|
if (Domain() == SockDomain::kV6) {
|
||||||
auto addr = SockAddrV6::InaddrAny();
|
auto addr = SockAddrV6::InaddrAny();
|
||||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
return system::FailWithCode("bind failed.");
|
||||||
|
}
|
||||||
|
|
||||||
sockaddr_in6 res_addr;
|
sockaddr_in6 res_addr;
|
||||||
socklen_t addrlen = sizeof(res_addr);
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
return system::FailWithCode("getsockname failed.");
|
||||||
return ntohs(res_addr.sin6_port);
|
}
|
||||||
|
*p_out = ntohs(res_addr.sin6_port);
|
||||||
} else {
|
} else {
|
||||||
auto addr = SockAddrV4::InaddrAny();
|
auto addr = SockAddrV4::InaddrAny();
|
||||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
return system::FailWithCode("bind failed.");
|
||||||
|
}
|
||||||
|
|
||||||
sockaddr_in res_addr;
|
sockaddr_in res_addr;
|
||||||
socklen_t addrlen = sizeof(res_addr);
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
return system::FailWithCode("getsockname failed.");
|
||||||
return ntohs(res_addr.sin_port);
|
}
|
||||||
|
*p_out = ntohs(res_addr.sin_port);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] auto Port() const {
|
[[nodiscard]] auto Port() const {
|
||||||
@ -641,13 +669,13 @@ class TCPSocket {
|
|||||||
*/
|
*/
|
||||||
std::size_t Send(StringView str);
|
std::size_t Send(StringView str);
|
||||||
/**
|
/**
|
||||||
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||||
*/
|
*/
|
||||||
std::size_t Recv(std::string *p_str);
|
[[nodiscard]] Result Recv(std::string *p_str);
|
||||||
/**
|
/**
|
||||||
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||||
*/
|
*/
|
||||||
Result Close() {
|
[[nodiscard]] Result Close() {
|
||||||
if (InvalidSocket() != handle_) {
|
if (InvalidSocket() != handle_) {
|
||||||
auto rc = system::CloseSocket(handle_);
|
auto rc = system::CloseSocket(handle_);
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
@ -664,6 +692,25 @@ class TCPSocket {
|
|||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* @brief Call shutdown on the socket.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] Result Shutdown() {
|
||||||
|
if (this->IsClosed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
auto rc = system::ShutdownSocket(this->Handle());
|
||||||
|
#if defined(_WIN32)
|
||||||
|
// Windows cannot shutdown a socket if it's not connected.
|
||||||
|
if (rc == -1 && system::LastError() == WSAENOTCONN) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to shutdown socket.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Create a TCP socket on specified domain.
|
* \brief Create a TCP socket on specified domain.
|
||||||
|
|||||||
@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {
|
|||||||
|
|
||||||
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
||||||
auto rc = this->WaitUntilReady();
|
auto rc = this->WaitUntilReady();
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
rc = GetHostAddress(&host);
|
rc = GetHostAddress(&host);
|
||||||
CHECK(rc.OK());
|
CHECK(rc.OK());
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host};
|
args["dmlc_tracker_uri"] = String{host};
|
||||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
args["dmlc_tracker_port"] = this->Port();
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
|
|||||||
if ((revents & POLLNVAL) != 0) {
|
if ((revents & POLLNVAL) != 0) {
|
||||||
return xgboost::system::FailWithCode("Invalid polling request.");
|
return xgboost::system::FailWithCode("Invalid polling request.");
|
||||||
}
|
}
|
||||||
|
if ((revents & POLLHUP) != 0) {
|
||||||
|
// Excerpt from the Linux manual:
|
||||||
|
//
|
||||||
|
// Note that when reading from a channel such as a pipe or a stream socket, this event
|
||||||
|
// merely indicates that the peer closed its end of the channel.Subsequent reads from
|
||||||
|
// the channel will return 0 (end of file) only after all outstanding data in the
|
||||||
|
// channel has been consumed.
|
||||||
|
//
|
||||||
|
// We don't usually have a barrier for exiting workers, it's normal to have one end
|
||||||
|
// exit while the other still reading data.
|
||||||
|
return xgboost::collective::Success();
|
||||||
|
}
|
||||||
|
#if defined(POLLRDHUP)
|
||||||
|
// Linux only flag
|
||||||
|
if ((revents & POLLRDHUP) != 0) {
|
||||||
|
return xgboost::system::FailWithCode("Poll hung up on the other end.");
|
||||||
|
}
|
||||||
|
#endif // defined(POLLRDHUP)
|
||||||
return xgboost::collective::Success();
|
return xgboost::collective::Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,9 +197,11 @@ struct PollHelper {
|
|||||||
}
|
}
|
||||||
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
|
return xgboost::collective::Fail(
|
||||||
|
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
|
||||||
|
std::make_error_code(std::errc::timed_out));
|
||||||
} else if (ret < 0) {
|
} else if (ret < 0) {
|
||||||
return xgboost::system::FailWithCode("Poll failed.");
|
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& pfd : fdset) {
|
for (auto& pfd : fdset) {
|
||||||
|
|||||||
@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
|
|||||||
try {
|
try {
|
||||||
for (auto &all_link : all_links) {
|
for (auto &all_link : all_links) {
|
||||||
if (!all_link.sock.IsClosed()) {
|
if (!all_link.sock.IsClosed()) {
|
||||||
all_link.sock.Close();
|
SafeColl(all_link.sock.Close());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
all_links.clear();
|
all_links.clear();
|
||||||
@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
|
|||||||
LOG(FATAL) << rc.Report();
|
LOG(FATAL) << rc.Report();
|
||||||
}
|
}
|
||||||
tracker.Send(xgboost::StringView{"shutdown"});
|
tracker.Send(xgboost::StringView{"shutdown"});
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
xgboost::system::SocketFinalize();
|
xgboost::system::SocketFinalize();
|
||||||
return true;
|
return true;
|
||||||
} catch (std::exception const &e) {
|
} catch (std::exception const &e) {
|
||||||
@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
|
|
||||||
tracker.Send(xgboost::StringView{"print"});
|
tracker.Send(xgboost::StringView{"print"});
|
||||||
tracker.Send(xgboost::StringView{msg});
|
tracker.Send(xgboost::StringView{msg});
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
// util to parse data with unit suffix
|
// util to parse data with unit suffix
|
||||||
@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
|
|
||||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||||
// create listening socket
|
// create listening socket
|
||||||
int port = sock_listen.BindHost();
|
std::int32_t port{0};
|
||||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
SafeColl(sock_listen.BindHost(&port));
|
||||||
sock_listen.Listen();
|
SafeColl(sock_listen.Listen());
|
||||||
|
|
||||||
// get number of to connect and number of to accept nodes from tracker
|
// get number of to connect and number of to accept nodes from tracker
|
||||||
int num_conn, num_accept, num_error = 1;
|
int num_conn, num_accept, num_error = 1;
|
||||||
do {
|
do {
|
||||||
for (auto & all_link : all_links) {
|
for (auto & all_link : all_links) {
|
||||||
all_link.sock.Close();
|
SafeColl(all_link.sock.Close());
|
||||||
}
|
}
|
||||||
// tracker construct goodset
|
// tracker construct goodset
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
LinkRecord r;
|
LinkRecord r;
|
||||||
int hport, hrank;
|
int hport, hrank;
|
||||||
std::string hname;
|
std::string hname;
|
||||||
tracker.Recv(&hname);
|
SafeColl(tracker.Recv(&hname));
|
||||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||||
// connect to peer
|
// connect to peer
|
||||||
@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
timeout_sec, &r.sock)
|
timeout_sec, &r.sock)
|
||||||
.OK()) {
|
.OK()) {
|
||||||
num_error += 1;
|
num_error += 1;
|
||||||
r.sock.Close();
|
SafeColl(r.sock.Close());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||||
@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
// send back socket listening port to tracker
|
// send back socket listening port to tracker
|
||||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||||
// close connection to tracker
|
// close connection to tracker
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
|
|
||||||
// listen to incoming links
|
// listen to incoming links
|
||||||
for (int i = 0; i < num_accept; ++i) {
|
for (int i = 0; i < num_accept; ++i) {
|
||||||
@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
if (!match) all_links.emplace_back(std::move(r));
|
if (!match) all_links.emplace_back(std::move(r));
|
||||||
}
|
}
|
||||||
sock_listen.Close();
|
SafeColl(sock_listen.Close());
|
||||||
|
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
|
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[parent_index].sock.Close();
|
SafeColl(links[parent_index].sock.Close());
|
||||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
|
|||||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return kRecvZeroLen;
|
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return Errno2Return();
|
if (len == -1) return Errno2Return();
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
|
|||||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return kRecvZeroLen;
|
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return Errno2Return();
|
if (len == -1) return Errno2Return();
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
|
|||||||
@ -5,9 +5,11 @@
|
|||||||
#include <future> // for future
|
#include <future> // for future
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
#include <thread> // for sleep_for
|
||||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||||
#include <utility> // for pair
|
#include <utility> // for pair
|
||||||
|
|
||||||
|
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||||
#include "../collective/tracker.h" // for RabitTracker
|
#include "../collective/tracker.h" // for RabitTracker
|
||||||
#include "../common/timer.h" // for Timer
|
#include "../common/timer.h" // for Timer
|
||||||
#include "c_api_error.h" // for API_BEGIN
|
#include "c_api_error.h" // for API_BEGIN
|
||||||
@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using TrackerHandleT =
|
using TrackerHandleT =
|
||||||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||||
|
|
||||||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||||
xgboost_CHECK_C_ARG_PTR(handle);
|
xgboost_CHECK_C_ARG_PTR(handle);
|
||||||
@ -41,12 +43,14 @@ struct CollAPIEntry {
|
|||||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||||
|
|
||||||
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||||
constexpr std::int64_t kDft{60};
|
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||||
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||||
|
|
||||||
common::Timer timer;
|
common::Timer timer;
|
||||||
timer.Start();
|
timer.Start();
|
||||||
|
|
||||||
|
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
|
||||||
|
|
||||||
auto fut = ptr->second;
|
auto fut = ptr->second;
|
||||||
while (fut.valid()) {
|
while (fut.valid()) {
|
||||||
auto res = fut.wait_for(wait_for);
|
auto res = fut.wait_for(wait_for);
|
||||||
@ -72,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
|||||||
Json jconfig = Json::Load(config);
|
Json jconfig = Json::Load(config);
|
||||||
|
|
||||||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||||
std::unique_ptr<collective::Tracker> tptr;
|
std::shared_ptr<collective::Tracker> tptr;
|
||||||
if (type == "federated") {
|
if (type == "federated") {
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << error::NoFederated();
|
LOG(FATAL) << error::NoFederated();
|
||||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||||
} else if (type == "rabit") {
|
} else if (type == "rabit") {
|
||||||
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
tptr = std::make_shared<collective::RabitTracker>(jconfig);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown communicator:" << type;
|
LOG(FATAL) << "Unknown communicator:" << type;
|
||||||
}
|
}
|
||||||
@ -103,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||||
@ -111,13 +115,14 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
xgboost_CHECK_C_ARG_PTR(config);
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
auto jconfig = Json::Load(StringView{config});
|
auto jconfig = Json::Load(StringView{config});
|
||||||
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||||
// interrupt the model training.
|
// interrupt the model training.
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||||
WaitImpl(ptr, std::chrono::seconds{timeout});
|
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||||
API_END();
|
API_END();
|
||||||
@ -125,8 +130,24 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
|||||||
|
|
||||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
using namespace std::chrono_literals; // NOLINT
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
|
ptr->first->Stop();
|
||||||
|
// The wait is not necessary since we just called stop, just reusing the function to do
|
||||||
|
// any potential cleanups.
|
||||||
WaitImpl(ptr, ptr->first->Timeout());
|
WaitImpl(ptr, ptr->first->Timeout());
|
||||||
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
// Make sure no one else is waiting on the tracker.
|
||||||
|
while (!ptr->first.unique()) {
|
||||||
|
auto ela = timer.Duration().count();
|
||||||
|
if (ela > ptr->first->Timeout().count()) {
|
||||||
|
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||||
|
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(64ms);
|
||||||
|
}
|
||||||
delete ptr;
|
delete ptr;
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -38,6 +38,10 @@ bool constexpr IsFloatingPointV() {
|
|||||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||||
auto p_lhs = lhs.data();
|
auto p_lhs = lhs.data();
|
||||||
auto p_out = out.data();
|
auto p_out = out.data();
|
||||||
|
#if defined(__GNUC__) || defined(__clang__)
|
||||||
|
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
|
||||||
|
#pragma omp simd
|
||||||
|
#endif
|
||||||
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
||||||
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,9 +5,11 @@
|
|||||||
|
|
||||||
#include <algorithm> // for copy
|
#include <algorithm> // for copy
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
#include <cstdlib> // for exit
|
#include <cstdlib> // for exit
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
#include <utility> // for move, forward
|
#include <utility> // for move, forward
|
||||||
#if !defined(XGBOOST_USE_NCCL)
|
#if !defined(XGBOOST_USE_NCCL)
|
||||||
#include "../common/common.h" // for AssertNCCLSupport
|
#include "../common/common.h" // for AssertNCCLSupport
|
||||||
@ -184,13 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
namespace {
|
||||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
std::string InitLog(std::string task_id, std::int32_t rank) {
|
||||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
if (task_id.empty()) {
|
||||||
|
return "Rank " + std::to_string(rank);
|
||||||
|
}
|
||||||
|
return "Task " + task_id + " got rank " + std::to_string(rank);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||||
|
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||||
|
StringView nccl_path)
|
||||||
|
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
|
||||||
nccl_path_{std::move(nccl_path)} {
|
nccl_path_{std::move(nccl_path)} {
|
||||||
|
if (this->TrackerInfo().host.empty()) {
|
||||||
|
// Not in a distributed environment.
|
||||||
|
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
|
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
|
||||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
|
this->ResetState();
|
||||||
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -217,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
|
|
||||||
// Start command
|
// Start command
|
||||||
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
||||||
std::int32_t lport = listener.BindHost();
|
std::int32_t lport{0};
|
||||||
listener.Listen();
|
rc = std::move(rc) << [&] {
|
||||||
|
return listener.BindHost(&lport);
|
||||||
|
} << [&] {
|
||||||
|
return listener.Listen();
|
||||||
|
};
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
// create worker for listening to error notice.
|
// create worker for listening to error notice.
|
||||||
auto domain = tracker.Domain();
|
auto domain = tracker.Domain();
|
||||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||||
auto eport = error_sock->BindHost();
|
std::int32_t eport{0};
|
||||||
error_sock->Listen();
|
rc = std::move(rc) << [&] {
|
||||||
|
return error_sock->BindHost(&eport);
|
||||||
|
} << [&] {
|
||||||
|
return error_sock->Listen();
|
||||||
|
};
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
error_port_ = eport;
|
||||||
|
|
||||||
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
||||||
auto conn = error_sock->Accept();
|
TCPSocket conn;
|
||||||
|
SockAddress addr;
|
||||||
|
auto rc = error_sock->Accept(&conn, &addr);
|
||||||
|
// On Linux, a shutdown causes an invalid argument error;
|
||||||
|
if (rc.Code() == std::errc::invalid_argument) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// On Windows, accept returns a closed socket after finalize.
|
// On Windows, accept returns a closed socket after finalize.
|
||||||
if (conn.IsClosed()) {
|
if (conn.IsClosed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// The error signal is from the tracker, while shutdown signal is from the shutdown method
|
||||||
|
// of the RabitComm class (this).
|
||||||
|
bool is_error{false};
|
||||||
|
rc = proto::Error{}.RecvSignal(&conn, &is_error);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(WARNING) << rc.Report();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!is_error) {
|
||||||
|
return; // shutdown
|
||||||
|
}
|
||||||
|
|
||||||
LOG(WARNING) << "Another worker is running into error.";
|
LOG(WARNING) << "Another worker is running into error.";
|
||||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||||
// exit is nicer than abort as the former performs cleanups.
|
// exit is nicer than abort as the former performs cleanups.
|
||||||
@ -239,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
LOG(FATAL) << "abort";
|
LOG(FATAL) << "abort";
|
||||||
#endif
|
#endif
|
||||||
}};
|
}};
|
||||||
|
// The worker thread is detached here to avoid the need to handle it later during
|
||||||
|
// destruction. For C++, if a thread is not joined or detached, it will segfault during
|
||||||
|
// destruction.
|
||||||
error_worker_.detach();
|
error_worker_.detach();
|
||||||
|
|
||||||
proto::Start start;
|
proto::Start start;
|
||||||
@ -251,7 +307,10 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
|
|
||||||
// get ring neighbors
|
// get ring neighbors
|
||||||
std::string snext;
|
std::string snext;
|
||||||
tracker.Recv(&snext);
|
rc = tracker.Recv(&snext);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||||
|
}
|
||||||
auto jnext = Json::Load(StringView{snext});
|
auto jnext = Json::Load(StringView{snext});
|
||||||
|
|
||||||
proto::PeerInfo ninfo{jnext};
|
proto::PeerInfo ninfo{jnext};
|
||||||
@ -268,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
CHECK(this->channels_.empty());
|
CHECK(this->channels_.empty());
|
||||||
for (auto& w : workers) {
|
for (auto& w : workers) {
|
||||||
if (w) {
|
if (w) {
|
||||||
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
rc = std::move(rc) << [&] {
|
||||||
<< [&] { return w->SetKeepAlive(); };
|
return w->SetNoDelay();
|
||||||
|
} << [&] {
|
||||||
|
return w->NonBlocking(true);
|
||||||
|
} << [&] {
|
||||||
|
return w->SetKeepAlive();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
if (!this->IsDistributed()) {
|
if (!this->IsDistributed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
|
||||||
|
"lead to undefined behaviour.";
|
||||||
auto rc = this->Shutdown();
|
auto rc = this->Shutdown();
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
LOG(WARNING) << rc.Report();
|
LOG(WARNING) << rc.Report();
|
||||||
@ -293,30 +361,49 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
if (!this->IsDistributed()) {
|
if (!this->IsDistributed()) {
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
// Tell the tracker that this worker is shutting down.
|
||||||
TCPSocket tracker;
|
TCPSocket tracker;
|
||||||
|
// Tell the error hanlding thread that we are shutting down.
|
||||||
|
TCPSocket err_client;
|
||||||
|
|
||||||
return Success() << [&] {
|
return Success() << [&] {
|
||||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||||
} << [&] {
|
} << [&] {
|
||||||
return this->Block();
|
return this->Block();
|
||||||
} << [&] {
|
} << [&] {
|
||||||
Json jcmd{Object{}};
|
return proto::ShutdownCMD{}.Send(&tracker);
|
||||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
|
||||||
auto scmd = Json::Dump(jcmd);
|
|
||||||
auto n_bytes = tracker.Send(scmd);
|
|
||||||
if (n_bytes != scmd.size()) {
|
|
||||||
return Fail("Faled to send cmd.");
|
|
||||||
}
|
|
||||||
|
|
||||||
this->ResetState();
|
|
||||||
return Success();
|
|
||||||
} << [&] {
|
} << [&] {
|
||||||
this->channels_.clear();
|
this->channels_.clear();
|
||||||
return Success();
|
return Success();
|
||||||
|
} << [&] {
|
||||||
|
// Use tracker address to determine whether we want to use IPv6.
|
||||||
|
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
|
||||||
|
// Shutdown the error handling thread. We signal the thread through socket,
|
||||||
|
// alternatively, we can get the native handle and use pthread_cancel. But using a
|
||||||
|
// socket seems to be clearer as we know what's happening.
|
||||||
|
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
|
||||||
|
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
|
||||||
|
// local socket. For a normal OS, this should be enough time to schedule the
|
||||||
|
// connection.
|
||||||
|
auto rc = Connect(StringView{addr}, this->error_port_, 1,
|
||||||
|
std::min(std::chrono::seconds{10}, timeout_), &err_client);
|
||||||
|
this->ResetState();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to connect to the error socket.", std::move(rc));
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
} << [&] {
|
||||||
|
// We put error thread shutdown at the end so that we have a better chance to finish
|
||||||
|
// the previous more important steps.
|
||||||
|
return proto::Error{}.SignalShutdown(&err_client);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
||||||
|
if (!this->IsDistributed()) {
|
||||||
|
LOG(CONSOLE) << msg;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
TCPSocket out;
|
TCPSocket out;
|
||||||
proto::Print print;
|
proto::Print print;
|
||||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||||
@ -324,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
||||||
TCPSocket out;
|
TCPSocket tracker;
|
||||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
return Success() << [&] {
|
||||||
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
|
return this->ConnectTracker(&tracker);
|
||||||
|
} << [&] {
|
||||||
|
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t, int64_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <thread> // for thread
|
#include <thread> // for thread
|
||||||
@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
|
||||||
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
|
inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
|
||||||
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
||||||
|
|
||||||
// indexing into the ring
|
// indexing into the ring
|
||||||
@ -51,7 +51,10 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
|
|
||||||
proto::PeerInfo tracker_;
|
proto::PeerInfo tracker_;
|
||||||
SockDomain domain_{SockDomain::kV4};
|
SockDomain domain_{SockDomain::kV4};
|
||||||
|
|
||||||
std::thread error_worker_;
|
std::thread error_worker_;
|
||||||
|
std::int32_t error_port_;
|
||||||
|
|
||||||
std::string task_id_;
|
std::string task_id_;
|
||||||
std::vector<std::shared_ptr<Channel>> channels_;
|
std::vector<std::shared_ptr<Channel>> channels_;
|
||||||
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
|
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
|
||||||
@ -59,6 +62,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
void ResetState() {
|
void ResetState() {
|
||||||
this->world_ = -1;
|
this->world_ = -1;
|
||||||
this->rank_ = 0;
|
this->rank_ = 0;
|
||||||
|
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
|
||||||
|
|
||||||
|
tracker_ = proto::PeerInfo{};
|
||||||
|
this->task_id_.clear();
|
||||||
|
channels_.clear();
|
||||||
|
|
||||||
|
loop_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -79,9 +89,9 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
[[nodiscard]] auto Retry() const { return retry_; }
|
[[nodiscard]] auto Retry() const { return retry_; }
|
||||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||||
|
|
||||||
[[nodiscard]] auto Rank() const { return rank_; }
|
[[nodiscard]] auto Rank() const noexcept { return rank_; }
|
||||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
[[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
|
||||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||||
void Submit(Loop::Op op) const {
|
void Submit(Loop::Op op) const {
|
||||||
CHECK(loop_);
|
CHECK(loop_);
|
||||||
loop_->Submit(op);
|
loop_->Submit(op);
|
||||||
@ -120,20 +130,20 @@ class RabitComm : public HostComm {
|
|||||||
|
|
||||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||||
std::string task_id);
|
std::string task_id);
|
||||||
[[nodiscard]] Result Shutdown() final;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// bootstrapping construction.
|
// bootstrapping construction.
|
||||||
RabitComm() = default;
|
RabitComm() = default;
|
||||||
// ctor for testing where environment is known.
|
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||||
std::int32_t retry, std::string task_id, StringView nccl_path);
|
StringView nccl_path);
|
||||||
~RabitComm() noexcept(false) override;
|
~RabitComm() noexcept(false) override;
|
||||||
|
|
||||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||||
|
|
||||||
[[nodiscard]] Result SignalError(Result const&) override;
|
[[nodiscard]] Result SignalError(Result const&) override;
|
||||||
|
[[nodiscard]] Result Shutdown() final;
|
||||||
|
|
||||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -64,6 +64,9 @@ CommGroup::CommGroup()
|
|||||||
|
|
||||||
auto const& obj = get<Object const>(config);
|
auto const& obj = get<Object const>(config);
|
||||||
auto it = obj.find(upper);
|
auto it = obj.find(upper);
|
||||||
|
if (it != obj.cend() && obj.find(name) != obj.cend()) {
|
||||||
|
LOG(FATAL) << "Duplicated parameter:" << name;
|
||||||
|
}
|
||||||
if (it != obj.cend()) {
|
if (it != obj.cend()) {
|
||||||
return OptionalArg<decltype(t)>(config, upper, dft);
|
return OptionalArg<decltype(t)>(config, upper, dft);
|
||||||
} else {
|
} else {
|
||||||
@ -77,14 +80,14 @@ CommGroup::CommGroup()
|
|||||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||||
|
|
||||||
if (type == "rabit") {
|
if (type == "rabit") {
|
||||||
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||||
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||||
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
||||||
auto ptr =
|
auto ptr = new CommGroup{
|
||||||
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||||
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
|
||||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||||
return ptr;
|
return ptr;
|
||||||
} else if (type == "federated") {
|
} else if (type == "federated") {
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
|||||||
@ -30,9 +30,9 @@ class CommGroup {
|
|||||||
public:
|
public:
|
||||||
CommGroup();
|
CommGroup();
|
||||||
|
|
||||||
[[nodiscard]] auto World() const { return comm_->World(); }
|
[[nodiscard]] auto World() const noexcept { return comm_->World(); }
|
||||||
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
[[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
|
||||||
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
[[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
|
||||||
|
|
||||||
[[nodiscard]] Result Finalize() const {
|
[[nodiscard]] Result Finalize() const {
|
||||||
return Success() << [this] {
|
return Success() << [this] {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
@ -58,6 +58,7 @@ struct Magic {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Basic commands for communication between workers and the tracker.
|
||||||
enum class CMD : std::int32_t {
|
enum class CMD : std::int32_t {
|
||||||
kInvalid = 0,
|
kInvalid = 0,
|
||||||
kStart = 1,
|
kStart = 1,
|
||||||
@ -84,7 +85,10 @@ struct Connect {
|
|||||||
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
||||||
std::string* task_id) const {
|
std::string* task_id) const {
|
||||||
std::string init;
|
std::string init;
|
||||||
sock->Recv(&init);
|
auto rc = sock->Recv(&init);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Connect protocol failed.", std::move(rc));
|
||||||
|
}
|
||||||
auto jinit = Json::Load(StringView{init});
|
auto jinit = Json::Load(StringView{init});
|
||||||
*world = get<Integer const>(jinit["world_size"]);
|
*world = get<Integer const>(jinit["world_size"]);
|
||||||
*rank = get<Integer const>(jinit["rank"]);
|
*rank = get<Integer const>(jinit["rank"]);
|
||||||
@ -122,9 +126,9 @@ class Start {
|
|||||||
}
|
}
|
||||||
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
||||||
std::string scmd;
|
std::string scmd;
|
||||||
auto n_bytes = tracker->Recv(&scmd);
|
auto rc = tracker->Recv(&scmd);
|
||||||
if (n_bytes <= 0) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to recv init command from tracker.");
|
return Fail("Failed to recv init command from tracker.", std::move(rc));
|
||||||
}
|
}
|
||||||
auto jcmd = Json::Load(scmd);
|
auto jcmd = Json::Load(scmd);
|
||||||
auto world = get<Integer const>(jcmd["world_size"]);
|
auto world = get<Integer const>(jcmd["world_size"]);
|
||||||
@ -132,7 +136,7 @@ class Start {
|
|||||||
return Fail("Invalid world size.");
|
return Fail("Invalid world size.");
|
||||||
}
|
}
|
||||||
*p_world = world;
|
*p_world = world;
|
||||||
return Success();
|
return rc;
|
||||||
}
|
}
|
||||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
||||||
std::int32_t* p_port, TCPSocket* p_sock,
|
std::int32_t* p_port, TCPSocket* p_sock,
|
||||||
@ -150,6 +154,7 @@ class Start {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker for printing message.
|
||||||
struct Print {
|
struct Print {
|
||||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
||||||
Json jcmd{Object{}};
|
Json jcmd{Object{}};
|
||||||
@ -172,6 +177,7 @@ struct Print {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker during error.
|
||||||
struct ErrorCMD {
|
struct ErrorCMD {
|
||||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
||||||
auto msg = res.Report();
|
auto msg = res.Report();
|
||||||
@ -199,6 +205,7 @@ struct ErrorCMD {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker during shutdown.
|
||||||
struct ShutdownCMD {
|
struct ShutdownCMD {
|
||||||
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
||||||
Json jcmd{Object{}};
|
Json jcmd{Object{}};
|
||||||
@ -211,4 +218,40 @@ struct ShutdownCMD {
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the local error handler during error or shutdown. Only
|
||||||
|
// one protocol that doesn't have the tracker involved.
|
||||||
|
struct Error {
|
||||||
|
constexpr static std::int32_t ShutdownSignal() { return 0; }
|
||||||
|
constexpr static std::int32_t ErrorSignal() { return -1; }
|
||||||
|
|
||||||
|
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||||
|
std::int32_t err{ErrorSignal()};
|
||||||
|
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||||
|
if (n_sent == sizeof(err)) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to send error signal");
|
||||||
|
}
|
||||||
|
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||||
|
// close.
|
||||||
|
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||||
|
std::int32_t err{ShutdownSignal()};
|
||||||
|
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||||
|
if (n_sent == sizeof(err)) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to send shutdown signal");
|
||||||
|
}
|
||||||
|
// get signal, either for error or for shutdown.
|
||||||
|
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||||
|
std::int32_t err{ShutdownSignal()};
|
||||||
|
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||||
|
if (n_recv == sizeof(err)) {
|
||||||
|
*p_is_error = err == 1;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to receive error signal.");
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace xgboost::collective::proto
|
} // namespace xgboost::collective::proto
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023 by XGBoost Contributors
|
* Copyright 2022-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/collective/socket.h"
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
@ -8,7 +8,8 @@
|
|||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cstring> // std::memcpy, std::memset
|
#include <cstring> // std::memcpy, std::memset
|
||||||
#include <filesystem> // for path
|
#include <filesystem> // for path
|
||||||
#include <system_error> // std::error_code, std::system_category
|
#include <system_error> // for error_code, system_category
|
||||||
|
#include <thread> // for sleep_for
|
||||||
|
|
||||||
#include "rabit/internal/socket.h" // for PollHelper
|
#include "rabit/internal/socket.h" // for PollHelper
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
|
|||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t TCPSocket::Recv(std::string *p_str) {
|
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||||
CHECK(!this->IsClosed());
|
CHECK(!this->IsClosed());
|
||||||
std::int32_t len;
|
std::int32_t len;
|
||||||
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
|
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||||
|
return Fail("Failed to recv string length.");
|
||||||
|
}
|
||||||
p_str->resize(len);
|
p_str->resize(len);
|
||||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||||
CHECK_EQ(bytes, len) << "Failed to recv string.";
|
if (static_cast<decltype(len)>(bytes) != len) {
|
||||||
return bytes;
|
return Fail("Failed to recv string.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||||
@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||||
if (attempt > 0) {
|
if (attempt > 0) {
|
||||||
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
|
||||||
Sleep(attempt << 1);
|
|
||||||
#else
|
|
||||||
sleep(attempt << 1);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||||
@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
|
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Failed to connect to " << host << ":" << port;
|
ss << "Failed to connect to " << host << ":" << port;
|
||||||
conn.Close();
|
auto close_rc = conn.Close();
|
||||||
return Fail(ss.str(), std::move(last_error));
|
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023-2024, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include "rabit/internal/socket.h"
|
||||||
#if defined(__unix__) || defined(__APPLE__)
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
#include <netdb.h> // gethostbyname
|
#include <netdb.h> // gethostbyname
|
||||||
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
||||||
@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||||
} << [&] {
|
} << [&] {
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
sock_.Recv(&cmd);
|
auto rc = sock_.Recv(&cmd);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
jcmd = Json::Load(StringView{cmd});
|
jcmd = Json::Load(StringView{cmd});
|
||||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||||
return Success();
|
return rc;
|
||||||
} << [&] {
|
} << [&] {
|
||||||
if (cmd_ == proto::CMD::kStart) {
|
if (cmd_ == proto::CMD::kStart) {
|
||||||
proto::Start start;
|
proto::Start start;
|
||||||
@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
|
|
||||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||||
std::string self;
|
std::string self;
|
||||||
auto rc = collective::GetHostAddress(&self);
|
auto rc = Success() << [&] {
|
||||||
host_ = OptionalArg<String>(config, "host", self);
|
return collective::GetHostAddress(&self);
|
||||||
|
} << [&] {
|
||||||
|
host_ = OptionalArg<String>(config, "host", self);
|
||||||
|
|
||||||
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||||
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||||
rc = listener_.Bind(host_, &this->port_);
|
return listener_.Bind(host_, &this->port_);
|
||||||
|
} << [&] {
|
||||||
|
return listener_.Listen();
|
||||||
|
};
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
listener_.Listen();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||||
@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
//
|
//
|
||||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||||
// tracker and the worker might be waiting for each other.
|
// tracker and the worker might be waiting for each other.
|
||||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
auto rc = Success() << [&] {
|
||||||
|
return Connect(w.first, w.second, 1, timeout_, &out);
|
||||||
|
} << [&] {
|
||||||
|
return proto::Error{}.SignalError(&out);
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to inform workers to stop.");
|
return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
return std::async(std::launch::async, [this, handle_error] {
|
return std::async(std::launch::async, [this, handle_error] {
|
||||||
State state{this->n_workers_};
|
State state{this->n_workers_};
|
||||||
|
|
||||||
|
auto select_accept = [&](TCPSocket* sock, auto* addr) {
|
||||||
|
// accept with poll so that we can enable timeout and interruption.
|
||||||
|
rabit::utils::PollHelper poll;
|
||||||
|
auto rc = Success() << [&] {
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
return listener_.NonBlocking(true);
|
||||||
|
} << [&] {
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
poll.WatchRead(listener_);
|
||||||
|
if (state.running) {
|
||||||
|
// Don't timeout if the communicator group is up and running.
|
||||||
|
return poll.Poll(std::chrono::seconds{-1});
|
||||||
|
} else {
|
||||||
|
// Have timeout for workers to bootstrap.
|
||||||
|
return poll.Poll(timeout_);
|
||||||
|
}
|
||||||
|
} << [&] {
|
||||||
|
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
|
||||||
|
// due to shutdown, the state is still valid (closed).
|
||||||
|
return listener_.Accept(sock, addr);
|
||||||
|
};
|
||||||
|
return rc;
|
||||||
|
};
|
||||||
|
|
||||||
while (state.ShouldContinue()) {
|
while (state.ShouldContinue()) {
|
||||||
TCPSocket sock;
|
TCPSocket sock;
|
||||||
SockAddress addr;
|
SockAddress addr;
|
||||||
this->ready_ = true;
|
this->ready_ = true;
|
||||||
auto rc = listener_.Accept(&sock, &addr);
|
auto rc = select_accept(&sock, &addr);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to accept connection.", std::move(rc));
|
return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
||||||
@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
state.Error();
|
state.Error();
|
||||||
rc = handle_error(worker);
|
rc = handle_error(worker);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to handle abort.", std::move(rc));
|
return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
state.Bootstrap();
|
state.Bootstrap();
|
||||||
}
|
}
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return this->Stop() + std::move(rc);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
}
|
}
|
||||||
case proto::CMD::kInvalid:
|
case proto::CMD::kInvalid:
|
||||||
default: {
|
default: {
|
||||||
return Fail("Invalid command received.");
|
return Fail("Invalid command received.", this->Stop());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ready_ = false;
|
return this->Stop();
|
||||||
return Success();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
|
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host_};
|
args["dmlc_tracker_uri"] = String{host_};
|
||||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
args["dmlc_tracker_port"] = this->Port();
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result RabitTracker::Stop() {
|
||||||
|
if (!this->Ready()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ready_ = false;
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
if (this->listener_.IsClosed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success() << [&] {
|
||||||
|
// This should have the effect of stopping the `accept` call.
|
||||||
|
return this->listener_.Shutdown();
|
||||||
|
} << [&] {
|
||||||
|
return listener_.Close();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||||
auto rc = GetHostName(out);
|
auto rc = GetHostName(out);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
|
|||||||
@ -36,15 +36,18 @@ namespace xgboost::collective {
|
|||||||
* signal an error to the tracker and the tracker will notify other workers.
|
* signal an error to the tracker and the tracker will notify other workers.
|
||||||
*/
|
*/
|
||||||
class Tracker {
|
class Tracker {
|
||||||
|
public:
|
||||||
|
enum class SortBy : std::int8_t {
|
||||||
|
kHost = 0,
|
||||||
|
kTask = 1,
|
||||||
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
||||||
// setting, multiple workers can occupy the same host, in which case one should sort
|
// setting, multiple workers can occupy the same host, in which case one should sort
|
||||||
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
||||||
// we use host as the default.
|
// we use host as the default.
|
||||||
enum class SortBy : std::int8_t {
|
SortBy sortby_;
|
||||||
kHost = 0,
|
|
||||||
kTask = 1,
|
|
||||||
} sortby_;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::int32_t n_workers_{0};
|
std::int32_t n_workers_{0};
|
||||||
@ -54,10 +57,7 @@ class Tracker {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Tracker(Json const& config);
|
explicit Tracker(Json const& config);
|
||||||
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
|
virtual ~Tracker() = default;
|
||||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
|
||||||
|
|
||||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
|
||||||
|
|
||||||
[[nodiscard]] Result WaitUntilReady() const;
|
[[nodiscard]] Result WaitUntilReady() const;
|
||||||
|
|
||||||
@ -69,6 +69,11 @@ class Tracker {
|
|||||||
* @brief Flag to indicate whether the server is running.
|
* @brief Flag to indicate whether the server is running.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] bool Ready() const { return ready_; }
|
[[nodiscard]] bool Ready() const { return ready_; }
|
||||||
|
/**
|
||||||
|
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
|
||||||
|
* calling accept.
|
||||||
|
*/
|
||||||
|
virtual Result Stop() { return Success(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class RabitTracker : public Tracker {
|
class RabitTracker : public Tracker {
|
||||||
@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
|
|||||||
// record for how to reach out to workers if error happens.
|
// record for how to reach out to workers if error happens.
|
||||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||||
// listening socket for incoming workers.
|
// listening socket for incoming workers.
|
||||||
//
|
|
||||||
// At the moment, the listener calls accept without first polling. We can add an
|
|
||||||
// additional unix domain socket to allow cancelling the accept.
|
|
||||||
TCPSocket listener_;
|
TCPSocket listener_;
|
||||||
|
// mutex for protecting the listener, used to prevent race when it's listening while
|
||||||
|
// another thread tries to shut it down.
|
||||||
|
std::mutex listener_mu_;
|
||||||
|
|
||||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
|
|
||||||
std::chrono::seconds timeout)
|
|
||||||
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
|
|
||||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
|
||||||
auto rc = listener_.Bind(host, &this->port_);
|
|
||||||
CHECK(rc.OK()) << rc.Report();
|
|
||||||
listener_.Listen();
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit RabitTracker(Json const& config);
|
explicit RabitTracker(Json const& config);
|
||||||
~RabitTracker() noexcept(false) override = default;
|
~RabitTracker() override = default;
|
||||||
|
|
||||||
std::future<Result> Run() override;
|
std::future<Result> Run() override;
|
||||||
[[nodiscard]] Json WorkerArgs() const override;
|
[[nodiscard]] Json WorkerArgs() const override;
|
||||||
|
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
|
||||||
|
// one of the workers failes to start.
|
||||||
|
[[nodiscard]] Result Stop() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Prob the public IP address of the host, need a better method.
|
// Prob the public IP address of the host, need a better method.
|
||||||
|
|||||||
@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) {
|
|||||||
auto config_str = Json::Dump(config);
|
auto config_str = Json::Dump(config);
|
||||||
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||||
ASSERT_EQ(rc, 0);
|
ASSERT_EQ(rc, 0);
|
||||||
rc = XGTrackerRun(handle);
|
rc = XGTrackerRun(handle, nullptr);
|
||||||
ASSERT_EQ(rc, 0);
|
ASSERT_EQ(rc, 0);
|
||||||
|
|
||||||
std::thread bg_wait{[&] {
|
std::thread bg_wait{[&] {
|
||||||
Json config{Object{}};
|
Json config{Object{}};
|
||||||
auto config_str = Json::Dump(config);
|
auto config_str = Json::Dump(config);
|
||||||
auto rc = XGTrackerWait(handle, config_str.c_str());
|
auto rc = XGTrackerWaitFor(handle, config_str.c_str());
|
||||||
ASSERT_EQ(rc, 0);
|
ASSERT_EQ(rc, 0);
|
||||||
}};
|
}};
|
||||||
|
|
||||||
@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) {
|
|||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
ASSERT_TRUE(GetHostAddress(&host).OK());
|
ASSERT_TRUE(GetHostAddress(&host).OK());
|
||||||
ASSERT_EQ(host, get<String const>(args["DMLC_TRACKER_URI"]));
|
ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
|
||||||
auto port = get<Integer const>(args["DMLC_TRACKER_PORT"]);
|
auto port = get<Integer const>(args["dmlc_tracker_port"]);
|
||||||
ASSERT_NE(port, 0);
|
ASSERT_NE(port, 0);
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class CommTest : public TrackerTest {};
|
|||||||
|
|
||||||
TEST_F(CommTest, Channel) {
|
TEST_F(CommTest, Channel) {
|
||||||
auto n_workers = 4;
|
auto n_workers = 4;
|
||||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) {
|
|||||||
return p_chan->SendAll(
|
return p_chan->SendAll(
|
||||||
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
|
||||||
} << [&] { return p_chan->Block(); };
|
} << [&] { return p_chan->Block(); };
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
} else {
|
} else {
|
||||||
auto p_chan = worker.Comm().Chan(i - 1);
|
auto p_chan = worker.Comm().Chan(i - 1);
|
||||||
std::int32_t r{-1};
|
std::int32_t r{-1};
|
||||||
@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) {
|
|||||||
return p_chan->RecvAll(
|
return p_chan->RecvAll(
|
||||||
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
|
||||||
} << [&] { return p_chan->Block(); };
|
} << [&] { return p_chan->Block(); };
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
ASSERT_EQ(r, i - 1);
|
ASSERT_EQ(r, i - 1);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@ -17,17 +17,6 @@
|
|||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace {
|
namespace {
|
||||||
auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) {
|
|
||||||
Json config{Object{}};
|
|
||||||
config["dmlc_communicator"] = std::string{"rabit"};
|
|
||||||
config["DMLC_TRACKER_URI"] = host;
|
|
||||||
config["DMLC_TRACKER_PORT"] = port;
|
|
||||||
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
|
||||||
config["DMLC_TASK_ID"] = std::to_string(r);
|
|
||||||
config["dmlc_retry"] = 2;
|
|
||||||
return config;
|
|
||||||
}
|
|
||||||
|
|
||||||
class CommGroupTest : public SocketTest {};
|
class CommGroupTest : public SocketTest {};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) {
|
|||||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
std::int32_t r) {
|
std::int32_t r) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
auto config = MakeConfig(host, port, timeout, r);
|
auto config = MakeDistributedTestConfig(host, port, timeout, r);
|
||||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||||
ASSERT_TRUE(ptr->IsDistributed());
|
ASSERT_TRUE(ptr->IsDistributed());
|
||||||
ASSERT_EQ(ptr->World(), n_workers);
|
ASSERT_EQ(ptr->World(), n_workers);
|
||||||
@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) {
|
|||||||
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
std::int32_t r) {
|
std::int32_t r) {
|
||||||
auto ctx = MakeCUDACtx(r);
|
auto ctx = MakeCUDACtx(r);
|
||||||
auto config = MakeConfig(host, port, timeout, r);
|
auto config = MakeDistributedTestConfig(host, port, timeout, r);
|
||||||
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
std::unique_ptr<CommGroup> ptr{CommGroup::Create(config)};
|
||||||
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0));
|
||||||
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
ASSERT_EQ(comm.TaskID(), std::to_string(r));
|
||||||
|
|||||||
@ -28,13 +28,11 @@ class LoopTest : public ::testing::Test {
|
|||||||
|
|
||||||
auto domain = SockDomain::kV4;
|
auto domain = SockDomain::kV4;
|
||||||
pair_.first = TCPSocket::Create(domain);
|
pair_.first = TCPSocket::Create(domain);
|
||||||
in_port_t port{0};
|
std::int32_t port{0};
|
||||||
auto rc = Success() << [&] {
|
auto rc = Success() << [&] {
|
||||||
port = pair_.first.BindHost();
|
return pair_.first.BindHost(&port);
|
||||||
return Success();
|
|
||||||
} << [&] {
|
} << [&] {
|
||||||
pair_.first.Listen();
|
return pair_.first.Listen();
|
||||||
return Success();
|
|
||||||
};
|
};
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023, XGBoost Contributors
|
* Copyright 2022-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/collective/socket.h>
|
#include <xgboost/collective/socket.h>
|
||||||
@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) {
|
|||||||
auto run_test = [msg](SockDomain domain) {
|
auto run_test = [msg](SockDomain domain) {
|
||||||
auto server = TCPSocket::Create(domain);
|
auto server = TCPSocket::Create(domain);
|
||||||
ASSERT_EQ(server.Domain(), domain);
|
ASSERT_EQ(server.Domain(), domain);
|
||||||
auto port = server.BindHost();
|
std::int32_t port{0};
|
||||||
server.Listen();
|
auto rc = Success() << [&] {
|
||||||
|
return server.BindHost(&port);
|
||||||
|
} << [&] {
|
||||||
|
return server.Listen();
|
||||||
|
};
|
||||||
|
SafeColl(rc);
|
||||||
|
|
||||||
TCPSocket client;
|
TCPSocket client;
|
||||||
if (domain == SockDomain::kV4) {
|
if (domain == SockDomain::kV4) {
|
||||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
} else {
|
} else {
|
||||||
auto const& addr = SockAddrV6::Loopback().Addr();
|
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||||
@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) {
|
|||||||
accepted.Send(msg);
|
accepted.Send(msg);
|
||||||
|
|
||||||
std::string str;
|
std::string str;
|
||||||
client.Recv(&str);
|
rc = client.Recv(&str);
|
||||||
|
SafeColl(rc);
|
||||||
ASSERT_EQ(StringView{str}, msg);
|
ASSERT_EQ(StringView{str}, msg);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
@ -10,6 +11,7 @@
|
|||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../../../src/collective/comm.h"
|
#include "../../../src/collective/comm.h"
|
||||||
|
#include "../helpers.h" // for GMockThrow
|
||||||
#include "test_worker.h"
|
#include "test_worker.h"
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -20,13 +22,13 @@ class PrintWorker : public WorkerForTest {
|
|||||||
|
|
||||||
void Print() {
|
void Print() {
|
||||||
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TEST_F(TrackerTest, Bootstrap) {
|
TEST_F(TrackerTest, Bootstrap) {
|
||||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||||
ASSERT_FALSE(tracker.Ready());
|
ASSERT_FALSE(tracker.Ready());
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ TEST_F(TrackerTest, Bootstrap) {
|
|||||||
|
|
||||||
auto args = tracker.WorkerArgs();
|
auto args = tracker.WorkerArgs();
|
||||||
ASSERT_TRUE(tracker.Ready());
|
ASSERT_TRUE(tracker.Ready());
|
||||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
|
||||||
|
|
||||||
std::int32_t port = tracker.Port();
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
@ -44,12 +46,11 @@ TEST_F(TrackerTest, Bootstrap) {
|
|||||||
for (auto &w : workers) {
|
for (auto &w : workers) {
|
||||||
w.join();
|
w.join();
|
||||||
}
|
}
|
||||||
|
SafeColl(fut.get());
|
||||||
ASSERT_TRUE(fut.get().OK());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TrackerTest, Print) {
|
TEST_F(TrackerTest, Print) {
|
||||||
RabitTracker tracker{host, n_workers, 0, timeout};
|
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
@ -73,4 +74,47 @@ TEST_F(TrackerTest, Print) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test connecting the tracker after it has finished. This should not hang the workers.
|
||||||
|
*/
|
||||||
|
TEST_F(TrackerTest, AfterShutdown) {
|
||||||
|
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||||
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
auto rc = tracker.WaitUntilReady();
|
||||||
|
ASSERT_TRUE(rc.OK());
|
||||||
|
|
||||||
|
std::int32_t port = tracker.Port();
|
||||||
|
|
||||||
|
// Launch no-op workers to cause the tracker to shutdown.
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &w : workers) {
|
||||||
|
w.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(fut.get().OK());
|
||||||
|
|
||||||
|
// Launch workers again, they should fail.
|
||||||
|
workers.clear();
|
||||||
|
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||||
|
auto assert_that = [=] {
|
||||||
|
WorkerForTest worker{host, port, timeout, n_workers, i};
|
||||||
|
};
|
||||||
|
// On a Linux platform, the connection will be refused, on Apple platform, this gets
|
||||||
|
// an operation now in progress poll failure, on Windows, it's a timeout error.
|
||||||
|
#if defined(__linux__)
|
||||||
|
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); });
|
||||||
|
#else
|
||||||
|
workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); });
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
for (auto &w : workers) {
|
||||||
|
w.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class WorkerForTest {
|
|||||||
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
|
comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} {
|
||||||
CHECK_EQ(world_size_, comm_.World());
|
CHECK_EQ(world_size_, comm_.World());
|
||||||
}
|
}
|
||||||
virtual ~WorkerForTest() = default;
|
virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); }
|
||||||
auto& Comm() { return comm_; }
|
auto& Comm() { return comm_; }
|
||||||
|
|
||||||
void LimitSockBuf(std::int32_t n_bytes) {
|
void LimitSockBuf(std::int32_t n_bytes) {
|
||||||
@ -87,19 +87,30 @@ class TrackerTest : public SocketTest {
|
|||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
SocketTest::SetUp();
|
SocketTest::SetUp();
|
||||||
auto rc = GetHostAddress(&host);
|
auto rc = GetHostAddress(&host);
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers,
|
||||||
|
std::chrono::seconds timeout) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["host"] = host;
|
||||||
|
config["port"] = Integer{0};
|
||||||
|
config["n_workers"] = Integer{n_workers};
|
||||||
|
config["sortby"] = Integer{static_cast<std::int32_t>(Tracker::SortBy::kHost)};
|
||||||
|
config["timeout"] = timeout.count();
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename WorkerFn>
|
template <typename WorkerFn>
|
||||||
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
||||||
std::chrono::seconds timeout{2};
|
std::chrono::seconds timeout{2};
|
||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
auto rc = GetHostAddress(&host);
|
auto rc = GetHostAddress(&host);
|
||||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
LOG(INFO) << "Using " << n_workers << " workers for test.";
|
LOG(INFO) << "Using " << n_workers << " workers for test.";
|
||||||
RabitTracker tracker{StringView{host}, n_workers, 0, timeout};
|
RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
|
||||||
auto fut = tracker.Run();
|
auto fut = tracker.Run();
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
@ -115,4 +126,15 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) {
|
|||||||
|
|
||||||
ASSERT_TRUE(fut.get().OK());
|
ASSERT_TRUE(fut.get().OK());
|
||||||
}
|
}
|
||||||
|
inline auto MakeDistributedTestConfig(std::string host, std::int32_t port,
|
||||||
|
std::chrono::seconds timeout, std::int32_t r) {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["dmlc_communicator"] = std::string{"rabit"};
|
||||||
|
config["dmlc_tracker_uri"] = host;
|
||||||
|
config["dmlc_tracker_port"] = port;
|
||||||
|
config["dmlc_timeout_sec"] = static_cast<std::int64_t>(timeout.count());
|
||||||
|
config["dmlc_task_id"] = std::to_string(r);
|
||||||
|
config["dmlc_retry"] = 2;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include "../../../../src/collective/tracker.h" // for GetHostAddress
|
#include "../../../../src/collective/tracker.h" // for GetHostAddress
|
||||||
#include "federated_tracker.h"
|
#include "federated_tracker.h"
|
||||||
#include "test_worker.h"
|
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -26,7 +25,7 @@ TEST(FederatedTrackerTest, Basic) {
|
|||||||
ASSERT_GE(tracker->Port(), 1);
|
ASSERT_GE(tracker->Port(), 1);
|
||||||
std::string host;
|
std::string host;
|
||||||
auto rc = GetHostAddress(&host);
|
auto rc = GetHostAddress(&host);
|
||||||
ASSERT_EQ(get<String const>(args["DMLC_TRACKER_URI"]), host);
|
ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);
|
||||||
|
|
||||||
rc = tracker->Shutdown();
|
rc = tracker->Shutdown();
|
||||||
ASSERT_TRUE(rc.OK());
|
ASSERT_TRUE(rc.OK());
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user