[coll] Use loky for rabit op tests. (#10828)
This commit is contained in:
parent
15c6172e09
commit
d5e1c41b69
@ -218,8 +218,13 @@ def check_extmem_qdm(
|
||||
)
|
||||
|
||||
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
|
||||
X, y, w = it.as_arrays()
|
||||
Xy = xgb.QuantileDMatrix(X, y, weight=w)
|
||||
it = tm.IteratorForTest(
|
||||
*tm.make_batches(
|
||||
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
|
||||
),
|
||||
cache=None,
|
||||
)
|
||||
Xy = xgb.QuantileDMatrix(it)
|
||||
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
|
||||
|
||||
if device == "cpu":
|
||||
|
||||
@ -34,44 +34,48 @@ def test_socket_error():
|
||||
tracker.free()
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
def run_rabit_ops(pool, n_workers: int, address: str) -> None:
|
||||
tracker = RabitTracker(host_ip=address, n_workers=n_workers)
|
||||
tracker.start()
|
||||
args = tracker.worker_args()
|
||||
|
||||
workers = tm.get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with CommunicatorContext(**rabit_args):
|
||||
def local_test(worker_id: int, rabit_args: dict) -> int:
|
||||
with collective.CommunicatorContext(**rabit_args):
|
||||
a = 1
|
||||
assert collective.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||
arr = np.array([a])
|
||||
reduced = collective.allreduce(arr, collective.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||
arr = np.array([worker_id])
|
||||
reduced = collective.allreduce(arr, collective.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
results = client.gather(futures)
|
||||
fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
|
||||
results = pool.map(fn, range(n_workers))
|
||||
assert sum(results) == n_workers
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_loky())
|
||||
def test_rabit_ops():
|
||||
from distributed import Client, LocalCluster
|
||||
from loky import get_reusable_executor
|
||||
|
||||
n_workers = 3
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
n_workers = 4
|
||||
with get_reusable_executor(max_workers=n_workers) as pool:
|
||||
run_rabit_ops(pool, n_workers, "127.0.0.1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_ipv6())
|
||||
@pytest.mark.skipif(**tm.no_loky())
|
||||
def test_rabit_ops_ipv6():
|
||||
from loky import get_reusable_executor
|
||||
|
||||
n_workers = 4
|
||||
with get_reusable_executor(max_workers=n_workers) as pool:
|
||||
run_rabit_ops(pool, n_workers, "::1")
|
||||
|
||||
|
||||
def run_allreduce(pool, n_workers: int) -> None:
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
|
||||
@ -133,19 +137,6 @@ def test_broadcast():
|
||||
run_broadcast(pool, n_workers)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_ipv6())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rabit_ops_ipv6():
|
||||
import dask
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
n_workers = 3
|
||||
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
|
||||
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rank_assignment() -> None:
|
||||
from distributed import Client, LocalCluster
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user