Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -1,16 +1,14 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker, build_info, federated
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||
from xgboost import testing as tm
|
||||
|
||||
|
||||
def run_rabit_worker(rabit_env, world_size):
|
||||
@@ -18,20 +16,21 @@ def run_rabit_worker(rabit_env, world_size):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == socket.gethostname()
|
||||
ret = xgb.collective.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
def test_rabit_communicator():
|
||||
def test_rabit_communicator() -> None:
|
||||
world_size = 2
|
||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start()
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_rabit_worker,
|
||||
args=(tracker.worker_envs(), world_size))
|
||||
worker = multiprocessing.Process(
|
||||
target=run_rabit_worker, args=(tracker.worker_args(), world_size)
|
||||
)
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
@@ -39,39 +38,44 @@ def test_rabit_communicator():
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
federated_world_size=world_size,
|
||||
federated_rank=rank):
|
||||
def run_federated_worker(port: int, world_size: int, rank: int) -> None:
|
||||
with xgb.collective.CommunicatorContext(
|
||||
dmlc_communicator="federated",
|
||||
federated_server_address=f"localhost:{port}",
|
||||
federated_world_size=world_size,
|
||||
federated_rank=rank,
|
||||
):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == f'rank{rank}'
|
||||
ret = xgb.collective.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
assert xgb.collective.get_processor_name() == f"rank:{rank}"
|
||||
bret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(bret) == "test1234"
|
||||
aret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(aret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.skip_win())
|
||||
def test_federated_communicator():
|
||||
if not build_info()["USE_FEDERATED"]:
|
||||
pytest.skip("XGBoost not built with federated learning enabled")
|
||||
|
||||
port = 9091
|
||||
world_size = 2
|
||||
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
|
||||
server.start()
|
||||
time.sleep(1)
|
||||
if not server.is_alive():
|
||||
tracker = multiprocessing.Process(
|
||||
target=federated.run_federated_server,
|
||||
kwargs={"port": port, "n_workers": world_size},
|
||||
)
|
||||
tracker.start()
|
||||
if not tracker.is_alive():
|
||||
raise Exception("Error starting Federated Learning server")
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_federated_worker,
|
||||
args=(port, world_size, rank))
|
||||
worker = multiprocessing.Process(
|
||||
target=run_federated_worker, args=(port, world_size, rank)
|
||||
)
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
server.terminate()
|
||||
|
||||
@@ -3,33 +3,33 @@ import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import HealthCheck, given, settings, strategies
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker, collective
|
||||
from xgboost import testing as tm
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||
tracker.start()
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_args()):
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.not_linux())
|
||||
def test_socket_error():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
env = tracker.worker_envs()
|
||||
env["DMLC_TRACKER_PORT"] = 0
|
||||
env["DMLC_WORKER_CONNECT_RETRY"] = 1
|
||||
with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"):
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2)
|
||||
tracker.start()
|
||||
env = tracker.worker_args()
|
||||
env["dmlc_tracker_port"] = 0
|
||||
env["dmlc_retry"] = 1
|
||||
with pytest.raises(ValueError, match="Failed to bootstrap the communication."):
|
||||
with xgb.collective.CommunicatorContext(**env):
|
||||
pass
|
||||
with pytest.raises(ValueError):
|
||||
tracker.free()
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
@@ -70,6 +70,40 @@ def test_rabit_ops():
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
def run_allreduce(client) -> None:
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
workers = tm.get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
n_workers = len(workers)
|
||||
|
||||
def local_test(worker_id: int) -> None:
|
||||
x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0)
|
||||
with CommunicatorContext(**rabit_args):
|
||||
k = np.asarray([1.0])
|
||||
for i in range(128):
|
||||
m = collective.allreduce(k, collective.Op.SUM)
|
||||
assert m == n_workers
|
||||
|
||||
y = collective.allreduce(x, collective.Op.SUM)
|
||||
np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers)))
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
results = client.gather(futures)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_allreduce() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
n_workers = 4
|
||||
for i in range(2):
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
for i in range(2):
|
||||
run_allreduce(client)
|
||||
|
||||
|
||||
def run_broadcast(client):
|
||||
from xgboost.dask import _get_dask_config, _get_rabit_args
|
||||
|
||||
@@ -109,6 +143,7 @@ def test_rabit_ops_ipv6():
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
@@ -133,3 +168,107 @@ def test_rank_assignment() -> None:
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
client.gather(futures)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_cluster():
|
||||
from distributed import LocalCluster
|
||||
|
||||
n_workers = 8
|
||||
with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster:
|
||||
yield cluster
|
||||
|
||||
|
||||
ops_strategy = strategies.lists(
|
||||
strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16))
|
||||
@settings(
|
||||
deadline=None,
|
||||
print_blob=True,
|
||||
max_examples=10,
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
)
|
||||
def test_ops_restart_comm(local_cluster, ops, size) -> None:
|
||||
from distributed import Client
|
||||
|
||||
def local_test(w: int, n_workers: int) -> None:
|
||||
a = np.arange(0, n_workers)
|
||||
with xgb.dask.CommunicatorContext(**args):
|
||||
for op in ops:
|
||||
if op == "broadcast":
|
||||
b = collective.broadcast(a, root=1)
|
||||
np.testing.assert_allclose(b, a)
|
||||
elif op == "allreduce_max":
|
||||
b = collective.allreduce(a, collective.Op.MAX)
|
||||
np.testing.assert_allclose(b, a)
|
||||
elif op == "allreduce_sum":
|
||||
b = collective.allreduce(a, collective.Op.SUM)
|
||||
np.testing.assert_allclose(a * n_workers, b)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
with Client(local_cluster) as client:
|
||||
workers = tm.get_client_workers(client)
|
||||
args = client.sync(
|
||||
xgb.dask._get_rabit_args,
|
||||
len(workers),
|
||||
None,
|
||||
client,
|
||||
)
|
||||
|
||||
workers = tm.get_client_workers(client)
|
||||
n_workers = len(workers)
|
||||
|
||||
futures = client.map(
|
||||
local_test, range(len(workers)), workers=workers, n_workers=n_workers
|
||||
)
|
||||
client.gather(futures)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_ops_reuse_comm(local_cluster) -> None:
|
||||
from distributed import Client
|
||||
|
||||
rng = np.random.default_rng(1994)
|
||||
n_examples = 10
|
||||
ops = rng.choice(
|
||||
["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples
|
||||
).tolist()
|
||||
|
||||
def local_test(w: int, n_workers: int) -> None:
|
||||
a = np.arange(0, n_workers)
|
||||
|
||||
with xgb.dask.CommunicatorContext(**args):
|
||||
for op in ops:
|
||||
if op == "broadcast":
|
||||
b = collective.broadcast(a, root=1)
|
||||
assert np.allclose(b, a)
|
||||
elif op == "allreduce_max":
|
||||
c = np.full_like(a, collective.get_rank())
|
||||
b = collective.allreduce(c, collective.Op.MAX)
|
||||
assert np.allclose(b, n_workers - 1), b
|
||||
elif op == "allreduce_sum":
|
||||
b = collective.allreduce(a, collective.Op.SUM)
|
||||
assert np.allclose(a * 8, b)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
with Client(local_cluster) as client:
|
||||
workers = tm.get_client_workers(client)
|
||||
args = client.sync(
|
||||
xgb.dask._get_rabit_args,
|
||||
len(workers),
|
||||
None,
|
||||
client,
|
||||
)
|
||||
|
||||
n_workers = len(workers)
|
||||
|
||||
futures = client.map(
|
||||
local_test, range(len(workers)), workers=workers, n_workers=n_workers
|
||||
)
|
||||
client.gather(futures)
|
||||
|
||||
@@ -8,19 +8,14 @@ import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import DataSplitMode
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.csv as pc
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
tm.no_arrow()["condition"] or tm.no_pandas()["condition"],
|
||||
reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"],
|
||||
)
|
||||
|
||||
dpath = "demo/data/"
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.csv as pc
|
||||
|
||||
|
||||
class TestArrowTable:
|
||||
|
||||
@@ -1098,9 +1098,10 @@ def test_pandas_input():
|
||||
np.testing.assert_equal(model.feature_names_in_, np.array(feature_names))
|
||||
|
||||
columns = list(train.columns)
|
||||
random.shuffle(columns)
|
||||
rng.shuffle(columns)
|
||||
df_incorrect = df[columns]
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
with pytest.raises(ValueError, match="feature_names mismatch"):
|
||||
model.predict(df_incorrect)
|
||||
|
||||
clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic")
|
||||
|
||||
Reference in New Issue
Block a user