[coll] Reduce the amount of open files (socket). (#10693)
Reduce the chance of hitting `Failed to call `socket`: Too many open files`.
This commit is contained in:
parent
d414fdf2e7
commit
43704549a2
@ -141,7 +141,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
|
|
||||||
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
|
for (std::int32_t r = (comm.Rank() + 1); r < comm.World(); ++r) {
|
||||||
auto const& peer = peers[r];
|
auto const& peer = peers[r];
|
||||||
std::shared_ptr<TCPSocket> worker{TCPSocket::CreatePtr(comm.Domain())};
|
auto worker = std::make_shared<TCPSocket>();
|
||||||
rc = std::move(rc)
|
rc = std::move(rc)
|
||||||
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
|
<< [&] { return Connect(peer.host, peer.port, retry, timeout, worker.get()); }
|
||||||
<< [&] { return worker->RecvTimeout(timeout); };
|
<< [&] { return worker->RecvTimeout(timeout); };
|
||||||
@ -161,7 +161,7 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
for (std::int32_t r = 0; r < comm.Rank(); ++r) {
|
||||||
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
|
auto peer = std::make_shared<TCPSocket>();
|
||||||
rc = std::move(rc) << [&] {
|
rc = std::move(rc) << [&] {
|
||||||
SockAddress addr;
|
SockAddress addr;
|
||||||
return listener->Accept(peer.get(), &addr);
|
return listener->Accept(peer.get(), &addr);
|
||||||
|
|||||||
@ -118,7 +118,9 @@ std::size_t TCPSocket::Send(StringView str) {
|
|||||||
addr_len = sizeof(addr.V6().Handle());
|
addr_len = sizeof(addr.V6().Handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
conn = TCPSocket::Create(addr.Domain());
|
if (conn.IsClosed()) {
|
||||||
|
conn = TCPSocket::Create(addr.Domain());
|
||||||
|
}
|
||||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||||
auto non_blocking = conn.NonBlocking();
|
auto non_blocking = conn.NonBlocking();
|
||||||
auto rc = conn.NonBlocking(true);
|
auto rc = conn.NonBlocking(true);
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
#include <chrono> // for ""s
|
#include <chrono> // for ""s
|
||||||
#include <thread> // for thread
|
#include <thread> // for thread
|
||||||
|
|
||||||
|
#include "../../../src/collective/allgather.h" // for RingAllgather
|
||||||
#include "../../../src/collective/tracker.h"
|
#include "../../../src/collective/tracker.h"
|
||||||
#include "test_worker.h" // for SocketTest
|
#include "test_worker.h" // for SocketTest
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
@ -19,8 +20,9 @@ class TrackerAPITest : public SocketTest {};
|
|||||||
TEST_F(TrackerAPITest, CAPI) {
|
TEST_F(TrackerAPITest, CAPI) {
|
||||||
TrackerHandle handle;
|
TrackerHandle handle;
|
||||||
Json config{Object{}};
|
Json config{Object{}};
|
||||||
|
std::int32_t n_workers{2};
|
||||||
config["dmlc_communicator"] = String{"rabit"};
|
config["dmlc_communicator"] = String{"rabit"};
|
||||||
config["n_workers"] = 2;
|
config["n_workers"] = n_workers;
|
||||||
config["timeout"] = 1;
|
config["timeout"] = 1;
|
||||||
auto config_str = Json::Dump(config);
|
auto config_str = Json::Dump(config);
|
||||||
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
|
||||||
@ -47,9 +49,21 @@ TEST_F(TrackerAPITest, CAPI) {
|
|||||||
ASSERT_NE(port, 0);
|
ASSERT_NE(port, 0);
|
||||||
|
|
||||||
std::vector<std::thread> workers;
|
std::vector<std::thread> workers;
|
||||||
using namespace std::chrono_literals; // NOLINT
|
using std::chrono_literals::operator""s;
|
||||||
for (std::int32_t r = 0; r < 2; ++r) {
|
for (std::int32_t r = 0; r < n_workers; ++r) {
|
||||||
workers.emplace_back([=] { WorkerForTest w{host, static_cast<std::int32_t>(port), 1s, 2, r}; });
|
workers.emplace_back([=] {
|
||||||
|
WorkerForTest w{host, static_cast<std::int32_t>(port), 8s, n_workers, r};
|
||||||
|
// basic test
|
||||||
|
std::vector<std::int32_t> data(w.Comm().World(), 0);
|
||||||
|
data[w.Comm().Rank()] = w.Comm().Rank();
|
||||||
|
|
||||||
|
auto rc = RingAllgather(w.Comm(), common::Span{data.data(), data.size()});
|
||||||
|
SafeColl(rc);
|
||||||
|
|
||||||
|
for (std::int32_t r = 0; r < w.Comm().World(); ++r) {
|
||||||
|
ASSERT_EQ(data[r], r);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
for (auto& w : workers) {
|
for (auto& w : workers) {
|
||||||
w.join();
|
w.join();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user