Initial support for IPv6 (#8225)
- Merge rabit socket into XGBoost. - Dask interface support. - Add test to the socket.
This commit is contained in:
@@ -121,6 +121,7 @@ if __name__ == "__main__":
|
||||
"python-package/xgboost/sklearn.py",
|
||||
"python-package/xgboost/spark",
|
||||
"python-package/xgboost/federated.py",
|
||||
"python-package/xgboost/testing.py",
|
||||
# tests
|
||||
"tests/python/test_config.py",
|
||||
"tests/python/test_spark/",
|
||||
|
||||
77
tests/cpp/collective/test_socket.cc
Normal file
77
tests/cpp/collective/test_socket.cc
Normal file
@@ -0,0 +1,77 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/socket.h>
|
||||
|
||||
#include <cerrno> // EADDRNOTAVAIL
|
||||
#include <fstream> // ifstream
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
TEST(Socket, Basic) {
|
||||
system::SocketStartup();
|
||||
|
||||
SockAddress addr{SockAddrV6::Loopback()};
|
||||
ASSERT_TRUE(addr.IsV6());
|
||||
addr = SockAddress{SockAddrV4::Loopback()};
|
||||
ASSERT_TRUE(addr.IsV4());
|
||||
|
||||
std::string msg{"Skipping IPv6 test"};
|
||||
|
||||
auto run_test = [msg](SockDomain domain) {
|
||||
auto server = TCPSocket::Create(domain);
|
||||
ASSERT_EQ(server.Domain(), domain);
|
||||
auto port = server.BindHost();
|
||||
server.Listen();
|
||||
|
||||
TCPSocket client;
|
||||
if (domain == SockDomain::kV4) {
|
||||
auto const& addr = SockAddrV4::Loopback().Addr();
|
||||
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
|
||||
} else {
|
||||
auto const& addr = SockAddrV6::Loopback().Addr();
|
||||
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
|
||||
// some environment (docker) has restricted network configuration.
|
||||
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
ASSERT_EQ(rc, std::errc{});
|
||||
}
|
||||
ASSERT_EQ(client.Domain(), domain);
|
||||
|
||||
auto accepted = server.Accept();
|
||||
StringView msg{"Hello world."};
|
||||
accepted.Send(msg);
|
||||
|
||||
std::string str;
|
||||
client.Recv(&str);
|
||||
ASSERT_EQ(StringView{str}, msg);
|
||||
};
|
||||
|
||||
run_test(SockDomain::kV4);
|
||||
|
||||
std::string path{"/sys/module/ipv6/parameters/disable"};
|
||||
if (FileExists(path)) {
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
std::string s_value;
|
||||
fin >> s_value;
|
||||
auto value = std::stoi(s_value);
|
||||
if (value != 0) {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
} else {
|
||||
GTEST_SKIP_(msg.c_str());
|
||||
}
|
||||
run_test(SockDomain::kV6);
|
||||
|
||||
system::SocketFinalize();
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -1,7 +1,6 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by XGBoost Contributors
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H
|
||||
#define XGBOOST_TESTS_CPP_FILESYSTEM_H
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2021-2022, XGBoost contributors.
|
||||
*/
|
||||
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/tree/hist/expand_entry.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -19,3 +23,4 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|
||||
|
||||
@@ -1,34 +1,37 @@
|
||||
from xgboost import RabitTracker
|
||||
import xgboost as xgb
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import sys
|
||||
import re
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker, testing
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
for k, v in worker_env.items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
ret = xgb.rabit.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.rabit.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_rabit_args
|
||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import rabit
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
@@ -55,12 +58,26 @@ def run_rabit_ops(client, n_workers):
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rabit_ops():
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
n_workers = 3
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**testing.skip_ipv6())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rabit_ops_ipv6():
|
||||
import dask
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
n_workers = 3
|
||||
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
|
||||
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
Reference in New Issue
Block a user