Utilities and cleanups for socket. (#9576)

- Use c++-17 nodiscard and nested ns.
- Add bind method to socket.
- Remove rabit parameters.
This commit is contained in:
Jiaming Yuan
2023-09-14 01:41:42 +08:00
committed by GitHub
parent 5abe50ff8c
commit b438d684d2
12 changed files with 187 additions and 75 deletions

View File

@@ -11,9 +11,7 @@
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost {
namespace collective {
namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
@@ -57,6 +55,4 @@ void Communicator::Finalize() {
communicator_.reset(new NoOpCommunicator());
}
#endif
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective

View File

@@ -3,6 +3,7 @@
*/
#include "xgboost/collective/socket.h"
#include <array> // for array
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
@@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
conn = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
conn.SetNonBlock(true);
auto non_blocking = conn.NonBlocking();
auto rc = conn.NonBlocking(true);
if (!rc.OK()) {
return Fail("Failed to set socket option.", std::move(rc));
}
Result last_error;
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) {
last_error = std::move(err);
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
<< "): Failed to connect to:" << host << ":" << port
<< " Error:" << last_error.Report();
};
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
@@ -138,12 +144,9 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
continue;
}
conn.SetNonBlock(false);
return Success();
return conn.NonBlocking(non_blocking);
} else {
conn.SetNonBlock(false);
return Success();
return conn.NonBlocking(non_blocking);
}
}
@@ -152,4 +155,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
conn.Close();
return Fail(ss.str(), std::move(last_error));
}
[[nodiscard]] Result GetHostName(std::string *p_out) {
std::array<char, HOST_NAME_MAX> buf;
if (gethostname(&buf[0], HOST_NAME_MAX) != 0) {
return system::FailWithCode("Failed to get host name.");
}
*p_out = buf.data();
return Success();
}
} // namespace xgboost::collective