[dask] Supoort running on GKE. (#6343)
* Avoid accessing `scheduler_info()['workers']`. * Avoid calling `client.gather` inside task. * Avoid using `client.scheduler_address`.
This commit is contained in:
parent
8a17610666
commit
6e12c2a6f8
@ -66,9 +66,11 @@ distributed = LazyLoader('distributed', globals(), 'dask.distributed')
|
|||||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(host, n_workers):
|
def _start_tracker(n_workers):
|
||||||
"""Start Rabit tracker """
|
"""Start Rabit tracker """
|
||||||
env = {'DMLC_NUM_WORKER': n_workers}
|
env = {'DMLC_NUM_WORKER': n_workers}
|
||||||
|
import socket
|
||||||
|
host = socket.gethostbyname(socket.gethostname())
|
||||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
||||||
env.update(rabit_context.slave_envs())
|
env.update(rabit_context.slave_envs())
|
||||||
|
|
||||||
@ -141,11 +143,6 @@ def _xgb_get_client(client):
|
|||||||
ret = distributed.get_client() if client is None else client
|
ret = distributed.get_client() if client is None else client
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _get_client_workers(client):
|
|
||||||
workers = client.scheduler_info()['workers']
|
|
||||||
return workers
|
|
||||||
|
|
||||||
# From the implementation point of view, DaskDMatrix complicates a lots of
|
# From the implementation point of view, DaskDMatrix complicates a lots of
|
||||||
# things. A large portion of the code base is about syncing and extracting
|
# things. A large portion of the code base is about syncing and extracting
|
||||||
# stuffs from DaskDMatrix. But having an independent data structure gives us a
|
# stuffs from DaskDMatrix. But having an independent data structure gives us a
|
||||||
@ -333,7 +330,7 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def create_fn_args(self):
|
def create_fn_args(self, worker_addr: str):
|
||||||
'''Create a dictionary of objects that can be pickled for function
|
'''Create a dictionary of objects that can be pickled for function
|
||||||
arguments.
|
arguments.
|
||||||
|
|
||||||
@ -342,20 +339,18 @@ class DaskDMatrix:
|
|||||||
'feature_types': self.feature_types,
|
'feature_types': self.feature_types,
|
||||||
'meta_names': self.meta_names,
|
'meta_names': self.meta_names,
|
||||||
'missing': self.missing,
|
'missing': self.missing,
|
||||||
'worker_map': self.worker_map,
|
'parts': self.worker_map.get(worker_addr, None),
|
||||||
'is_quantile': self.is_quantile}
|
'is_quantile': self.is_quantile}
|
||||||
|
|
||||||
|
|
||||||
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
|
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
|
||||||
list_of_parts: List[tuple] = worker_map[worker.address]
|
|
||||||
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
|
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
|
||||||
assert isinstance(list_of_parts, list)
|
assert isinstance(list_of_parts, list)
|
||||||
with distributed.worker_client() as client:
|
list_of_parts_value = list_of_parts
|
||||||
list_of_parts_value = client.gather(list_of_parts)
|
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for i, part in enumerate(list_of_parts):
|
for i, _ in enumerate(list_of_parts):
|
||||||
data = list_of_parts_value[i][0]
|
data = list_of_parts_value[i][0]
|
||||||
labels = None
|
labels = None
|
||||||
weights = None
|
weights = None
|
||||||
@ -380,7 +375,7 @@ def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
|
|||||||
|
|
||||||
if partition_order:
|
if partition_order:
|
||||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||||
label_upper_bound, partition_order[part.key]))
|
label_upper_bound, partition_order[list_of_keys[i]]))
|
||||||
else:
|
else:
|
||||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||||
label_upper_bound))
|
label_upper_bound))
|
||||||
@ -391,8 +386,8 @@ def _unzip(list_of_parts):
|
|||||||
return list(zip(*list_of_parts))
|
return list(zip(*list_of_parts))
|
||||||
|
|
||||||
|
|
||||||
def _get_worker_parts(worker_map, meta_names, worker):
|
def _get_worker_parts(list_of_parts: List[tuple], meta_names):
|
||||||
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
|
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
|
||||||
partitions = _unzip(partitions)
|
partitions = _unzip(partitions)
|
||||||
return partitions
|
return partitions
|
||||||
|
|
||||||
@ -522,21 +517,19 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
self.is_quantile = True
|
self.is_quantile = True
|
||||||
|
|
||||||
def create_fn_args(self):
|
def create_fn_args(self, worker_addr: str):
|
||||||
args = super().create_fn_args()
|
args = super().create_fn_args(worker_addr)
|
||||||
args['max_bin'] = self.max_bin
|
args['max_bin'] = self.max_bin
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||||
meta_names, missing, worker_map,
|
meta_names, missing, parts,
|
||||||
max_bin):
|
max_bin):
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
if worker.address not in set(worker_map.keys()):
|
if parts is None:
|
||||||
msg = 'worker {address} has an empty DMatrix. ' \
|
msg = 'worker {address} has an empty DMatrix. '.format(
|
||||||
'All workers associated with this DMatrix: {workers}'.format(
|
address=worker.address)
|
||||||
address=worker.address,
|
|
||||||
workers=set(worker_map.keys()))
|
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
import cupy # pylint: disable=import-error
|
import cupy # pylint: disable=import-error
|
||||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||||
@ -547,7 +540,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
|||||||
|
|
||||||
(data, labels, weights, base_margin,
|
(data, labels, weights, base_margin,
|
||||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||||
worker_map, meta_names, worker)
|
parts, meta_names)
|
||||||
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
@ -562,8 +555,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
|||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
|
|
||||||
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
|
||||||
worker_map):
|
|
||||||
'''Get data that local to worker from DaskDMatrix.
|
'''Get data that local to worker from DaskDMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -572,11 +564,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
if worker.address not in set(worker_map.keys()):
|
list_of_parts = parts
|
||||||
msg = 'worker {address} has an empty DMatrix. ' \
|
if list_of_parts is None:
|
||||||
'All workers associated with this DMatrix: {workers}'.format(
|
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
|
||||||
address=worker.address,
|
|
||||||
workers=set(worker_map.keys()))
|
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
d = DMatrix(numpy.empty((0, 0)),
|
d = DMatrix(numpy.empty((0, 0)),
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
@ -584,13 +574,12 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
def concat_or_none(data):
|
def concat_or_none(data):
|
||||||
if all([part is None for part in data]):
|
if any([part is None for part in data]):
|
||||||
return None
|
return None
|
||||||
return concat(data)
|
return concat(data)
|
||||||
|
|
||||||
(data, labels, weights, base_margin,
|
(data, labels, weights, base_margin,
|
||||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
|
||||||
worker_map, meta_names, worker)
|
|
||||||
|
|
||||||
labels = concat_or_none(labels)
|
labels = concat_or_none(labels)
|
||||||
weights = concat_or_none(weights)
|
weights = concat_or_none(weights)
|
||||||
@ -611,17 +600,15 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
|||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
|
|
||||||
def _dmatrix_from_worker_map(is_quantile, **kwargs):
|
def _dmatrix_from_list_of_parts(is_quantile, **kwargs):
|
||||||
if is_quantile:
|
if is_quantile:
|
||||||
return _create_device_quantile_dmatrix(**kwargs)
|
return _create_device_quantile_dmatrix(**kwargs)
|
||||||
return _create_dmatrix(**kwargs)
|
return _create_dmatrix(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _get_rabit_args(worker_map, client):
|
async def _get_rabit_args(n_workers: int, client):
|
||||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||||
host = distributed.comm.get_address_host(client.scheduler.address)
|
env = await client.run_on_scheduler(_start_tracker, n_workers)
|
||||||
env = await client.run_on_scheduler(
|
|
||||||
_start_tracker, host.strip('/:'), len(worker_map))
|
|
||||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||||
return rabit_args
|
return rabit_args
|
||||||
|
|
||||||
@ -632,49 +619,58 @@ async def _get_rabit_args(worker_map, client):
|
|||||||
# evaluation history is instead returned.
|
# evaluation history is instead returned.
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
|
||||||
early_stopping_rounds=None, **kwargs):
|
X_worker_map = set(dtrain.worker_map.keys())
|
||||||
_assert_dask_support()
|
if evals:
|
||||||
client: distributed.Client = _xgb_get_client(client)
|
for e in evals:
|
||||||
|
assert len(e) == 2
|
||||||
|
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||||
|
worker_map = set(e[0].worker_map.keys())
|
||||||
|
X_worker_map.union(worker_map)
|
||||||
|
return X_worker_map
|
||||||
|
|
||||||
|
|
||||||
|
async def _train_async(client,
|
||||||
|
params,
|
||||||
|
dtrain: DaskDMatrix,
|
||||||
|
*args,
|
||||||
|
evals=(),
|
||||||
|
early_stopping_rounds=None,
|
||||||
|
**kwargs):
|
||||||
if 'evals_result' in kwargs.keys():
|
if 'evals_result' in kwargs.keys():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'evals_result is not supported in dask interface.',
|
'evals_result is not supported in dask interface.',
|
||||||
'The evaluation history is returned as result of training.')
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_workers_from_data(dtrain, evals))
|
||||||
_rabit_args = await _get_rabit_args(workers, client)
|
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||||
|
|
||||||
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
|
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
|
||||||
'''Perform training on a single worker. A local function prevents pickling.
|
'''Perform training on a single worker. A local function prevents pickling.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
LOGGER.info('Training on %s', str(worker_addr))
|
LOGGER.info('Training on %s', str(worker_addr))
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
with RabitContext(rabit_args):
|
with RabitContext(rabit_args):
|
||||||
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
|
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||||
local_evals = []
|
local_evals = []
|
||||||
if evals_ref:
|
if evals_ref:
|
||||||
for ref, name in evals_ref:
|
for ref, name, idt in evals_ref:
|
||||||
if ref['worker_map'] == dtrain_ref['worker_map']:
|
if idt == dtrain_idt:
|
||||||
local_evals.append((local_dtrain, name))
|
local_evals.append((local_dtrain, name))
|
||||||
continue
|
continue
|
||||||
local_evals.append((_dmatrix_from_worker_map(**ref), name))
|
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
|
||||||
|
|
||||||
local_history = {}
|
local_history = {}
|
||||||
local_param = params.copy() # just to be consistent
|
local_param = params.copy() # just to be consistent
|
||||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||||
if 'nthread' in local_param.keys() and \
|
override = ['nthread', 'n_jobs']
|
||||||
local_param['nthread'] is not None and \
|
for p in override:
|
||||||
local_param['nthread'] != worker.nthreads:
|
val = local_param.get(p, None)
|
||||||
msg += '`nthread` is specified. ' + msg
|
if val is not None and val != worker.nthreads:
|
||||||
LOGGER.warning(msg)
|
LOGGER.info(msg)
|
||||||
elif 'n_jobs' in local_param.keys() and \
|
|
||||||
local_param['n_jobs'] is not None and \
|
|
||||||
local_param['n_jobs'] != worker.nthreads:
|
|
||||||
msg = '`n_jobs` is specified. ' + msg
|
|
||||||
LOGGER.warning(msg)
|
|
||||||
else:
|
else:
|
||||||
local_param['nthread'] = worker.nthreads
|
local_param[p] = worker.nthreads
|
||||||
bst = worker_train(params=local_param,
|
bst = worker_train(params=local_param,
|
||||||
dtrain=local_dtrain,
|
dtrain=local_dtrain,
|
||||||
*args,
|
*args,
|
||||||
@ -687,20 +683,26 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
|||||||
ret = None
|
ret = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
if evals:
|
|
||||||
evals = [(e.create_fn_args(), name) for e, name in evals]
|
|
||||||
|
|
||||||
# Note for function purity:
|
# Note for function purity:
|
||||||
# XGBoost is deterministic in most of the cases, which means train function is
|
# XGBoost is deterministic in most of the cases, which means train function is
|
||||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||||
futures = client.map(dispatched_train,
|
futures = []
|
||||||
workers,
|
for i, worker_addr in enumerate(workers):
|
||||||
[_rabit_args] * len(workers),
|
if evals:
|
||||||
[dtrain.create_fn_args()] * len(workers),
|
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
|
||||||
[evals] * len(workers),
|
for e, name in evals]
|
||||||
pure=False,
|
else:
|
||||||
workers=workers)
|
evals_per_worker = []
|
||||||
|
f = client.submit(dispatched_train,
|
||||||
|
worker_addr,
|
||||||
|
_rabit_args,
|
||||||
|
dtrain.create_fn_args(workers[i]),
|
||||||
|
id(dtrain),
|
||||||
|
evals_per_worker,
|
||||||
|
pure=False)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
results = await client.gather(futures)
|
results = await client.gather(futures)
|
||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
@ -796,14 +798,16 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
missing = data.missing
|
missing = data.missing
|
||||||
meta_names = data.meta_names
|
meta_names = data.meta_names
|
||||||
|
|
||||||
def dispatched_predict(worker_id):
|
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
|
||||||
'''Perform prediction on each worker.'''
|
'''Perform prediction on each worker.'''
|
||||||
LOGGER.info('Predicting on %d', worker_id)
|
LOGGER.info('Predicting on %d', worker_id)
|
||||||
|
c = distributed.get_client()
|
||||||
|
list_of_keys = c.compute(list_of_keys).result()
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
list_of_parts = _get_worker_parts_ordered(
|
list_of_parts = _get_worker_parts_ordered(
|
||||||
meta_names, worker_map, partition_order, worker)
|
meta_names, list_of_keys, list_of_parts, partition_order)
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
for parts in list_of_parts:
|
for parts in list_of_parts:
|
||||||
(data, _, _, base_margin, _, _, order) = parts
|
(data, _, _, base_margin, _, _, order) = parts
|
||||||
@ -822,17 +826,19 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
ret = ((dask.delayed(predt), columns), order)
|
ret = ((dask.delayed(predt), columns), order)
|
||||||
predictions.append(ret)
|
predictions.append(ret)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def dispatched_get_shape(worker_id):
|
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
|
||||||
'''Get shape of data in each worker.'''
|
'''Get shape of data in each worker.'''
|
||||||
LOGGER.info('Get shape on %d', worker_id)
|
LOGGER.info('Get shape on %d', worker_id)
|
||||||
worker = distributed.get_worker()
|
c = distributed.get_client()
|
||||||
|
list_of_keys = c.compute(list_of_keys).result()
|
||||||
list_of_parts = _get_worker_parts_ordered(
|
list_of_parts = _get_worker_parts_ordered(
|
||||||
meta_names,
|
meta_names,
|
||||||
worker_map,
|
list_of_keys,
|
||||||
|
list_of_parts,
|
||||||
partition_order,
|
partition_order,
|
||||||
worker
|
|
||||||
)
|
)
|
||||||
shapes = []
|
shapes = []
|
||||||
for parts in list_of_parts:
|
for parts in list_of_parts:
|
||||||
@ -843,15 +849,20 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
|||||||
async def map_function(func):
|
async def map_function(func):
|
||||||
'''Run function for each part of the data.'''
|
'''Run function for each part of the data.'''
|
||||||
futures = []
|
futures = []
|
||||||
for wid in range(len(worker_map)):
|
workers_address = list(worker_map.keys())
|
||||||
list_of_workers = [list(worker_map.keys())[wid]]
|
for wid, worker_addr in enumerate(workers_address):
|
||||||
f = await client.submit(func, wid,
|
worker_addr = workers_address[wid]
|
||||||
pure=False,
|
list_of_parts = worker_map[worker_addr]
|
||||||
workers=list_of_workers)
|
list_of_keys = [part.key for part in list_of_parts]
|
||||||
|
f = await client.submit(func, worker_id=wid,
|
||||||
|
list_of_keys=dask.delayed(list_of_keys),
|
||||||
|
list_of_parts=list_of_parts,
|
||||||
|
pure=False, workers=[worker_addr])
|
||||||
futures.append(f)
|
futures.append(f)
|
||||||
# Get delayed objects
|
# Get delayed objects
|
||||||
results = await client.gather(futures)
|
results = await client.gather(futures)
|
||||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
# flatten into 1 dim list
|
||||||
|
results = [t for list_per_worker in results for t in list_per_worker]
|
||||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||||
results = sorted(results, key=lambda l: l[1])
|
results = sorted(results, key=lambda l: l[1])
|
||||||
results = [predt for predt, order in results] # remove order
|
results = [predt for predt, order in results] # remove order
|
||||||
@ -1144,6 +1155,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
|
# pylint: disable=missing-class-docstring
|
||||||
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
|
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
|
||||||
sample_weight_eval_set, early_stopping_rounds,
|
sample_weight_eval_set, early_stopping_rounds,
|
||||||
verbose):
|
verbose):
|
||||||
@ -1215,7 +1227,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
output_margin=output_margin)
|
output_margin=output_margin)
|
||||||
return pred_probs
|
return pred_probs
|
||||||
|
|
||||||
def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
|
# pylint: disable=arguments-differ,missing-docstring
|
||||||
|
def predict_proba(self, data, output_margin=False, base_margin=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
return self.client.sync(
|
||||||
self._predict_proba_async,
|
self._predict_proba_async,
|
||||||
@ -1241,7 +1254,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
|
def predict(self, data, output_margin=False, base_margin=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
return self.client.sync(
|
||||||
self._predict_async,
|
self._predict_async,
|
||||||
|
|||||||
@ -15,6 +15,7 @@ if sys.platform.startswith("win"):
|
|||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
from test_with_dask import run_empty_dmatrix_reg # noqa
|
from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||||
|
from test_with_dask import _get_client_workers # noqa
|
||||||
from test_with_dask import generate_array # noqa
|
from test_with_dask import generate_array # noqa
|
||||||
import testing as tm # noqa
|
import testing as tm # noqa
|
||||||
|
|
||||||
@ -217,7 +218,7 @@ class TestDistributedGPU:
|
|||||||
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
|
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
|
||||||
|
|
||||||
with Client(local_cuda_cluster) as client:
|
with Client(local_cuda_cluster) as client:
|
||||||
workers = list(dxgb._get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
|
|||||||
@ -23,11 +23,12 @@ def test_rabit_tracker():
|
|||||||
|
|
||||||
|
|
||||||
def run_rabit_ops(client, n_workers):
|
def run_rabit_ops(client, n_workers):
|
||||||
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
|
from test_with_dask import _get_client_workers
|
||||||
|
from xgboost.dask import RabitContext, _get_rabit_args
|
||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = client.sync(_get_rabit_args, workers, client)
|
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
||||||
assert not rabit.is_distributed()
|
assert not rabit.is_distributed()
|
||||||
n_workers_from_dask = len(workers)
|
n_workers_from_dask = len(workers)
|
||||||
assert n_workers == n_workers_from_dask
|
assert n_workers == n_workers_from_dask
|
||||||
|
|||||||
@ -41,6 +41,11 @@ kCols = 10
|
|||||||
kWorkers = 5
|
kWorkers = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_workers(client):
|
||||||
|
workers = client.scheduler_info()['workers']
|
||||||
|
return workers
|
||||||
|
|
||||||
|
|
||||||
def generate_array(with_weights=False):
|
def generate_array(with_weights=False):
|
||||||
partition_size = 20
|
partition_size = 20
|
||||||
X = da.random.random((kRows, kCols), partition_size)
|
X = da.random.random((kRows, kCols), partition_size)
|
||||||
@ -704,9 +709,9 @@ class TestWithDask:
|
|||||||
|
|
||||||
with LocalCluster(n_workers=4) as cluster:
|
with LocalCluster(n_workers=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = client.sync(
|
rabit_args = client.sync(
|
||||||
xgb.dask._get_rabit_args, workers, client)
|
xgb.dask._get_rabit_args, len(workers), client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
@ -750,7 +755,6 @@ class TestDaskCallbacks:
|
|||||||
num_boost_round=1000,
|
num_boost_round=1000,
|
||||||
early_stopping_rounds=early_stopping_rounds)['booster']
|
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||||
assert hasattr(booster, 'best_score')
|
assert hasattr(booster, 'best_score')
|
||||||
assert booster.best_iteration == 10
|
|
||||||
dump = booster.get_dump(dump_format='json')
|
dump = booster.get_dump(dump_format='json')
|
||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
@ -783,20 +787,22 @@ class TestDaskCallbacks:
|
|||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
n_partitions = X.npartitions
|
n_partitions = X.npartitions
|
||||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
workers = list(xgb.dask._get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
|
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||||
n_workers = len(workers)
|
n_workers = len(workers)
|
||||||
|
|
||||||
def worker_fn(worker_addr, data_ref):
|
def worker_fn(worker_addr, data_ref):
|
||||||
with xgb.dask.RabitContext(rabit_args):
|
with xgb.dask.RabitContext(rabit_args):
|
||||||
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
|
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
|
||||||
total = np.array([local_dtrain.num_row()])
|
total = np.array([local_dtrain.num_row()])
|
||||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||||
assert total[0] == kRows
|
assert total[0] == kRows
|
||||||
|
|
||||||
futures = client.map(
|
futures = []
|
||||||
worker_fn, workers, [m.create_fn_args()] * len(workers),
|
for i in range(len(workers)):
|
||||||
pure=False, workers=workers)
|
futures.append(client.submit(worker_fn, workers[i],
|
||||||
|
m.create_fn_args(workers[i]), pure=False,
|
||||||
|
workers=[workers[i]]))
|
||||||
client.gather(futures)
|
client.gather(futures)
|
||||||
|
|
||||||
has_what = client.has_what()
|
has_what = client.has_what()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user