diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 565443fbc..32631442b 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -141,7 +141,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) { auto const& peer = peers[r]; - std::shared_ptr worker{TCPSocket::CreatePtr(comm.Domain())}; + auto worker = std::make_shared(); rc = std::move(rc) << [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); } << [&] { return worker->RecvTimeout(timeout); }; @@ -161,7 +161,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st } for (std::int32_t r = 0; r < comm.Rank(); ++r) { - auto peer = std::shared_ptr(TCPSocket::CreatePtr(comm.Domain())); + auto peer = std::make_shared(); rc = std::move(rc) << [&] { SockAddress addr; return listener->Accept(peer.get(), &addr); diff --git a/src/collective/socket.cc b/src/collective/socket.cc index c8629069f..aedddbcfb 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -118,7 +118,9 @@ std::size_t TCPSocket::Send(StringView str) { addr_len = sizeof(addr.V6().Handle()); } - conn = TCPSocket::Create(addr.Domain()); + if (conn.IsClosed()) { + conn = TCPSocket::Create(addr.Domain()); + } CHECK_EQ(static_cast(conn.Domain()), static_cast(addr.Domain())); auto non_blocking = conn.NonBlocking(); auto rc = conn.NonBlocking(true); diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc index d73191361..c24ddd275 100644 --- a/tests/cpp/collective/test_coll_c_api.cc +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -7,6 +7,7 @@ #include // for ""s #include // for thread +#include "../../../src/collective/allgather.h" // for RingAllgather #include "../../../src/collective/tracker.h" #include "test_worker.h" // for SocketTest #include "xgboost/json.h" // for Json @@ -19,8 +20,9 @@ class TrackerAPITest : public SocketTest {}; TEST_F(TrackerAPITest, CAPI) { TrackerHandle handle; Json config{Object{}}; + std::int32_t n_workers{2}; config["dmlc_communicator"] = String{"rabit"}; - config["n_workers"] = 2; + config["n_workers"] = n_workers; config["timeout"] = 1; auto config_str = Json::Dump(config); auto rc = XGTrackerCreate(config_str.c_str(), &handle); @@ -47,9 +49,21 @@ TEST_F(TrackerAPITest, CAPI) { ASSERT_NE(port, 0); std::vector workers; - using namespace std::chrono_literals; // NOLINT - for (std::int32_t r = 0; r < 2; ++r) { - workers.emplace_back([=] { WorkerForTest w{host, static_cast(port), 1s, 2, r}; }); + using std::chrono_literals::operator""s; + for (std::int32_t r = 0; r < n_workers; ++r) { + workers.emplace_back([=] { + WorkerForTest w{host, static_cast(port), 8s, n_workers, r}; + // basic test + std::vector data(w.Comm().World(), 0); + data[w.Comm().Rank()] = w.Comm().Rank(); + + auto rc = RingAllgather(w.Comm(), common::Span{data.data(), data.size()}); + SafeColl(rc); + + for (std::int32_t r = 0; r < w.Comm().World(); ++r) { + ASSERT_EQ(data[r], r); + } + }); } for (auto& w : workers) { w.join();