[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
for k, v in worker_env.items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
ret = xgb.rabit.broadcast("test1234", 0)
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not rabit.is_distributed()
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
with CommunicatorContext(**rabit_args):
|
||||
a = 1
|
||||
assert rabit.is_distributed()
|
||||
assert collective.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
@@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||
task_id = ctx["DMLC_TASK_ID"]
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
rank = xgb.collective.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
Reference in New Issue
Block a user