xgboost/tests/python/test_tracker.py
Rong Ou 668b8a0ea4
[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
2022-10-05 14:39:01 -08:00

102 lines
3.2 KiB
Python

import re
import sys
import numpy as np
import pytest
import testing as tm
import xgboost as xgb
from xgboost import RabitTracker, testing
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from test_with_dask import _get_client_workers
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
from xgboost import collective
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def local_test(worker_id):
with CommunicatorContext(**rabit_args):
a = 1
assert collective.is_distributed()
a = np.array([a])
reduced = collective.allreduce(a, collective.Op.SUM)
assert reduced[0] == n_workers
worker_id = np.array([worker_id])
reduced = collective.allreduce(worker_id, collective.Op.MAX)
assert reduced == n_workers - 1
return 1
futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
assert sum(results) == n_workers
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**testing.skip_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster
n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers
def local_test(worker_id):
with xgb.dask.CommunicatorContext(**args) as ctx:
task_id = ctx["DMLC_TASK_ID"]
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.collective.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)