[coll] Use loky for tests. (#10676)

This makes the tests easier to run and debug. In addition, they can now work on Windows as
well.
This commit is contained in:
Jiaming Yuan 2024-08-03 07:33:42 +08:00 committed by GitHub
parent a185b693dc
commit a269055b2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,7 @@
import re import re
import sys import sys
from functools import partial, update_wrapper
from typing import Dict, Union
import numpy as np import numpy as np
import pytest import pytest
@ -13,8 +15,8 @@ from xgboost import testing as tm
def test_rabit_tracker(): def test_rabit_tracker():
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
tracker.start() tracker.start()
with xgb.collective.CommunicatorContext(**tracker.worker_args()): with collective.CommunicatorContext(**tracker.worker_args()):
ret = xgb.collective.broadcast("test1234", 0) ret = collective.broadcast("test1234", 0)
assert str(ret) == "test1234" assert str(ret) == "test1234"
@ -26,7 +28,7 @@ def test_socket_error():
env["dmlc_tracker_port"] = 0 env["dmlc_tracker_port"] = 0
env["dmlc_retry"] = 1 env["dmlc_retry"] = 1
with pytest.raises(ValueError, match="Failed to bootstrap the communication."): with pytest.raises(ValueError, match="Failed to bootstrap the communication."):
with xgb.collective.CommunicatorContext(**env): with collective.CommunicatorContext(**env):
pass pass
with pytest.raises(ValueError): with pytest.raises(ValueError):
tracker.free() tracker.free()
@ -70,16 +72,15 @@ def test_rabit_ops():
run_rabit_ops(client, n_workers) run_rabit_ops(client, n_workers)
def run_allreduce(client) -> None:
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
workers = tm.get_client_workers(client) def run_allreduce(pool, n_workers: int) -> None:
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
n_workers = len(workers) tracker.start()
args = tracker.worker_args()
def local_test(worker_id: int) -> None: def local_test(worker_id: int, rabit_args: Dict[str, Union[str, int]]) -> None:
x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0) x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0)
with CommunicatorContext(**rabit_args): with collective.CommunicatorContext(**rabit_args):
k = np.asarray([1.0]) k = np.asarray([1.0])
for i in range(128): for i in range(128):
m = collective.allreduce(k, collective.Op.SUM) m = collective.allreduce(k, collective.Op.SUM)
@ -88,46 +89,48 @@ def run_allreduce(client) -> None:
y = collective.allreduce(x, collective.Op.SUM) y = collective.allreduce(x, collective.Op.SUM)
np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers))) np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers)))
futures = client.map(local_test, range(len(workers)), workers=workers) fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
results = client.gather(futures) results = pool.map(fn, range(n_workers))
for r in results:
assert r is None
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_loky())
def test_allreduce() -> None: def test_allreduce() -> None:
from distributed import Client, LocalCluster from loky import get_reusable_executor
n_workers = 4 n_workers = 4
for i in range(2): n_trials = 2
with LocalCluster(n_workers=n_workers) as cluster: for _ in range(n_trials):
with Client(cluster) as client: with get_reusable_executor(max_workers=n_workers) as pool:
for i in range(2): run_allreduce(pool, n_workers)
run_allreduce(client)
def run_broadcast(client): def run_broadcast(pool, n_workers: int) -> None:
from xgboost.dask import _get_dask_config, _get_rabit_args tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
tracker.start()
args = tracker.worker_args()
workers = tm.get_client_workers(client) def local_test(worker_id: int, rabit_args: Dict[str, Union[str, int]]):
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
def local_test(worker_id):
with collective.CommunicatorContext(**rabit_args): with collective.CommunicatorContext(**rabit_args):
res = collective.broadcast(17, 0) res = collective.broadcast(17, 0)
return res return res
futures = client.map(local_test, range(len(workers)), workers=workers) fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
results = client.gather(futures) results = pool.map(fn, range(n_workers))
np.testing.assert_allclose(np.array(results), 17) np.testing.assert_allclose(np.array(list(results)), 17)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_loky())
def test_broadcast(): def test_broadcast():
from distributed import Client, LocalCluster from loky import get_reusable_executor
n_workers = 3 n_workers = 4
with LocalCluster(n_workers=n_workers) as cluster: n_trials = 2
with Client(cluster) as client:
run_broadcast(client) for _ in range(n_trials):
with get_reusable_executor(max_workers=n_workers) as pool:
run_broadcast(pool, n_workers)
@pytest.mark.skipif(**tm.no_ipv6()) @pytest.mark.skipif(**tm.no_ipv6())
@ -151,7 +154,7 @@ def test_rank_assignment() -> None:
with xgb.dask.CommunicatorContext(**args) as ctx: with xgb.dask.CommunicatorContext(**args) as ctx:
task_id = ctx["DMLC_TASK_ID"] task_id = ctx["DMLC_TASK_ID"]
matched = re.search(".*-([0-9]).*", task_id) matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.collective.get_rank() rank = collective.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id # As long as the number of workers is lesser than 10, rank and worker id
# should be the same # should be the same
assert rank == int(matched.group(1)) assert rank == int(matched.group(1))
@ -170,21 +173,12 @@ def test_rank_assignment() -> None:
client.gather(futures) client.gather(futures)
@pytest.fixture
def local_cluster():
from distributed import LocalCluster
n_workers = 8
with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster:
yield cluster
ops_strategy = strategies.lists( ops_strategy = strategies.lists(
strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"]) strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"])
) )
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_loky())
@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16)) @given(ops=ops_strategy, size=strategies.integers(2**4, 2**16))
@settings( @settings(
deadline=None, deadline=None,
@ -192,12 +186,14 @@ ops_strategy = strategies.lists(
max_examples=10, max_examples=10,
suppress_health_check=[HealthCheck.function_scoped_fixture], suppress_health_check=[HealthCheck.function_scoped_fixture],
) )
def test_ops_restart_comm(local_cluster, ops, size) -> None: def test_ops_restart_comm(ops, size) -> None:
from distributed import Client from loky import get_reusable_executor
def local_test(w: int, n_workers: int) -> None: n_workers = 8
def local_test(w: int, rabit_args: Dict[str, Union[str, int]]) -> None:
a = np.arange(0, n_workers) a = np.arange(0, n_workers)
with xgb.dask.CommunicatorContext(**args): with collective.CommunicatorContext(**rabit_args):
for op in ops: for op in ops:
if op == "broadcast": if op == "broadcast":
b = collective.broadcast(a, root=1) b = collective.broadcast(a, root=1)
@ -211,27 +207,21 @@ def test_ops_restart_comm(local_cluster, ops, size) -> None:
else: else:
raise ValueError() raise ValueError()
with Client(local_cluster) as client: with get_reusable_executor(max_workers=n_workers) as pool:
workers = tm.get_client_workers(client) tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
args = client.sync( tracker.start()
xgb.dask._get_rabit_args, args = tracker.worker_args()
len(workers),
None,
client,
)
workers = tm.get_client_workers(client) fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
n_workers = len(workers) results = pool.map(fn, range(n_workers))
futures = client.map( for r in results:
local_test, range(len(workers)), workers=workers, n_workers=n_workers assert r is None
)
client.gather(futures)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_loky())
def test_ops_reuse_comm(local_cluster) -> None: def test_ops_reuse_comm() -> None:
from distributed import Client from loky import get_reusable_executor
rng = np.random.default_rng(1994) rng = np.random.default_rng(1994)
n_examples = 10 n_examples = 10
@ -239,10 +229,13 @@ def test_ops_reuse_comm(local_cluster) -> None:
["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples ["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples
).tolist() ).tolist()
def local_test(w: int, n_workers: int) -> None: n_workers = 8
n_trials = 8
def local_test(w: int, rabit_args: Dict[str, Union[str, int]]) -> None:
a = np.arange(0, n_workers) a = np.arange(0, n_workers)
with xgb.dask.CommunicatorContext(**args): with collective.CommunicatorContext(**rabit_args):
for op in ops: for op in ops:
if op == "broadcast": if op == "broadcast":
b = collective.broadcast(a, root=1) b = collective.broadcast(a, root=1)
@ -257,18 +250,13 @@ def test_ops_reuse_comm(local_cluster) -> None:
else: else:
raise ValueError() raise ValueError()
with Client(local_cluster) as client: with get_reusable_executor(max_workers=n_workers) as pool:
workers = tm.get_client_workers(client) for _ in range(n_trials):
args = client.sync( tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
xgb.dask._get_rabit_args, tracker.start()
len(workers), args = tracker.worker_args()
None,
client,
)
n_workers = len(workers) fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
results = pool.map(fn, range(n_workers))
futures = client.map( for r in results:
local_test, range(len(workers)), workers=workers, n_workers=n_workers assert r is None
)
client.gather(futures)