[dask] Fix union of workers. (#6375)
This commit is contained in:
parent
fcfeb4959c
commit
4ccf92ea34
@ -627,7 +627,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
||||
assert len(e) == 2
|
||||
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||
worker_map = set(e[0].worker_map.keys())
|
||||
X_worker_map.union(worker_map)
|
||||
X_worker_map = X_worker_map.union(worker_map)
|
||||
return X_worker_map
|
||||
|
||||
|
||||
|
||||
@ -776,6 +776,26 @@ class TestDaskCallbacks:
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
def test_n_workers(self):
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
||||
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()
|
||||
train = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
dX = dd.from_array(X)
|
||||
dX = client.persist(dX, workers={dX: workers[1]})
|
||||
dy = dd.from_array(y)
|
||||
dy = client.persist(dy, workers={dy: workers[1]})
|
||||
valid = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
|
||||
assert len(merged) == 2
|
||||
|
||||
|
||||
def test_data_initialization(self):
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
generate unnecessary copies of data.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user