diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 1de2e4773..8eb026283 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -627,7 +627,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()): assert len(e) == 2 assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) worker_map = set(e[0].worker_map.keys()) - X_worker_map.union(worker_map) + X_worker_map = X_worker_map.union(worker_map) return X_worker_map diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index ca1da7042..ba697ab4d 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -776,6 +776,26 @@ class TestDaskCallbacks: dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 + def test_n_workers(self): + with LocalCluster(n_workers=2) as cluster: + with Client(cluster) as client: + workers = list(_get_client_workers(client).keys()) + from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + dX = client.submit(da.from_array, X, workers=[workers[0]]).result() + dy = client.submit(da.from_array, y, workers=[workers[0]]).result() + train = xgb.dask.DaskDMatrix(client, dX, dy) + + dX = dd.from_array(X) + dX = client.persist(dX, workers={dX: workers[1]}) + dy = dd.from_array(y) + dy = client.persist(dy, workers={dy: workers[1]}) + valid = xgb.dask.DaskDMatrix(client, dX, dy) + + merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')]) + assert len(merged) == 2 + + def test_data_initialization(self): '''Assert each worker has the correct amount of data, and DMatrix initialization doesn't generate unnecessary copies of data.