This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
275 lines
8.7 KiB
Python
275 lines
8.7 KiB
Python
import re
|
|
import sys
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from hypothesis import HealthCheck, given, settings, strategies
|
|
|
|
import xgboost as xgb
|
|
from xgboost import RabitTracker, collective
|
|
from xgboost import testing as tm
|
|
|
|
|
|
def test_rabit_tracker():
|
|
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
|
tracker.start()
|
|
with xgb.collective.CommunicatorContext(**tracker.worker_args()):
|
|
ret = xgb.collective.broadcast("test1234", 0)
|
|
assert str(ret) == "test1234"
|
|
|
|
|
|
@pytest.mark.skipif(**tm.not_linux())
|
|
def test_socket_error():
|
|
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2)
|
|
tracker.start()
|
|
env = tracker.worker_args()
|
|
env["dmlc_tracker_port"] = 0
|
|
env["dmlc_retry"] = 1
|
|
with pytest.raises(ValueError, match="Failed to bootstrap the communication."):
|
|
with xgb.collective.CommunicatorContext(**env):
|
|
pass
|
|
with pytest.raises(ValueError):
|
|
tracker.free()
|
|
|
|
|
|
def run_rabit_ops(client, n_workers):
|
|
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
|
|
|
workers = tm.get_client_workers(client)
|
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
|
assert not collective.is_distributed()
|
|
n_workers_from_dask = len(workers)
|
|
assert n_workers == n_workers_from_dask
|
|
|
|
def local_test(worker_id):
|
|
with CommunicatorContext(**rabit_args):
|
|
a = 1
|
|
assert collective.is_distributed()
|
|
a = np.array([a])
|
|
reduced = collective.allreduce(a, collective.Op.SUM)
|
|
assert reduced[0] == n_workers
|
|
|
|
worker_id = np.array([worker_id])
|
|
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
|
assert reduced == n_workers - 1
|
|
|
|
return 1
|
|
|
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
|
results = client.gather(futures)
|
|
assert sum(results) == n_workers
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_rabit_ops():
|
|
from distributed import Client, LocalCluster
|
|
|
|
n_workers = 3
|
|
with LocalCluster(n_workers=n_workers) as cluster:
|
|
with Client(cluster) as client:
|
|
run_rabit_ops(client, n_workers)
|
|
|
|
|
|
def run_allreduce(client) -> None:
|
|
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
|
|
|
workers = tm.get_client_workers(client)
|
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
|
n_workers = len(workers)
|
|
|
|
def local_test(worker_id: int) -> None:
|
|
x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0)
|
|
with CommunicatorContext(**rabit_args):
|
|
k = np.asarray([1.0])
|
|
for i in range(128):
|
|
m = collective.allreduce(k, collective.Op.SUM)
|
|
assert m == n_workers
|
|
|
|
y = collective.allreduce(x, collective.Op.SUM)
|
|
np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers)))
|
|
|
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
|
results = client.gather(futures)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_allreduce() -> None:
|
|
from distributed import Client, LocalCluster
|
|
|
|
n_workers = 4
|
|
for i in range(2):
|
|
with LocalCluster(n_workers=n_workers) as cluster:
|
|
with Client(cluster) as client:
|
|
for i in range(2):
|
|
run_allreduce(client)
|
|
|
|
|
|
def run_broadcast(client):
|
|
from xgboost.dask import _get_dask_config, _get_rabit_args
|
|
|
|
workers = tm.get_client_workers(client)
|
|
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
|
|
|
def local_test(worker_id):
|
|
with collective.CommunicatorContext(**rabit_args):
|
|
res = collective.broadcast(17, 0)
|
|
return res
|
|
|
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
|
results = client.gather(futures)
|
|
np.testing.assert_allclose(np.array(results), 17)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_broadcast():
|
|
from distributed import Client, LocalCluster
|
|
|
|
n_workers = 3
|
|
with LocalCluster(n_workers=n_workers) as cluster:
|
|
with Client(cluster) as client:
|
|
run_broadcast(client)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_ipv6())
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_rabit_ops_ipv6():
|
|
import dask
|
|
from distributed import Client, LocalCluster
|
|
|
|
n_workers = 3
|
|
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
|
|
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
|
|
with Client(cluster) as client:
|
|
run_rabit_ops(client, n_workers)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_rank_assignment() -> None:
|
|
from distributed import Client, LocalCluster
|
|
|
|
def local_test(worker_id):
|
|
with xgb.dask.CommunicatorContext(**args) as ctx:
|
|
task_id = ctx["DMLC_TASK_ID"]
|
|
matched = re.search(".*-([0-9]).*", task_id)
|
|
rank = xgb.collective.get_rank()
|
|
# As long as the number of workers is lesser than 10, rank and worker id
|
|
# should be the same
|
|
assert rank == int(matched.group(1))
|
|
|
|
with LocalCluster(n_workers=8) as cluster:
|
|
with Client(cluster) as client:
|
|
workers = tm.get_client_workers(client)
|
|
args = client.sync(
|
|
xgb.dask._get_rabit_args,
|
|
len(workers),
|
|
None,
|
|
client,
|
|
)
|
|
|
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
|
client.gather(futures)
|
|
|
|
|
|
@pytest.fixture
|
|
def local_cluster():
|
|
from distributed import LocalCluster
|
|
|
|
n_workers = 8
|
|
with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster:
|
|
yield cluster
|
|
|
|
|
|
ops_strategy = strategies.lists(
|
|
strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"])
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16))
|
|
@settings(
|
|
deadline=None,
|
|
print_blob=True,
|
|
max_examples=10,
|
|
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
)
|
|
def test_ops_restart_comm(local_cluster, ops, size) -> None:
|
|
from distributed import Client
|
|
|
|
def local_test(w: int, n_workers: int) -> None:
|
|
a = np.arange(0, n_workers)
|
|
with xgb.dask.CommunicatorContext(**args):
|
|
for op in ops:
|
|
if op == "broadcast":
|
|
b = collective.broadcast(a, root=1)
|
|
np.testing.assert_allclose(b, a)
|
|
elif op == "allreduce_max":
|
|
b = collective.allreduce(a, collective.Op.MAX)
|
|
np.testing.assert_allclose(b, a)
|
|
elif op == "allreduce_sum":
|
|
b = collective.allreduce(a, collective.Op.SUM)
|
|
np.testing.assert_allclose(a * n_workers, b)
|
|
else:
|
|
raise ValueError()
|
|
|
|
with Client(local_cluster) as client:
|
|
workers = tm.get_client_workers(client)
|
|
args = client.sync(
|
|
xgb.dask._get_rabit_args,
|
|
len(workers),
|
|
None,
|
|
client,
|
|
)
|
|
|
|
workers = tm.get_client_workers(client)
|
|
n_workers = len(workers)
|
|
|
|
futures = client.map(
|
|
local_test, range(len(workers)), workers=workers, n_workers=n_workers
|
|
)
|
|
client.gather(futures)
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_ops_reuse_comm(local_cluster) -> None:
|
|
from distributed import Client
|
|
|
|
rng = np.random.default_rng(1994)
|
|
n_examples = 10
|
|
ops = rng.choice(
|
|
["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples
|
|
).tolist()
|
|
|
|
def local_test(w: int, n_workers: int) -> None:
|
|
a = np.arange(0, n_workers)
|
|
|
|
with xgb.dask.CommunicatorContext(**args):
|
|
for op in ops:
|
|
if op == "broadcast":
|
|
b = collective.broadcast(a, root=1)
|
|
assert np.allclose(b, a)
|
|
elif op == "allreduce_max":
|
|
c = np.full_like(a, collective.get_rank())
|
|
b = collective.allreduce(c, collective.Op.MAX)
|
|
assert np.allclose(b, n_workers - 1), b
|
|
elif op == "allreduce_sum":
|
|
b = collective.allreduce(a, collective.Op.SUM)
|
|
assert np.allclose(a * 8, b)
|
|
else:
|
|
raise ValueError()
|
|
|
|
with Client(local_cluster) as client:
|
|
workers = tm.get_client_workers(client)
|
|
args = client.sync(
|
|
xgb.dask._get_rabit_args,
|
|
len(workers),
|
|
None,
|
|
client,
|
|
)
|
|
|
|
n_workers = len(workers)
|
|
|
|
futures = client.map(
|
|
local_test, range(len(workers)), workers=workers, n_workers=n_workers
|
|
)
|
|
client.gather(futures)
|