[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -1,13 +1,13 @@
import multiprocessing
import socket
import sys
import time
import numpy as np
import pytest
import xgboost as xgb
from xgboost import RabitTracker
from xgboost import collective
from xgboost import RabitTracker, build_info, federated
if sys.platform.startswith("win"):
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
@@ -37,3 +37,41 @@ def test_rabit_communicator():
for worker in workers:
worker.join()
assert worker.exitcode == 0
def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}',
federated_world_size=world_size,
federated_rank=rank):
assert xgb.collective.get_world_size() == world_size
assert xgb.collective.is_distributed()
assert xgb.collective.get_processor_name() == f'rank{rank}'
ret = xgb.collective.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
def test_federated_communicator():
if not build_info()["USE_FEDERATED"]:
pytest.skip("XGBoost not built with federated learning enabled")
port = 9091
world_size = 2
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")
workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_federated_worker,
args=(port, world_size, rank))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0
server.terminate()

View File

@@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
worker_env = tracker.worker_envs()
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast("test1234", 0)
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
from xgboost import rabit
from xgboost import collective
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not rabit.is_distributed()
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def local_test(worker_id):
with RabitContext(rabit_args):
with CommunicatorContext(**rabit_args):
a = 1
assert rabit.is_distributed()
assert collective.is_distributed()
a = np.array([a])
reduced = rabit.allreduce(a, rabit.Op.SUM)
reduced = collective.allreduce(a, collective.Op.SUM)
assert reduced[0] == n_workers
worker_id = np.array([worker_id])
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
reduced = collective.allreduce(worker_id, collective.Op.MAX)
assert reduced == n_workers - 1
return 1
@@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
with xgb.dask.CommunicatorContext(**args) as ctx:
task_id = ctx["DMLC_TASK_ID"]
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
rank = xgb.collective.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))

View File

@@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
class TestWithDask:
def test_dmatrix_binary(self, client: "Client") -> None:
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
X, y = tm.make_categorical(100, 4, 4, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
path = os.path.join(tmpdir, f"{rank}.bin")
Xy.save_binary(path)
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
path = os.path.join(tmpdir, f"{rank}.bin")
Xy = xgb.DMatrix(path)
assert Xy.num_row() == 100
@@ -1488,20 +1488,13 @@ class TestWithDask:
test = "--gtest_filter=Quantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
port_env = ''
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env["DMLC_TRACKER_URI"] = uri[1]
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
@@ -1543,8 +1536,8 @@ class TestWithDask:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
with xgb.dask.RabitContext(rabit_args):
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
with xgb.dask.CommunicatorContext(**rabit_args):
if worker_id == 0:
y = np.array([0.0, 0.0, 0.0])
x = np.array([[0.0]] * 3)
@@ -1686,12 +1679,12 @@ class TestWithDask:
n_workers = len(workers)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args):
with xgb.dask.CommunicatorContext(**rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
assert total[0] == kRows
futures = []