- Add user configuration. - Bring back to the logic of using scheduler address from dask. This was removed when we were trying to support GKE, now we bring it back and let xgboost try it if direct guess or host IP from user config failed.
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
from xgboost import RabitTracker
|
|
import xgboost as xgb
|
|
import pytest
|
|
import testing as tm
|
|
import numpy as np
|
|
import sys
|
|
|
|
if sys.platform.startswith("win"):
|
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
|
|
|
|
|
def test_rabit_tracker():
|
|
tracker = RabitTracker(hostIP='127.0.0.1', n_workers=1)
|
|
tracker.start(1)
|
|
rabit_env = [
|
|
str.encode('DMLC_TRACKER_URI=127.0.0.1'),
|
|
str.encode('DMLC_TRACKER_PORT=9091'),
|
|
str.encode('DMLC_TASK_ID=0')]
|
|
xgb.rabit.init(rabit_env)
|
|
ret = xgb.rabit.broadcast('test1234', 0)
|
|
assert str(ret) == 'test1234'
|
|
xgb.rabit.finalize()
|
|
|
|
|
|
def run_rabit_ops(client, n_workers):
|
|
from test_with_dask import _get_client_workers
|
|
from xgboost.dask import RabitContext, _get_rabit_args
|
|
from xgboost import rabit
|
|
|
|
workers = _get_client_workers(client)
|
|
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
|
|
assert not rabit.is_distributed()
|
|
n_workers_from_dask = len(workers)
|
|
assert n_workers == n_workers_from_dask
|
|
|
|
def local_test(worker_id):
|
|
with RabitContext(rabit_args):
|
|
a = 1
|
|
assert rabit.is_distributed()
|
|
a = np.array([a])
|
|
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
|
assert reduced[0] == n_workers
|
|
|
|
worker_id = np.array([worker_id])
|
|
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
|
assert reduced == n_workers - 1
|
|
|
|
return 1
|
|
|
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
|
results = client.gather(futures)
|
|
assert sum(results) == n_workers
|
|
|
|
|
|
@pytest.mark.skipif(**tm.no_dask())
|
|
def test_rabit_ops():
|
|
from distributed import Client, LocalCluster
|
|
n_workers = 3
|
|
with LocalCluster(n_workers=n_workers) as cluster:
|
|
with Client(cluster) as client:
|
|
run_rabit_ops(client, n_workers)
|