xgboost/tests/python/test_tracker.py
Jiaming Yuan cfd2a9f872
Extract dask and spark test into distributed test. (#8395)
- Move test files.
- Run spark and dask separately to prevent conflicts.
- Gather common code into the testing module.
2022-10-28 16:24:32 +08:00

99 lines
3.1 KiB
Python

import re
import sys
import numpy as np
import pytest
import xgboost as xgb
from xgboost import RabitTracker
from xgboost import testing as tm
from xgboost import collective
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start(1)
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
ret = xgb.collective.broadcast("test1234", 0)
assert str(ret) == "test1234"
def run_rabit_ops(client, n_workers):
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def local_test(worker_id):
with CommunicatorContext(**rabit_args):
a = 1
assert collective.is_distributed()
a = np.array([a])
reduced = collective.allreduce(a, collective.Op.SUM)
assert reduced[0] == n_workers
worker_id = np.array([worker_id])
reduced = collective.allreduce(worker_id, collective.Op.MAX)
assert reduced == n_workers - 1
return 1
futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
assert sum(results) == n_workers
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops():
from distributed import Client, LocalCluster
n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster
n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
def local_test(worker_id):
with xgb.dask.CommunicatorContext(**args) as ctx:
task_id = ctx["DMLC_TASK_ID"]
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.collective.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))
with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = tm.get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)
futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)