Initial support for IPv6 (#8225)

- Merge rabit socket into XGBoost.
- Dask interface support.
- Add test to the socket.
This commit is contained in:
Jiaming Yuan
2022-09-21 18:06:50 +08:00
committed by GitHub
parent 7d43e74e71
commit b791446623
17 changed files with 924 additions and 595 deletions

View File

@@ -121,6 +121,7 @@ if __name__ == "__main__":
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark",
"python-package/xgboost/federated.py",
"python-package/xgboost/testing.py",
# tests
"tests/python/test_config.py",
"tests/python/test_spark/",

View File

@@ -0,0 +1,77 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>
#include <cerrno> // EADDRNOTAVAIL
#include <fstream> // ifstream
#include <system_error> // std::error_code, std::system_category
#include "../helpers.h"
namespace xgboost {
namespace collective {
TEST(Socket, Basic) {
system::SocketStartup();
SockAddress addr{SockAddrV6::Loopback()};
ASSERT_TRUE(addr.IsV6());
addr = SockAddress{SockAddrV4::Loopback()};
ASSERT_TRUE(addr.IsV4());
std::string msg{"Skipping IPv6 test"};
auto run_test = [msg](SockDomain domain) {
auto server = TCPSocket::Create(domain);
ASSERT_EQ(server.Domain(), domain);
auto port = server.BindHost();
server.Listen();
TCPSocket client;
if (domain == SockDomain::kV4) {
auto const& addr = SockAddrV4::Loopback().Addr();
ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{});
} else {
auto const& addr = SockAddrV6::Loopback().Addr();
auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client);
// some environment (docker) has restricted network configuration.
if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) {
GTEST_SKIP_(msg.c_str());
}
ASSERT_EQ(rc, std::errc{});
}
ASSERT_EQ(client.Domain(), domain);
auto accepted = server.Accept();
StringView msg{"Hello world."};
accepted.Send(msg);
std::string str;
client.Recv(&str);
ASSERT_EQ(StringView{str}, msg);
};
run_test(SockDomain::kV4);
std::string path{"/sys/module/ipv6/parameters/disable"};
if (FileExists(path)) {
std::ifstream fin(path);
if (!fin) {
GTEST_SKIP_(msg.c_str());
}
std::string s_value;
fin >> s_value;
auto value = std::stoi(s_value);
if (value != 0) {
GTEST_SKIP_(msg.c_str());
}
} else {
GTEST_SKIP_(msg.c_str());
}
run_test(SockDomain::kV6);
system::SocketFinalize();
}
} // namespace collective
} // namespace xgboost

View File

@@ -1,7 +1,6 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H
#define XGBOOST_TESTS_CPP_FILESYSTEM_H

View File

@@ -1,8 +1,12 @@
/*!
* Copyright 2021-2022, XGBoost contributors.
*/
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/tree_model.h>
#include <vector>
#include "../../../src/tree/hist/expand_entry.h"
namespace xgboost {
@@ -19,3 +23,4 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntr
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_

View File

@@ -1,34 +1,37 @@
from xgboost import RabitTracker
import xgboost as xgb
import re
import sys
import numpy as np
import pytest
import testing as tm
import numpy as np
import sys
import re
import xgboost as xgb
from xgboost import RabitTracker, testing
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=1)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_rabit_args
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
from xgboost import rabit
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
@@ -55,12 +58,26 @@ def run_rabit_ops(client, n_workers):
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**testing.skip_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster
n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers