[dask] Fix union of workers. (#6375)
This commit is contained in:
parent
fcfeb4959c
commit
4ccf92ea34
@ -627,7 +627,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
|||||||
assert len(e) == 2
|
assert len(e) == 2
|
||||||
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||||
worker_map = set(e[0].worker_map.keys())
|
worker_map = set(e[0].worker_map.keys())
|
||||||
X_worker_map.union(worker_map)
|
X_worker_map = X_worker_map.union(worker_map)
|
||||||
return X_worker_map
|
return X_worker_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -776,6 +776,26 @@ class TestDaskCallbacks:
|
|||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_n_workers(self):
|
||||||
|
with LocalCluster(n_workers=2) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
||||||
|
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()
|
||||||
|
train = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||||
|
|
||||||
|
dX = dd.from_array(X)
|
||||||
|
dX = client.persist(dX, workers={dX: workers[1]})
|
||||||
|
dy = dd.from_array(y)
|
||||||
|
dy = client.persist(dy, workers={dy: workers[1]})
|
||||||
|
valid = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||||
|
|
||||||
|
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
|
||||||
|
assert len(merged) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_data_initialization(self):
|
def test_data_initialization(self):
|
||||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||||
generate unnecessary copies of data.
|
generate unnecessary copies of data.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user