Utilities and cleanups for socket. (#9576)
- Use c++-17 nodiscard and nested ns. - Add bind method to socket. - Remove rabit parameters.
This commit is contained in:
@@ -11,9 +11,7 @@
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
namespace xgboost::collective {
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
|
||||
@@ -57,6 +55,4 @@ void Communicator::Finalize() {
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
#include <array> // for array
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstring> // std::memcpy, std::memset
|
||||
@@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
|
||||
conn = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
conn.SetNonBlock(true);
|
||||
auto non_blocking = conn.NonBlocking();
|
||||
auto rc = conn.NonBlocking(true);
|
||||
if (!rc.OK()) {
|
||||
return Fail("Failed to set socket option.", std::move(rc));
|
||||
}
|
||||
|
||||
Result last_error;
|
||||
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
||||
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
|
||||
last_error = std::move(err);
|
||||
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
||||
<< "): Failed to connect to:" << host << ":" << port
|
||||
<< " Error:" << last_error.Report();
|
||||
};
|
||||
|
||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||
@@ -138,12 +144,9 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
continue;
|
||||
}
|
||||
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
|
||||
return conn.NonBlocking(non_blocking);
|
||||
} else {
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
return conn.NonBlocking(non_blocking);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,4 +155,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
conn.Close();
|
||||
return Fail(ss.str(), std::move(last_error));
|
||||
}
|
||||
|
||||
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||
std::array<char, HOST_NAME_MAX> buf;
|
||||
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
|
||||
return system::FailWithCode("Failed to get host name.");
|
||||
}
|
||||
*p_out = buf.data();
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user