Build a simple event loop for collective. (#9593)

This commit is contained in:
Jiaming Yuan
2023-09-20 02:09:07 +08:00
committed by GitHub
parent 259d80c0cf
commit 38ac52dd87
7 changed files with 402 additions and 47 deletions

View File

@@ -16,8 +16,8 @@
#include <string>
#include <vector>
#include "rabit/internal/utils.h"
#include "rabit/serializable.h"
#include "dmlc/io.h"
#include "xgboost/logging.h"
namespace rabit::utils {
/*! \brief re-use definition of dmlc::SeekStream */
@@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream {
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;

View File

@@ -29,11 +29,10 @@
#include <chrono>
#include <cstring>
#include <string>
#include <system_error> // make_error_code, errc
#include <unordered_map>
#include <vector>
#include "utils.h"
#if !defined(_WIN32)
#include <sys/poll.h>
@@ -93,6 +92,20 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
#endif // IS_MINGW()
}
template <typename E>
std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
if ((revents & POLLERR) != 0) {
return xgboost::system::FailWithCode("Poll error condition.");
}
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up.");
}
return xgboost::collective::Success();
}
/*! \brief helper data structure to perform poll */
struct PollHelper {
public:
@@ -160,25 +173,32 @@ struct PollHelper {
*
* @param timeout specify timeout in seconds. Block if negative.
*/
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) {
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
bool check_error = true) {
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.");
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
} else {
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
if (!revents) {
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
for (auto& pfd : fdset) {
auto result = PollError(pfd.revents);
if (check_error && !result.OK()) {
return result;
}
auto revents = pfd.revents & pfd.events;
if (!revents) {
// FIXME(jiamingy): remove this once rabit is replaced.
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
return xgboost::collective::Success();

View File

@@ -721,12 +721,11 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
finished = false;
}
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
auto poll_res = watcher.Poll(timeout_sec);
auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos
if (!poll_res.OK()) {
LOG(FATAL) << poll_res.Report();
}