[dask] Test for data initializaton. (#6226)

This commit is contained in:
Jiaming Yuan
2020-10-13 11:08:35 +08:00
committed by GitHub
parent 2443275891
commit b05073bda5
3 changed files with 51 additions and 6 deletions

View File

@@ -326,10 +326,10 @@ class DaskDMatrix:
self.partition_order[part.key] = i
key_to_partition = {part.key: part for part in parts}
who_has = await client.scheduler.who_has(
keys=[part.key for part in parts])
who_has = await client.scheduler.who_has(keys=[part.key for part in parts])
worker_map = defaultdict(list)
for key, workers in who_has.items():
worker_map[next(iter(workers))].append(key_to_partition[key])
@@ -651,9 +651,9 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
'The evaluation history is returned as result of training.')
workers = list(_get_client_workers(client).keys())
rabit_args = await _get_rabit_args(workers, client)
_rabit_args = await _get_rabit_args(workers, client)
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
'''
@@ -699,8 +699,13 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
if evals:
evals = [(e.create_fn_args(), name) for e, name in evals]
# Note for function purity:
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
futures = client.map(dispatched_train,
workers,
[_rabit_args] * len(workers),
[dtrain.create_fn_args()] * len(workers),
[evals] * len(workers),
pure=False,