Revamp the rabit implementation. (#10112)

This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
This commit is contained in:
Jiaming Yuan
2024-05-20 11:56:23 +08:00
committed by GitHub
parent ba9b4cb1ee
commit a5a58102e5
195 changed files with 2768 additions and 9234 deletions

View File

@@ -1,16 +1,14 @@
import multiprocessing
import socket
import sys
import time
from threading import Thread
import numpy as np
import pytest
import xgboost as xgb
from xgboost import RabitTracker, build_info, federated
if sys.platform.startswith("win"):
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
from xgboost import testing as tm
def run_rabit_worker(rabit_env, world_size):
@@ -18,20 +16,21 @@ def run_rabit_worker(rabit_env, world_size):
assert xgb.collective.get_world_size() == world_size
assert xgb.collective.is_distributed()
assert xgb.collective.get_processor_name() == socket.gethostname()
ret = xgb.collective.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
def test_rabit_communicator():
def test_rabit_communicator() -> None:
world_size = 2
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
tracker.start(world_size)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start()
workers = []
for _ in range(world_size):
worker = multiprocessing.Process(target=run_rabit_worker,
args=(tracker.worker_envs(), world_size))
worker = multiprocessing.Process(
target=run_rabit_worker, args=(tracker.worker_args(), world_size)
)
workers.append(worker)
worker.start()
for worker in workers:
@@ -39,39 +38,44 @@ def test_rabit_communicator():
assert worker.exitcode == 0
def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}',
federated_world_size=world_size,
federated_rank=rank):
def run_federated_worker(port: int, world_size: int, rank: int) -> None:
with xgb.collective.CommunicatorContext(
dmlc_communicator="federated",
federated_server_address=f"localhost:{port}",
federated_world_size=world_size,
federated_rank=rank,
):
assert xgb.collective.get_world_size() == world_size
assert xgb.collective.is_distributed()
assert xgb.collective.get_processor_name() == f'rank{rank}'
ret = xgb.collective.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
assert xgb.collective.get_processor_name() == f"rank:{rank}"
bret = xgb.collective.broadcast("test1234", 0)
assert str(bret) == "test1234"
aret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(aret, np.asarray([2, 4, 6]))
@pytest.mark.skipif(**tm.skip_win())
def test_federated_communicator():
if not build_info()["USE_FEDERATED"]:
pytest.skip("XGBoost not built with federated learning enabled")
port = 9091
world_size = 2
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
server.start()
time.sleep(1)
if not server.is_alive():
tracker = multiprocessing.Process(
target=federated.run_federated_server,
kwargs={"port": port, "n_workers": world_size},
)
tracker.start()
if not tracker.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_federated_worker,
args=(port, world_size, rank))
worker = multiprocessing.Process(
target=run_federated_worker, args=(port, world_size, rank)
)
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0
server.terminate()

View File

