[dask] Supoort running on GKE. (#6343)
* Avoid accessing `scheduler_info()['workers']`. * Avoid calling `client.gather` inside task. * Avoid using `client.scheduler_address`.
This commit is contained in:
parent
8a17610666
commit
6e12c2a6f8
@ -66,9 +66,11 @@ distributed = LazyLoader('distributed', globals(), 'dask.distributed')
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
|
||||
|
||||
def _start_tracker(host, n_workers):
|
||||
def _start_tracker(n_workers):
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
import socket
|
||||
host = socket.gethostbyname(socket.gethostname())
|
||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
||||
env.update(rabit_context.slave_envs())
|
||||
|
||||
@ -141,11 +143,6 @@ def _xgb_get_client(client):
|
||||
ret = distributed.get_client() if client is None else client
|
||||
return ret
|
||||
|
||||
|
||||
def _get_client_workers(client):
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
|
||||
# From the implementation point of view, DaskDMatrix complicates a lots of
|
||||
# things. A large portion of the code base is about syncing and extracting
|
||||
# stuffs from DaskDMatrix. But having an independent data structure gives us a
|
||||
@ -333,7 +330,7 @@ class DaskDMatrix:
|
||||
|
||||
return self
|
||||
|
||||
def create_fn_args(self):
|
||||
def create_fn_args(self, worker_addr: str):
|
||||
'''Create a dictionary of objects that can be pickled for function
|
||||
arguments.
|
||||
|
||||
@ -342,57 +339,55 @@ class DaskDMatrix:
|
||||
'feature_types': self.feature_types,
|
||||
'meta_names': self.meta_names,
|
||||
'missing': self.missing,
|
||||
'worker_map': self.worker_map,
|
||||
'parts': self.worker_map.get(worker_addr, None),
|
||||
'is_quantile': self.is_quantile}
|
||||
|
||||
|
||||
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
|
||||
list_of_parts: List[tuple] = worker_map[worker.address]
|
||||
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
|
||||
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
|
||||
assert isinstance(list_of_parts, list)
|
||||
with distributed.worker_client() as client:
|
||||
list_of_parts_value = client.gather(list_of_parts)
|
||||
list_of_parts_value = list_of_parts
|
||||
|
||||
result = []
|
||||
result = []
|
||||
|
||||
for i, part in enumerate(list_of_parts):
|
||||
data = list_of_parts_value[i][0]
|
||||
labels = None
|
||||
weights = None
|
||||
base_margin = None
|
||||
label_lower_bound = None
|
||||
label_upper_bound = None
|
||||
# Iterate through all possible meta info, brings small overhead as in xgboost
|
||||
# there are constant number of meta info available.
|
||||
for j, blob in enumerate(list_of_parts_value[i][1:]):
|
||||
if meta_names[j] == 'labels':
|
||||
labels = blob
|
||||
elif meta_names[j] == 'weights':
|
||||
weights = blob
|
||||
elif meta_names[j] == 'base_margin':
|
||||
base_margin = blob
|
||||
elif meta_names[j] == 'label_lower_bound':
|
||||
label_lower_bound = blob
|
||||
elif meta_names[j] == 'label_upper_bound':
|
||||
label_upper_bound = blob
|
||||
else:
|
||||
raise ValueError('Unknown metainfo:', meta_names[j])
|
||||
|
||||
if partition_order:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound, partition_order[part.key]))
|
||||
for i, _ in enumerate(list_of_parts):
|
||||
data = list_of_parts_value[i][0]
|
||||
labels = None
|
||||
weights = None
|
||||
base_margin = None
|
||||
label_lower_bound = None
|
||||
label_upper_bound = None
|
||||
# Iterate through all possible meta info, brings small overhead as in xgboost
|
||||
# there are constant number of meta info available.
|
||||
for j, blob in enumerate(list_of_parts_value[i][1:]):
|
||||
if meta_names[j] == 'labels':
|
||||
labels = blob
|
||||
elif meta_names[j] == 'weights':
|
||||
weights = blob
|
||||
elif meta_names[j] == 'base_margin':
|
||||
base_margin = blob
|
||||
elif meta_names[j] == 'label_lower_bound':
|
||||
label_lower_bound = blob
|
||||
elif meta_names[j] == 'label_upper_bound':
|
||||
label_upper_bound = blob
|
||||
else:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound))
|
||||
return result
|
||||
raise ValueError('Unknown metainfo:', meta_names[j])
|
||||
|
||||
if partition_order:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound, partition_order[list_of_keys[i]]))
|
||||
else:
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
label_upper_bound))
|
||||
return result
|
||||
|
||||
|
||||
def _unzip(list_of_parts):
|
||||
return list(zip(*list_of_parts))
|
||||
|
||||
|
||||
def _get_worker_parts(worker_map, meta_names, worker):
|
||||
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
|
||||
def _get_worker_parts(list_of_parts: List[tuple], meta_names):
|
||||
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
|
||||
partitions = _unzip(partitions)
|
||||
return partitions
|
||||
|
||||
@ -522,21 +517,19 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
def create_fn_args(self):
|
||||
args = super().create_fn_args()
|
||||
def create_fn_args(self, worker_addr: str):
|
||||
args = super().create_fn_args(worker_addr)
|
||||
args['max_bin'] = self.max_bin
|
||||
return args
|
||||
|
||||
|
||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
meta_names, missing, worker_map,
|
||||
meta_names, missing, parts,
|
||||
max_bin):
|
||||
worker = distributed.get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(worker_map.keys()))
|
||||
if parts is None:
|
||||
msg = 'worker {address} has an empty DMatrix. '.format(
|
||||
address=worker.address)
|
||||
LOGGER.warning(msg)
|
||||
import cupy # pylint: disable=import-error
|
||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||
@ -547,7 +540,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
|
||||
(data, labels, weights, base_margin,
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||
worker_map, meta_names, worker)
|
||||
parts, meta_names)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
||||
base_margin=base_margin,
|
||||
label_lower_bound=label_lower_bound,
|
||||
@ -562,8 +555,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
worker_map):
|
||||
def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
|
||||
'''Get data that local to worker from DaskDMatrix.
|
||||
|
||||
Returns
|
||||
@ -572,11 +564,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
|
||||
'''
|
||||
worker = distributed.get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(worker_map.keys()))
|
||||
list_of_parts = parts
|
||||
if list_of_parts is None:
|
||||
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=feature_names,
|
||||
@ -584,13 +574,12 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
return d
|
||||
|
||||
def concat_or_none(data):
|
||||
if all([part is None for part in data]):
|
||||
if any([part is None for part in data]):
|
||||
return None
|
||||
return concat(data)
|
||||
|
||||
(data, labels, weights, base_margin,
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||
worker_map, meta_names, worker)
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
|
||||
|
||||
labels = concat_or_none(labels)
|
||||
weights = concat_or_none(weights)
|
||||
@ -611,17 +600,15 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _dmatrix_from_worker_map(is_quantile, **kwargs):
|
||||
def _dmatrix_from_list_of_parts(is_quantile, **kwargs):
|
||||
if is_quantile:
|
||||
return _create_device_quantile_dmatrix(**kwargs)
|
||||
return _create_dmatrix(**kwargs)
|
||||
|
||||
|
||||
async def _get_rabit_args(worker_map, client):
|
||||
async def _get_rabit_args(n_workers: int, client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
host = distributed.comm.get_address_host(client.scheduler.address)
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, host.strip('/:'), len(worker_map))
|
||||
env = await client.run_on_scheduler(_start_tracker, n_workers)
|
||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||
return rabit_args
|
||||
|
||||
@ -632,49 +619,58 @@ async def _get_rabit_args(worker_map, client):
|
||||
# evaluation history is instead returned.
|
||||
|
||||
|
||||
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
early_stopping_rounds=None, **kwargs):
|
||||
_assert_dask_support()
|
||||
client: distributed.Client = _xgb_get_client(client)
|
||||
def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
||||
X_worker_map = set(dtrain.worker_map.keys())
|
||||
if evals:
|
||||
for e in evals:
|
||||
assert len(e) == 2
|
||||
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||
worker_map = set(e[0].worker_map.keys())
|
||||
X_worker_map.union(worker_map)
|
||||
return X_worker_map
|
||||
|
||||
|
||||
async def _train_async(client,
|
||||
params,
|
||||
dtrain: DaskDMatrix,
|
||||
*args,
|
||||
evals=(),
|
||||
early_stopping_rounds=None,
|
||||
**kwargs):
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
_rabit_args = await _get_rabit_args(workers, client)
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
|
||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
for ref, name in evals_ref:
|
||||
if ref['worker_map'] == dtrain_ref['worker_map']:
|
||||
for ref, name, idt in evals_ref:
|
||||
if idt == dtrain_idt:
|
||||
local_evals.append((local_dtrain, name))
|
||||
continue
|
||||
local_evals.append((_dmatrix_from_worker_map(**ref), name))
|
||||
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
|
||||
|
||||
local_history = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
if 'nthread' in local_param.keys() and \
|
||||
local_param['nthread'] is not None and \
|
||||
local_param['nthread'] != worker.nthreads:
|
||||
msg += '`nthread` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
elif 'n_jobs' in local_param.keys() and \
|
||||
local_param['n_jobs'] is not None and \
|
||||
local_param['n_jobs'] != worker.nthreads:
|
||||
msg = '`n_jobs` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
local_param['nthread'] = worker.nthreads
|
||||
override = ['nthread', 'n_jobs']
|
||||
for p in override:
|
||||
val = local_param.get(p, None)
|
||||
if val is not None and val != worker.nthreads:
|
||||
LOGGER.info(msg)
|
||||
else:
|
||||
local_param[p] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
@ -687,20 +683,26 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
if evals:
|
||||
evals = [(e.create_fn_args(), name) for e, name in evals]
|
||||
|
||||
# Note for function purity:
|
||||
# XGBoost is deterministic in most of the cases, which means train function is
|
||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||
futures = client.map(dispatched_train,
|
||||
workers,
|
||||
[_rabit_args] * len(workers),
|
||||
[dtrain.create_fn_args()] * len(workers),
|
||||
[evals] * len(workers),
|
||||
pure=False,
|
||||
workers=workers)
|
||||
futures = []
|
||||
for i, worker_addr in enumerate(workers):
|
||||
if evals:
|
||||
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
|
||||
for e, name in evals]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
dtrain.create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False)
|
||||
futures.append(f)
|
||||
|
||||
results = await client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
@ -796,14 +798,16 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
missing = data.missing
|
||||
meta_names = data.meta_names
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
|
||||
c = distributed.get_client()
|
||||
list_of_keys = c.compute(list_of_keys).result()
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
meta_names, worker_map, partition_order, worker)
|
||||
meta_names, list_of_keys, list_of_parts, partition_order)
|
||||
predictions = []
|
||||
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for parts in list_of_parts:
|
||||
(data, _, _, base_margin, _, _, order) = parts
|
||||
@ -822,17 +826,19 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((dask.delayed(predt), columns), order)
|
||||
predictions.append(ret)
|
||||
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Get shape on %d', worker_id)
|
||||
worker = distributed.get_worker()
|
||||
c = distributed.get_client()
|
||||
list_of_keys = c.compute(list_of_keys).result()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
meta_names,
|
||||
worker_map,
|
||||
list_of_keys,
|
||||
list_of_parts,
|
||||
partition_order,
|
||||
worker
|
||||
)
|
||||
shapes = []
|
||||
for parts in list_of_parts:
|
||||
@ -843,15 +849,20 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
async def map_function(func):
|
||||
'''Run function for each part of the data.'''
|
||||
futures = []
|
||||
for wid in range(len(worker_map)):
|
||||
list_of_workers = [list(worker_map.keys())[wid]]
|
||||
f = await client.submit(func, wid,
|
||||
pure=False,
|
||||
workers=list_of_workers)
|
||||
workers_address = list(worker_map.keys())
|
||||
for wid, worker_addr in enumerate(workers_address):
|
||||
worker_addr = workers_address[wid]
|
||||
list_of_parts = worker_map[worker_addr]
|
||||
list_of_keys = [part.key for part in list_of_parts]
|
||||
f = await client.submit(func, worker_id=wid,
|
||||
list_of_keys=dask.delayed(list_of_keys),
|
||||
list_of_parts=list_of_parts,
|
||||
pure=False, workers=[worker_addr])
|
||||
futures.append(f)
|
||||
# Get delayed objects
|
||||
results = await client.gather(futures)
|
||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||
# flatten into 1 dim list
|
||||
results = [t for list_per_worker in results for t in list_per_worker]
|
||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||
results = sorted(results, key=lambda l: l[1])
|
||||
results = [predt for predt, order in results] # remove order
|
||||
@ -1144,6 +1155,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||
['estimators', 'model'])
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
|
||||
sample_weight_eval_set, early_stopping_rounds,
|
||||
verbose):
|
||||
@ -1215,7 +1227,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
output_margin=output_margin)
|
||||
return pred_probs
|
||||
|
||||
def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
|
||||
# pylint: disable=arguments-differ,missing-docstring
|
||||
def predict_proba(self, data, output_margin=False, base_margin=None):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(
|
||||
self._predict_proba_async,
|
||||
@ -1241,7 +1254,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
return preds
|
||||
|
||||
def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
|
||||
# pylint: disable=arguments-differ
|
||||
def predict(self, data, output_margin=False, base_margin=None):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(
|
||||
self._predict_async,
|
||||
|
||||
@ -15,6 +15,7 @@ if sys.platform.startswith("win"):
|
||||
sys.path.append("tests/python")
|
||||
from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
import testing as tm # noqa
|
||||
|
||||
@ -217,7 +218,7 @@ class TestDistributedGPU:
|
||||
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
with Client(local_cuda_cluster) as client:
|
||||
workers = list(dxgb._get_client_workers(client).keys())
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
|
||||
@ -23,11 +23,12 @@ def test_rabit_tracker():
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_rabit_args
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(_get_rabit_args, workers, client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
@ -41,6 +41,11 @@ kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def _get_client_workers(client):
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
|
||||
|
||||
def generate_array(with_weights=False):
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
@ -704,9 +709,9 @@ class TestWithDask:
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, workers, client)
|
||||
xgb.dask._get_rabit_args, len(workers), client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
@ -750,7 +755,6 @@ class TestDaskCallbacks:
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||
assert hasattr(booster, 'best_score')
|
||||
assert booster.best_iteration == 10
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@ -783,20 +787,22 @@ class TestDaskCallbacks:
|
||||
X, y = generate_array()
|
||||
n_partitions = X.npartitions
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr, data_ref):
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = client.map(
|
||||
worker_fn, workers, [m.create_fn_args()] * len(workers),
|
||||
pure=False, workers=workers)
|
||||
futures = []
|
||||
for i in range(len(workers)):
|
||||
futures.append(client.submit(worker_fn, workers[i],
|
||||
m.create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]]))
|
||||
client.gather(futures)
|
||||
|
||||
has_what = client.has_what()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user