[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost import collective
|
||||
from xgboost import RabitTracker, build_info, federated
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||
@@ -37,3 +37,41 @@ def test_rabit_communicator():
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
federated_world_size=world_size,
|
||||
federated_rank=rank):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == f'rank{rank}'
|
||||
ret = xgb.collective.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
def test_federated_communicator():
|
||||
if not build_info()["USE_FEDERATED"]:
|
||||
pytest.skip("XGBoost not built with federated learning enabled")
|
||||
|
||||
port = 9091
|
||||
world_size = 2
|
||||
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
|
||||
server.start()
|
||||
time.sleep(1)
|
||||
if not server.is_alive():
|
||||
raise Exception("Error starting Federated Learning server")
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_federated_worker,
|
||||
args=(port, world_size, rank))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
server.terminate()
|
||||
|
||||
@@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
for k, v in worker_env.items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
ret = xgb.rabit.broadcast("test1234", 0)
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not rabit.is_distributed()
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
with CommunicatorContext(**rabit_args):
|
||||
a = 1
|
||||
assert rabit.is_distributed()
|
||||
assert collective.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
@@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||
task_id = ctx["DMLC_TASK_ID"]
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
rank = xgb.collective.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
@@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
@@ -1488,20 +1488,13 @@ class TestWithDask:
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ''
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env["DMLC_TRACKER_URI"] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||
@@ -1543,8 +1536,8 @@ class TestWithDask:
|
||||
def get_score(config: Dict) -> float:
|
||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||
|
||||
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
if worker_id == 0:
|
||||
y = np.array([0.0, 0.0, 0.0])
|
||||
x = np.array([[0.0]] * 3)
|
||||
@@ -1686,12 +1679,12 @@ class TestWithDask:
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = []
|
||||
|
||||
Reference in New Issue
Block a user