[dask] Fix union of workers. (#6375)

This commit is contained in:
Jiaming Yuan 2020-11-13 16:55:05 +08:00 committed by GitHub
parent fcfeb4959c
commit 4ccf92ea34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -627,7 +627,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
assert len(e) == 2
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
worker_map = set(e[0].worker_map.keys())
X_worker_map.union(worker_map)
X_worker_map = X_worker_map.union(worker_map)
return X_worker_map

View File

@ -776,6 +776,26 @@ class TestDaskCallbacks:
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
def test_n_workers(self):
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
workers = list(_get_client_workers(client).keys())
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()
train = xgb.dask.DaskDMatrix(client, dX, dy)
dX = dd.from_array(X)
dX = client.persist(dX, workers={dX: workers[1]})
dy = dd.from_array(y)
dy = client.persist(dy, workers={dy: workers[1]})
valid = xgb.dask.DaskDMatrix(client, dX, dy)
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
assert len(merged) == 2
def test_data_initialization(self):
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
generate unnecessary copies of data.