Build a simple event loop for collective. (#9593)
This commit is contained in:
@@ -16,8 +16,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "rabit/internal/utils.h"
|
||||
#include "rabit/serializable.h"
|
||||
#include "dmlc/io.h"
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace rabit::utils {
|
||||
/*! \brief re-use definition of dmlc::SeekStream */
|
||||
@@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream {
|
||||
}
|
||||
~MemoryBufferStream() override = default;
|
||||
size_t Read(void *ptr, size_t size) override {
|
||||
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
||||
"read can not have position excceed buffer length");
|
||||
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
|
||||
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||
curr_ptr_ += nread;
|
||||
|
||||
@@ -29,11 +29,10 @@
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <system_error> // make_error_code, errc
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#if !defined(_WIN32)
|
||||
|
||||
#include <sys/poll.h>
|
||||
@@ -93,6 +92,20 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
|
||||
#endif // IS_MINGW()
|
||||
}
|
||||
|
||||
template <typename E>
|
||||
std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
|
||||
if ((revents & POLLERR) != 0) {
|
||||
return xgboost::system::FailWithCode("Poll error condition.");
|
||||
}
|
||||
if ((revents & POLLNVAL) != 0) {
|
||||
return xgboost::system::FailWithCode("Invalid polling request.");
|
||||
}
|
||||
if ((revents & POLLHUP) != 0) {
|
||||
return xgboost::system::FailWithCode("Poll hung up.");
|
||||
}
|
||||
return xgboost::collective::Success();
|
||||
}
|
||||
|
||||
/*! \brief helper data structure to perform poll */
|
||||
struct PollHelper {
|
||||
public:
|
||||
@@ -160,25 +173,32 @@ struct PollHelper {
|
||||
*
|
||||
* @param timeout specify timeout in seconds. Block if negative.
|
||||
*/
|
||||
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) {
|
||||
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
|
||||
bool check_error = true) {
|
||||
std::vector<pollfd> fdset;
|
||||
fdset.reserve(fds.size());
|
||||
for (auto kv : fds) {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == 0) {
|
||||
return xgboost::collective::Fail("Poll timeout.");
|
||||
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
|
||||
} else if (ret < 0) {
|
||||
return xgboost::system::FailWithCode("Poll failed.");
|
||||
} else {
|
||||
for (auto& pfd : fdset) {
|
||||
auto revents = pfd.revents & pfd.events;
|
||||
if (!revents) {
|
||||
fds.erase(pfd.fd);
|
||||
} else {
|
||||
fds[pfd.fd].events = revents;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& pfd : fdset) {
|
||||
auto result = PollError(pfd.revents);
|
||||
if (check_error && !result.OK()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto revents = pfd.revents & pfd.events;
|
||||
if (!revents) {
|
||||
// FIXME(jiamingy): remove this once rabit is replaced.
|
||||
fds.erase(pfd.fd);
|
||||
} else {
|
||||
fds[pfd.fd].events = revents;
|
||||
}
|
||||
}
|
||||
return xgboost::collective::Success();
|
||||
|
||||
Reference in New Issue
Block a user