@@ -3,33 +3,33 @@ import sys
import numpy as np
import pytest
from hypothesis import HealthCheck, given, settings, strategies
import xgboost as xgb
from xgboost import RabitTracker, collective
from xgboost import testing as tm
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
tracker.start()
with xgb.collective.CommunicatorContext(**tracker.worker_args()):
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
@pytest.mark.skipif(**tm.not_linux())
def test_socket_error():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
env = tracker.worker_envs()
env["DMLC_TRACKER_PORT"] = 0
env["DMLC_WORKER_CONNECT_RETRY"] = 1
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2)
tracker.start()
env = tracker.worker_args()
env["dmlc_tracker_port"] = 0
env["dmlc_retry"] = 1
with pytest.raises(ValueError, match="Failed to bootstrap the communication."):
with xgb.collective.CommunicatorContext(**env):
pass
with pytest.raises(ValueError):
tracker.free()
def run_rabit_ops(client, n_workers):
@@ -70,6 +70,40 @@ def test_rabit_ops():
run_rabit_ops(client, n_workers)
def run_allreduce(client) -> None:
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
n_workers = len(workers)
def local_test(worker_id: int) -> None:
x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0)
with CommunicatorContext(**rabit_args):
k = np.asarray([1.0])
for i in range(128):
m = collective.allreduce(k, collective.Op.SUM)
assert m == n_workers
y = collective.allreduce(x, collective.Op.SUM)
np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers)))
futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
@pytest.mark.skipif(**tm.no_dask())
def test_allreduce() -> None:
from distributed import Client, LocalCluster
n_workers = 4
for i in range(2):
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
for i in range(2):
run_allreduce(client)
def run_broadcast(client):
from xgboost.dask import _get_dask_config, _get_rabit_args
@@ -109,6 +143,7 @@ def test_rabit_ops_ipv6():
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**tm.no_dask())
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
@@ -133,3 +168,107 @@ def test_rank_assignment() -> None:
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)
@pytest.fixture
def local_cluster():
from distributed import LocalCluster
n_workers = 8
with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster:
yield cluster
ops_strategy = strategies.lists(
strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"])
)
@pytest.mark.skipif(**tm.no_dask())
@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16))
@settings(
deadline=None,
print_blob=True,
max_examples=10,
suppress_health_check=[HealthCheck.function_scoped_fixture],
)
def test_ops_restart_comm(local_cluster, ops, size) -> None:
from distributed import Client
def local_test(w: int, n_workers: int) -> None:
a = np.arange(0, n_workers)
with xgb.dask.CommunicatorContext(**args):
for op in ops:
if op == "broadcast":
b = collective.broadcast(a, root=1)
np.testing.assert_allclose(b, a)
elif op == "allreduce_max":
b = collective.allreduce(a, collective.Op.MAX)
np.testing.assert_allclose(b, a)
elif op == "allreduce_sum":
b = collective.allreduce(a, collective.Op.SUM)
np.testing.assert_allclose(a * n_workers, b)
else:
raise ValueError()
with Client(local_cluster) as client:
workers = tm.get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
workers = tm.get_client_workers(client)
n_workers = len(workers)
futures = client.map(
local_test, range(len(workers)), workers=workers, n_workers=n_workers
)
client.gather(futures)
@pytest.mark.skipif(**tm.no_dask())
def test_ops_reuse_comm(local_cluster) -> None:
from distributed import Client
rng = np.random.default_rng(1994)
n_examples = 10
ops = rng.choice(
["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples
).tolist()
def local_test(w: int, n_workers: int) -> None:
a = np.arange(0, n_workers)
with xgb.dask.CommunicatorContext(**args):
for op in ops:
if op == "broadcast":
b = collective.broadcast(a, root=1)
assert np.allclose(b, a)
elif op == "allreduce_max":
c = np.full_like(a, collective.get_rank())
b = collective.allreduce(c, collective.Op.MAX)
assert np.allclose(b, n_workers - 1), b
elif op == "allreduce_sum":
b = collective.allreduce(a, collective.Op.SUM)
assert np.allclose(a * 8, b)
else:
raise ValueError()
with Client(local_cluster) as client:
workers = tm.get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
n_workers = len(workers)
futures = client.map(
local_test, range(len(workers)), workers=workers, n_workers=n_workers
)
client.gather(futures)

View File

@@ -8,19 +8,14 @@ import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import DataSplitMode
try:
import pandas as pd
import pyarrow as pa
import pyarrow.csv as pc
except ImportError:
pass
pytestmark = pytest.mark.skipif(
tm.no_arrow()["condition"] or tm.no_pandas()["condition"],
reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"],
)
dpath = "demo/data/"
import pandas as pd
import pyarrow as pa
import pyarrow.csv as pc
class TestArrowTable:

View File

@@ -1098,9 +1098,10 @@ def test_pandas_input():
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
columns = list(train.columns)
random.shuffle(columns)
rng.shuffle(columns)
df_incorrect = df[columns]
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="feature_names mismatch"):
model.predict(df_incorrect)
clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")