[rabit] Improved connection handling. (#9531)
- Enable timeout. - Report connection error from the system. - Handle retry for both tracker connection and peer connection.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
@@ -10,8 +10,7 @@
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
TEST(Socket, Basic) {
|
||||
system::SocketStartup();
|
||||
|
||||
@@ -31,15 +30,16 @@ TEST(Socket, Basic) {
|
||||
TCPSocket client;
|
||||
if (domain == SockDomain::kV4) {
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
|
||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
} else {
|
||||
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
|
||||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client);
|
||||
// some environment (docker) has restricted network configuration.
|
||||
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
||||
if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
ASSERT_EQ(rc, std::errc{});
|
||||
ASSERT_EQ(rc, Success()) << rc.Report();
|
||||
}
|
||||
ASSERT_EQ(client.Domain(), domain);
|
||||
|
||||
@@ -73,5 +73,4 @@ TEST(Socket, Basic) {
|
||||
|
||||
system::SocketFinalize();
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -20,6 +20,18 @@ def test_rabit_tracker():
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.not_linux())
|
||||
def test_socket_error():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
env = tracker.worker_envs()
|
||||
env["DMLC_TRACKER_PORT"] = 0
|
||||
env["DMLC_WORKER_CONNECT_RETRY"] = 1
|
||||
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
|
||||
with xgb.collective.CommunicatorContext(**env):
|
||||
pass
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
@@ -58,6 +70,32 @@ def test_rabit_ops():
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
def run_broadcast(client):
|
||||
from xgboost.dask import _get_dask_config, _get_rabit_args
|
||||
|
||||
workers = tm.get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
|
||||
def local_test(worker_id):
|
||||
with collective.CommunicatorContext(**rabit_args):
|
||||
res = collective.broadcast(17, 0)
|
||||
return res
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
results = client.gather(futures)
|
||||
np.testing.assert_allclose(np.array(results), 17)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_broadcast():
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
n_workers = 3
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_broadcast(client)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_ipv6())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rabit_ops_ipv6():
|
||||
|
||||
Reference in New Issue
Block a user