[dask] Fix union of workers. (#6375)

This commit is contained in:
Jiaming Yuan 2020-11-13 16:55:05 +08:00 committed by GitHub
parent fcfeb4959c
commit 4ccf92ea34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -627,7 +627,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
assert len(e) == 2 assert len(e) == 2
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
worker_map = set(e[0].worker_map.keys()) worker_map = set(e[0].worker_map.keys())
X_worker_map.union(worker_map) X_worker_map = X_worker_map.union(worker_map)
return X_worker_map return X_worker_map

View File

@ -776,6 +776,26 @@ class TestDaskCallbacks:
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_n_workers(self):
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
workers = list(_get_client_workers(client).keys())
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()
train = xgb.dask.DaskDMatrix(client, dX, dy)
dX = dd.from_array(X)
dX = client.persist(dX, workers={dX: workers[1]})
dy = dd.from_array(y)
dy = client.persist(dy, workers={dy: workers[1]})
valid = xgb.dask.DaskDMatrix(client, dX, dy)
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
assert len(merged) == 2
def test_data_initialization(self): def test_data_initialization(self):
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't '''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
generate unnecessary copies of data. generate unnecessary copies of data.