[dask] Test for data initializaton. (#6226)
This commit is contained in:
@@ -326,10 +326,10 @@ class DaskDMatrix:
|
||||
self.partition_order[part.key] = i
|
||||
|
||||
key_to_partition = {part.key: part for part in parts}
|
||||
who_has = await client.scheduler.who_has(
|
||||
keys=[part.key for part in parts])
|
||||
who_has = await client.scheduler.who_has(keys=[part.key for part in parts])
|
||||
|
||||
worker_map = defaultdict(list)
|
||||
|
||||
for key, workers in who_has.items():
|
||||
worker_map[next(iter(workers))].append(key_to_partition[key])
|
||||
|
||||
@@ -651,9 +651,9 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = await _get_rabit_args(workers, client)
|
||||
_rabit_args = await _get_rabit_args(workers, client)
|
||||
|
||||
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
|
||||
'''
|
||||
@@ -699,8 +699,13 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
if evals:
|
||||
evals = [(e.create_fn_args(), name) for e, name in evals]
|
||||
|
||||
# Note for function purity:
|
||||
# XGBoost is deterministic in most of the cases, which means train function is
|
||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||
futures = client.map(dispatched_train,
|
||||
workers,
|
||||
[_rabit_args] * len(workers),
|
||||
[dtrain.create_fn_args()] * len(workers),
|
||||
[evals] * len(workers),
|
||||
pure=False,
|
||||
|
||||
Reference in New Issue
Block a user