diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index afc1f47fd..9bce616ef 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1554,29 +1554,19 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, * \param config JSON encoded configuration. Accepted JSON keys are: * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. * * rabit: Use Rabit. This is the default if the type is unspecified. - * * mpi: Use MPI. * * federated: Use the gRPC interface for Federated Learning. * Only applicable to the Rabit communicator (these are case-sensitive): * - rabit_tracker_uri: Hostname of the tracker. * - rabit_tracker_port: Port number of the tracker. * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. * - rabit_world_size: Total number of workers. - * - rabit_hadoop_mode: Enable Hadoop support. - * - rabit_tree_reduce_minsize: Minimal size for tree reduce. - * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. - * - rabit_reduce_buffer: Size of the reduce buffer. - * - rabit_bootstrap_cache: Size of the bootstrap cache. - * - rabit_debug: Enable debugging. * - rabit_timeout: Enable timeout. * - rabit_timeout_sec: Timeout in seconds. - * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as * environment variables): * - DMLC_TRACKER_URI: Hostname of the tracker. * - DMLC_TRACKER_PORT: Port number of the tracker. * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. - * - DMLC_ROLE: Role of the current task, "worker" or "server". - * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. * Only applicable to the Federated communicator (use upper case for environment variables, use * lower case for runtime configuration): diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 5bff2204e..f36cdccb2 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -215,9 +215,9 @@ class SockAddrV4 { static SockAddrV4 Loopback(); static SockAddrV4 InaddrAny(); - in_port_t Port() const { return ntohs(addr_.sin_port); } + [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); } - std::string Addr() const { + [[nodiscard]] std::string Addr() const { char buf[INET_ADDRSTRLEN]; auto const *s = system::inet_ntop(static_cast(SockDomain::kV4), &addr_.sin_addr, buf, INET_ADDRSTRLEN); @@ -226,7 +226,7 @@ class SockAddrV4 { } return {buf}; } - sockaddr_in const &Handle() const { return addr_; } + [[nodiscard]] sockaddr_in const &Handle() const { return addr_; } }; /** @@ -243,13 +243,13 @@ class SockAddress { explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {} explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {} - auto Domain() const { return domain_; } + [[nodiscard]] auto Domain() const { return domain_; } - bool IsV4() const { return Domain() == SockDomain::kV4; } - bool IsV6() const { return !IsV4(); } + [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; } + [[nodiscard]] bool IsV6() const { return !IsV4(); } - auto const &V4() const { return v4_; } - auto const &V6() const { return v6_; } + [[nodiscard]] auto const &V4() const { return v4_; } + [[nodiscard]] auto const &V6() const { return v6_; } }; /** @@ -261,6 +261,7 @@ class TCPSocket { private: HandleT handle_{InvalidSocket()}; + bool non_blocking_{false}; // There's reliable no way to extract domain from a socket without first binding that // socket on macos. #if defined(__APPLE__) @@ -276,7 +277,7 @@ class TCPSocket { /** * \brief Return the socket domain. */ - auto Domain() const -> SockDomain { + [[nodiscard]] auto Domain() const -> SockDomain { auto ret_iafamily = [](std::int32_t domain) { switch (domain) { case AF_INET: @@ -321,10 +322,10 @@ class TCPSocket { #endif // platforms } - bool IsClosed() const { return handle_ == InvalidSocket(); } + [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); } - /** \brief get last error code if any */ - Result GetSockError() const { + /** @brief get last error code if any */ + [[nodiscard]] Result GetSockError() const { std::int32_t optval = 0; socklen_t len = sizeof(optval); auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast(&optval), &len); @@ -340,7 +341,7 @@ class TCPSocket { } /** \brief check if anything bad happens */ - bool BadSocket() const { + [[nodiscard]] bool BadSocket() const { if (IsClosed()) { return true; } @@ -352,24 +353,56 @@ class TCPSocket { return false; } - void SetNonBlock(bool non_block) { + [[nodiscard]] Result NonBlocking(bool non_block) { #if defined(_WIN32) u_long mode = non_block ? 1 : 0; - xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR); + if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) { + return system::FailWithCode("Failed to set socket to non-blocking."); + } #else std::int32_t flag = fcntl(handle_, F_GETFL, 0); - if (flag == -1) { - system::ThrowAtError("fcntl"); + auto rc = flag; + if (rc == -1) { + return system::FailWithCode("Failed to get socket flag."); } if (non_block) { flag |= O_NONBLOCK; } else { flag &= ~O_NONBLOCK; } - if (fcntl(handle_, F_SETFL, flag) == -1) { - system::ThrowAtError("fcntl"); + rc = fcntl(handle_, F_SETFL, flag); + if (rc == -1) { + return system::FailWithCode("Failed to set socket to non-blocking."); } #endif // _WIN32 + non_blocking_ = non_block; + return Success(); + } + [[nodiscard]] bool NonBlocking() const { return non_blocking_; } + [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) { + timeval tv; + tv.tv_sec = timeout.count(); + tv.tv_usec = 0; + auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), + sizeof(tv)); + if (rc != 0) { + return system::FailWithCode("Failed to set timeout on recv."); + } + return Success(); + } + + [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) { + auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast(&n_bytes), + sizeof(n_bytes)); + if (rc != 0) { + return system::FailWithCode("Failed to set send buffer size."); + } + rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast(&n_bytes), + sizeof(n_bytes)); + if (rc != 0) { + return system::FailWithCode("Failed to set recv buffer size."); + } + return Success(); } void SetKeepAlive() { @@ -391,7 +424,7 @@ class TCPSocket { * \brief Accept new connection, returns a new TCP socket for the new connection. */ TCPSocket Accept() { - HandleT newfd = accept(handle_, nullptr, nullptr); + HandleT newfd = accept(Handle(), nullptr, nullptr); if (newfd == InvalidSocket()) { system::ThrowAtError("accept"); } @@ -399,6 +432,18 @@ class TCPSocket { return newsock; } + [[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) { + struct sockaddr_in caddr; + socklen_t caddr_len = sizeof(caddr); + HandleT newfd = accept(Handle(), reinterpret_cast(&caddr), &caddr_len); + if (newfd == InvalidSocket()) { + return system::FailWithCode("Failed to accept."); + } + *addr = SockAddrV4{caddr}; + *out = TCPSocket{newfd}; + return Success(); + } + ~TCPSocket() { if (!IsClosed()) { Close(); @@ -413,9 +458,9 @@ class TCPSocket { return *this; } /** - * \brief Return the native socket file descriptor. + * @brief Return the native socket file descriptor. */ - HandleT const &Handle() const { return handle_; } + [[nodiscard]] HandleT const &Handle() const { return handle_; } /** * \brief Listen to incoming requests. Should be called after bind. */ @@ -448,6 +493,49 @@ class TCPSocket { return ntohs(res_addr.sin_port); } } + + [[nodiscard]] auto Port() const { + if (this->Domain() == SockDomain::kV4) { + sockaddr_in res_addr; + socklen_t addrlen = sizeof(res_addr); + auto code = getsockname(handle_, reinterpret_cast(&res_addr), &addrlen); + if (code != 0) { + return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0}); + } + return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)}); + } else { + sockaddr_in6 res_addr; + socklen_t addrlen = sizeof(res_addr); + auto code = getsockname(handle_, reinterpret_cast(&res_addr), &addrlen); + if (code != 0) { + return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0}); + } + return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)}); + } + } + + [[nodiscard]] Result Bind(StringView ip, std::int32_t *port) { + // bind socket handle_ to ip + auto addr = MakeSockAddress(ip, 0); + std::int32_t errc{0}; + if (addr.IsV4()) { + auto handle = reinterpret_cast(&addr.V4().Handle()); + errc = bind(handle_, handle, sizeof(std::remove_reference_t)); + } else { + auto handle = reinterpret_cast(&addr.V6().Handle()); + errc = bind(handle_, handle, sizeof(std::remove_reference_t)); + } + if (errc != 0) { + return system::FailWithCode("Failed to bind socket."); + } + auto [rc, new_port] = this->Port(); + if (!rc.OK()) { + return std::move(rc); + } + *port = new_port; + return Success(); + } + /** * \brief Send data, without error then all data should be sent. */ @@ -567,13 +655,9 @@ class TCPSocket { xgboost::collective::TCPSocket *out_conn); /** - * \brief Get the local host name. + * @brief Get the local host name. */ -inline std::string GetHostName() { - char buf[HOST_NAME_MAX]; - xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0); - return buf; -} +[[nodiscard]] Result GetHostName(std::string *p_out); } // namespace collective } // namespace xgboost diff --git a/include/xgboost/json.h b/include/xgboost/json.h index cb22e120e..b099d1c47 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -372,6 +372,19 @@ class Json { /*! \brief Use your own JsonWriter. */ static void Dump(Json json, JsonWriter* writer); + template + static Container Dump(Json json) { + if constexpr (std::is_same_v) { + std::string str; + Dump(json, &str); + return str; + } else { + std::vector str; + Dump(json, &str); + return str; + } + } + Json() = default; // number diff --git a/include/xgboost/string_view.h b/include/xgboost/string_view.h index 8b5bff7f6..ba0d9f368 100644 --- a/include/xgboost/string_view.h +++ b/include/xgboost/string_view.h @@ -29,7 +29,7 @@ struct StringView { public: constexpr StringView() = default; constexpr StringView(CharT const* str, std::size_t size) : str_{str}, size_{size} {} - explicit StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} + StringView(std::string const& str) : str_{str.c_str()}, size_{str.size()} {} // NOLINT constexpr StringView(CharT const* str) // NOLINT : str_{str}, size_{str == nullptr ? 0ul : Traits::length(str)} {} diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 2b4637339..d104cb231 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -11,9 +11,7 @@ #include #include -namespace xgboost { -namespace federated { - +namespace xgboost::federated { /** * @brief A wrapper around the gRPC client. */ @@ -112,6 +110,4 @@ class FederatedClient { int const rank_; uint64_t sequence_number_{}; }; - -} // namespace federated -} // namespace xgboost +} // namespace xgboost::federated diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 7acd8a829..996b433cb 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -9,9 +9,7 @@ #include "../../src/common/io.h" #include "federated_client.h" -namespace xgboost { -namespace collective { - +namespace xgboost::collective { /** * @brief A Federated Learning communicator class that handles collective communication. */ @@ -118,13 +116,13 @@ class FederatedCommunicator : public Communicator { * \brief Get if the communicator is distributed. * \return True. */ - bool IsDistributed() const override { return true; } + [[nodiscard]] bool IsDistributed() const override { return true; } /** * \brief Get if the communicator is federated. * \return True. */ - bool IsFederated() const override { return true; } + [[nodiscard]] bool IsFederated() const override { return true; } /** * \brief Perform in-place allgather. @@ -189,5 +187,4 @@ class FederatedCommunicator : public Communicator { private: std::unique_ptr client_{}; }; -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index c50bf1f35..ae42f6d28 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -11,9 +11,7 @@ #include "../../src/common/io.h" -namespace xgboost { -namespace federated { - +namespace xgboost::federated { grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, AllgatherReply* reply) { handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(), @@ -75,6 +73,4 @@ void RunInsecureServer(int port, int world_size) { server->Wait(); } - -} // namespace federated -} // namespace xgboost +} // namespace xgboost::federated diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index bd48d3599..6480adf03 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -115,9 +115,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) { // start socket xgboost::system::SocketStartup(); utils::Assert(all_links.size() == 0, "can only call Init once"); - this->host_uri = xgboost::collective::GetHostName(); + auto rc = xgboost::collective::GetHostName(&this->host_uri); + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } // get information from tracker - auto rc = this->ReConnectLinks(); + rc = this->ReConnectLinks(); if (rc.OK()) { return true; } @@ -406,13 +409,14 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!match) all_links.emplace_back(std::move(r)); } sock_listen.Close(); + this->parent_index = -1; // setup tree links and ring structure tree_links.plinks.clear(); for (auto &all_link : all_links) { utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive - all_link.sock.SetNonBlock(true); + CHECK(all_link.sock.NonBlocking(true).OK()); all_link.sock.SetKeepAlive(); if (rabit_enable_tcp_no_delay) { all_link.sock.SetNoDelay(); diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index e4c491c2b..6ac9ff58e 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -11,9 +11,7 @@ #include "../../plugin/federated/federated_communicator.h" #endif -namespace xgboost { -namespace collective { - +namespace xgboost::collective { thread_local std::unique_ptr Communicator::communicator_{new NoOpCommunicator()}; thread_local CommunicatorType Communicator::type_{}; @@ -57,6 +55,4 @@ void Communicator::Finalize() { communicator_.reset(new NoOpCommunicator()); } #endif - -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 78dc3d79b..8ca936ff3 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -3,6 +3,7 @@ */ #include "xgboost/collective/socket.h" +#include // for array #include // std::size_t #include // std::int32_t #include // std::memcpy, std::memset @@ -92,13 +93,18 @@ std::size_t TCPSocket::Recv(std::string *p_str) { conn = TCPSocket::Create(addr.Domain()); CHECK_EQ(static_cast(conn.Domain()), static_cast(addr.Domain())); - conn.SetNonBlock(true); + auto non_blocking = conn.NonBlocking(); + auto rc = conn.NonBlocking(true); + if (!rc.OK()) { + return Fail("Failed to set socket option.", std::move(rc)); + } Result last_error; - auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) { + auto log_failure = [&host, &last_error, port](Result err, char const *file, std::int32_t line) { last_error = std::move(err); LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line - << "): Failed to connect to:" << host << " Error:" << last_error.Report(); + << "): Failed to connect to:" << host << ":" << port + << " Error:" << last_error.Report(); }; for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { @@ -138,12 +144,9 @@ std::size_t TCPSocket::Recv(std::string *p_str) { continue; } - conn.SetNonBlock(false); - return Success(); - + return conn.NonBlocking(non_blocking); } else { - conn.SetNonBlock(false); - return Success(); + return conn.NonBlocking(non_blocking); } } @@ -152,4 +155,13 @@ std::size_t TCPSocket::Recv(std::string *p_str) { conn.Close(); return Fail(ss.str(), std::move(last_error)); } + +[[nodiscard]] Result GetHostName(std::string *p_out) { + std::array buf; + if (gethostname(&buf[0], HOST_NAME_MAX) != 0) { + return system::FailWithCode("Failed to get host name."); + } + *p_out = buf.data(); + return Success(); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index ddc73d1f2..07a7f52d0 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -73,4 +73,15 @@ TEST(Socket, Basic) { system::SocketFinalize(); } + +TEST(Socket, Bind) { + system::SocketStartup(); + auto any = SockAddrV4::InaddrAny().Addr(); + auto sock = TCPSocket::Create(SockDomain::kV4); + std::int32_t port{0}; + auto rc = sock.Bind(any, &port); + ASSERT_TRUE(rc.OK()); + ASSERT_NE(port, 0); + system::SocketFinalize(); +} } // namespace xgboost::collective diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 4d498ffd5..1d1319274 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -4,6 +4,7 @@ #include #include +#include // for back_inserter #include #include "../../../src/common/charconv.h" @@ -691,4 +692,16 @@ TEST(Json, TypeCheck) { ASSERT_NE(err.find("foo"), std::string::npos); } } + +TEST(Json, Dump) { + auto str = GetModelStr(); + auto jobj = Json::Load(str); + std::string result_s = Json::Dump(jobj); + + std::vector result_v = Json::Dump>(jobj); + ASSERT_EQ(result_s.size(), result_v.size()); + for (std::size_t i = 0; i < result_s.size(); ++i) { + ASSERT_EQ(result_s[i], result_v[i]); + } +} } // namespace xgboost