[dask] Fix union of workers. (#6375)
This commit is contained in:
@@ -776,6 +776,26 @@ class TestDaskCallbacks:
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
def test_n_workers(self):
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
||||
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()
|
||||
train = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
dX = dd.from_array(X)
|
||||
dX = client.persist(dX, workers={dX: workers[1]})
|
||||
dy = dd.from_array(y)
|
||||
dy = client.persist(dy, workers={dy: workers[1]})
|
||||
valid = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
|
||||
assert len(merged) == 2
|
||||
|
||||
|
||||
def test_data_initialization(self):
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
generate unnecessary copies of data.
|
||||
|
||||
Reference in New Issue
Block a user