[dask] Supoort running on GKE. (#6343)
* Avoid accessing `scheduler_info()['workers']`. * Avoid calling `client.gather` inside task. * Avoid using `client.scheduler_address`.
This commit is contained in:
@@ -23,11 +23,12 @@ def test_rabit_tracker():
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_rabit_args
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(_get_rabit_args, workers, client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
@@ -41,6 +41,11 @@ kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def _get_client_workers(client):
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
|
||||
|
||||
def generate_array(with_weights=False):
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
@@ -704,9 +709,9 @@ class TestWithDask:
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, workers, client)
|
||||
xgb.dask._get_rabit_args, len(workers), client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
@@ -750,7 +755,6 @@ class TestDaskCallbacks:
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
assert booster.best_iteration == 10
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@@ -783,20 +787,22 @@ class TestDaskCallbacks:
|
||||
X, y = generate_array()
|
||||
n_partitions = X.npartitions
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr, data_ref):
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = client.map(
|
||||
worker_fn, workers, [m.create_fn_args()] * len(workers),
|
||||
pure=False, workers=workers)
|
||||
futures = []
|
||||
for i in range(len(workers)):
|
||||
futures.append(client.submit(worker_fn, workers[i],
|
||||
m.create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]]))
|
||||
client.gather(futures)
|
||||
|
||||
has_what = client.has_what()
|
||||
|
||||
Reference in New Issue
Block a user