[Dask] Asyncio support. (#5862)

This commit is contained in:
Jiaming Yuan 2020-07-30 06:23:58 +08:00 committed by GitHub
parent e4a273e1da
commit fa3715f584
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 637 additions and 335 deletions

View File

@ -12,6 +12,12 @@ algorithm. For an overview of GPU based training and internal working, see `A N
Official Dask API for XGBoost Official Dask API for XGBoost
<https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7>`_. <https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7>`_.
**Contents**
.. contents::
:backlinks: none
:local:
************ ************
Requirements Requirements
************ ************
@ -105,6 +111,60 @@ set:
XGBoost will use 8 threads in each training process. XGBoost will use 8 threads in each training process.
********************
Working with asyncio
********************
.. versionadded:: 1.2.0
XGBoost dask interface supports the new ``asyncio`` in Python and can be integrated into
asynchronous workflows. For using dask with asynchronous operations, please refer to
`dask example <https://examples.dask.org/applications/async-await.html>`_ and document in
`distributed <https://distributed.dask.org/en/latest/asynchronous.html>`_. As XGBoost
takes ``Client`` object as an argument for both training and prediction, so when
``asynchronous=True`` is specified when creating ``Client``, the dask interface can adapt
the change accordingly. All functions provided by the functional interface returns a
coroutine when called in async function, and hence require awaiting to get the result,
including ``DaskDMatrix``.
Functional interface:
.. code-block:: python
async with Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
m = await xgb.dask.DaskDMatrix(client, X, y)
output = await xgb.dask.train(client, {}, dtrain=m)
with_m = await xgb.dask.predict(client, output, m)
with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X)
# Use `client.compute` instead of the `compute` method from dask collection
print(await client.compute(with_m))
While for Scikit Learn interface, trivial methods like ``set_params`` and accessing class
attributes like ``evals_result_`` do not require ``await``. Other methods involving
actual computation will return a coroutine and hence require awaiting:
.. code-block:: python
async with Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist') # trivial method, synchronous operation
regressor.client = client # accessing attribute, synchronous operation
regressor = await regressor.fit(X, y, eval_set=[(X, y)])
prediction = await regressor.predict(X)
# Use `client.compute` instead of the `compute` method from dask collection
print(await client.compute(prediction))
Be careful that XGBoost uses all the workers supplied by the ``client`` object. If you
are training on GPU cluster and have 2 GPUs, the client object passed to XGBoost should
return 2 workers.
***************************************************************************** *****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
***************************************************************************** *****************************************************************************

View File

@ -1,4 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals # pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
"""Dask extensions for distributed training. See """Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
tutorial. Also xgboost/demo/dask for some examples. tutorial. Also xgboost/demo/dask for some examples.
@ -35,6 +37,11 @@ from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc from .sklearn import xgboost_model_doc
try:
from distributed import Client
except ImportError:
Client = None
# Current status is considered as initial support, many features are # Current status is considered as initial support, many features are
# not properly supported yet. # not properly supported yet.
# #
@ -43,6 +50,17 @@ from .sklearn import xgboost_model_doc
# - Label encoding. # - Label encoding.
# - CV # - CV
# - Ranking # - Ranking
#
# Note for developers:
# As of writing asyncio is still a new feature of Python and in depth
# documentation is rare. Best examples of various asyncio tricks are in dask
# (luckily). Classes like Client, Worker are awaitable. Some general rules
# for the implementation here:
# - Synchronous world is different from asynchronous one, and they don't
# mix well.
# - Write everything with async, then use distributed Client sync function
# to do the switch.
LOGGER = logging.getLogger('[xgboost.dask]') LOGGER = logging.getLogger('[xgboost.dask]')
@ -125,6 +143,12 @@ def _get_client_workers(client):
workers = client.scheduler_info()['workers'] workers = client.scheduler_info()['workers']
return workers return workers
# From the implementation point of view, DaskDMatrix complicates a lots of
# things. A large portion of the code base is about syncing and extracting
# stuffs from DaskDMatrix. But having an independent data structure gives us a
# chance to perform some specialized optimizations, like building histogram
# index directly.
class DaskDMatrix: class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes # pylint: disable=missing-docstring, too-many-instance-attributes
@ -133,6 +157,11 @@ class DaskDMatrix:
the input data explicitly if you want to see actual computation of the input data explicitly if you want to see actual computation of
constructing `DaskDMatrix`. constructing `DaskDMatrix`.
.. note::
DaskDMatrix does not repartition or move data between workers. It's
the caller's responsibility to balance the data.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
Parameters Parameters
@ -165,7 +194,7 @@ class DaskDMatrix:
feature_names=None, feature_names=None,
feature_types=None): feature_types=None):
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client: Client = _xgb_get_client(client)
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
@ -187,7 +216,13 @@ class DaskDMatrix:
self.has_label = label is not None self.has_label = label is not None
self.has_weights = weight is not None self.has_weights = weight is not None
client.sync(self.map_local_data, client, data, label, weight) self.is_quantile = False
self._init = client.sync(self.map_local_data,
client, data, label, weight)
def __await__(self):
return self._init.__await__()
async def map_local_data(self, client, data, label=None, weights=None): async def map_local_data(self, client, data, label=None, weights=None):
'''Obtain references to local data.''' '''Obtain references to local data.'''
@ -264,29 +299,46 @@ class DaskDMatrix:
self.worker_map = worker_map self.worker_map = worker_map
def get_worker_x_ordered(self, worker): return self
list_of_parts = self.worker_map[worker.address]
def create_fn_args(self):
'''Create a dictionary of objects that can be pickled for function
arguments.
'''
return {'feature_names': self.feature_names,
'feature_types': self.feature_types,
'has_label': self.has_label,
'has_weights': self.has_weights,
'missing': self.missing,
'worker_map': self.worker_map,
'is_quantile': self.is_quantile}
def _get_worker_x_ordered(worker_map, partition_order, worker):
list_of_parts = worker_map[worker.address]
client = get_client() client = get_client()
list_of_parts_value = client.gather(list_of_parts) list_of_parts_value = client.gather(list_of_parts)
result = [] result = []
for i, part in enumerate(list_of_parts): for i, part in enumerate(list_of_parts):
result.append((list_of_parts_value[i][0], result.append((list_of_parts_value[i][0],
self.partition_order[part.key])) partition_order[part.key]))
return result return result
def get_worker_parts(self, worker):
'''Get mapped parts of data in each worker.''' def _get_worker_parts(has_label, has_weights, worker_map, worker):
list_of_parts = self.worker_map[worker.address] '''Get mapped parts of data in each worker from DaskDMatrix.'''
list_of_parts = worker_map[worker.address]
assert list_of_parts, 'data in ' + worker.address + ' was moved.' assert list_of_parts, 'data in ' + worker.address + ' was moved.'
assert isinstance(list_of_parts, list) assert isinstance(list_of_parts, list)
# `get_worker_parts` is launched inside worker. In dask side # `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`. # this should be equal to `worker._get_client`.
client = get_client() client = get_client()
list_of_parts = client.gather(list_of_parts) list_of_parts = client.gather(list_of_parts)
if self.has_label: if has_label:
if self.has_weights: if has_weights:
data, labels, weights = zip(*list_of_parts) data, labels, weights = zip(*list_of_parts)
else: else:
data, labels = zip(*list_of_parts) data, labels = zip(*list_of_parts)
@ -297,66 +349,6 @@ class DaskDMatrix:
weights = None weights = None
return data, labels, weights return data, labels, weights
def get_worker_data(self, worker):
'''Get data that local to worker.
Parameters
----------
worker: The worker used as key to data.
Returns
-------
A DMatrix object.
'''
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types)
return d
data, labels, weights = self.get_worker_parts(worker)
data = concat(data)
if self.has_label:
labels = concat(labels)
else:
labels = None
if self.has_weights:
weights = concat(weights)
else:
weights = None
dmatrix = DMatrix(data,
labels,
weight=weights,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=worker.nthreads)
return dmatrix
def get_worker_data_shape(self, worker):
'''Get the shape of data X in each worker.'''
data, _, _ = self.get_worker_parts(worker)
shapes = [d.shape for d in data]
rows = 0
cols = 0
for shape in shapes:
rows += shape[0]
c = shape[1]
assert cols in (0, c), 'Shape between partitions are not the' \
' same. Got: {left} and {right}'.format(left=c, right=cols)
cols = c
return (rows, cols)
class DaskPartitionIter(DataIter): # pylint: disable=R0902 class DaskPartitionIter(DataIter): # pylint: disable=R0902
'''A data iterator for `DaskDeviceQuantileDMatrix`. '''A data iterator for `DaskDeviceQuantileDMatrix`.
@ -460,6 +452,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
---------- ----------
max_bin: Number of bins for histogram construction. max_bin: Number of bins for histogram construction.
''' '''
def __init__(self, client, data, label=None, weight=None, def __init__(self, client, data, label=None, weight=None,
missing=None, missing=None,
@ -471,39 +464,99 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types) feature_types=feature_types)
self.max_bin = max_bin self.max_bin = max_bin
self.is_quantile = True
def get_worker_data(self, worker): def create_fn_args(self):
if worker.address not in set(self.worker_map.keys()): args = super().create_fn_args()
args['max_bin'] = self.max_bin
return args
def _create_device_quantile_dmatrix(feature_names, feature_types,
has_label,
has_weights, missing, worker_map,
max_bin):
worker = distributed_get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \ msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format( 'All workers associated with this DMatrix: {workers}'.format(
address=worker.address, address=worker.address,
workers=set(self.worker_map.keys())) workers=set(worker_map.keys()))
LOGGER.warning(msg) LOGGER.warning(msg)
import cupy # pylint: disable=import-error import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)), d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=self.feature_names, feature_names=feature_names,
feature_types=self.feature_types, feature_types=feature_types,
max_bin=self.max_bin) max_bin=max_bin)
return d return d
data, labels, weights = self.get_worker_parts(worker) data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights) it = DaskPartitionIter(data=data, label=labels, weight=weights)
dmatrix = DeviceQuantileDMatrix(it, dmatrix = DeviceQuantileDMatrix(it,
missing=self.missing, missing=missing,
feature_names=self.feature_names, feature_names=feature_names,
feature_types=self.feature_types, feature_types=feature_types,
nthread=worker.nthreads, nthread=worker.nthreads,
max_bin=self.max_bin) max_bin=max_bin)
return dmatrix return dmatrix
def _get_rabit_args(worker_map, client): def _create_dmatrix(feature_names, feature_types, has_label,
has_weights, missing, worker_map):
'''Get data that local to worker from DaskDMatrix.
Returns
-------
A DMatrix object.
'''
worker = distributed_get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(worker_map.keys()))
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types)
return d
data, labels, weights = _get_worker_parts(has_label, has_weights,
worker_map, worker)
data = concat(data)
if has_label:
labels = concat(labels)
else:
labels = None
if has_weights:
weights = concat(weights)
else:
weights = None
dmatrix = DMatrix(data,
labels,
weight=weights,
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads)
return dmatrix
def _dmatrix_from_worker_map(is_quantile, **kwargs):
if is_quantile:
return _create_device_quantile_dmatrix(**kwargs)
return _create_dmatrix(**kwargs)
async def _get_rabit_args(worker_map, client: Client):
'''Get rabit context arguments from data distribution in DaskDMatrix.''' '''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address) host = distributed_comm.get_address_host(client.scheduler.address)
env = await client.run_on_scheduler(
env = client.run_on_scheduler(_start_tracker, host.strip('/:'), _start_tracker, host.strip('/:'), len(worker_map))
len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()] rabit_args = [('%s=%s' % item).encode() for item in env.items()]
return rabit_args return rabit_args
@ -514,6 +567,73 @@ def _get_rabit_args(worker_map, client):
# evaluation history is instead returned. # evaluation history is instead returned.
async def _train_async(client, params, dtrain: DaskDMatrix,
*args, evals=(), **kwargs):
_assert_dask_support()
client: Client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
workers = list(_get_client_workers(client).keys())
rabit_args = await _get_rabit_args(workers, client)
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker()
with RabitContext(rabit_args):
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
local_evals = []
if evals_ref:
for ref, name in evals_ref:
if ref['worker_map'] == dtrain_ref['worker_map']:
local_evals.append((local_dtrain, name))
continue
local_evals.append((_dmatrix_from_worker_map(**ref), name))
local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys() and \
local_param['nthread'] is not None and \
local_param['nthread'] != worker.nthreads:
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys() and \
local_param['n_jobs'] is not None and \
local_param['n_jobs'] != worker.nthreads:
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
evals_result=local_history,
evals=local_evals,
**kwargs)
ret = {'booster': bst, 'history': local_history}
if local_dtrain.num_row() == 0:
ret = None
return ret
if evals:
evals = [(e.create_fn_args(), name) for e, name in evals]
futures = client.map(dispatched_train,
workers,
[dtrain.create_fn_args()] * len(workers),
[evals] * len(workers),
pure=False,
workers=workers)
results = await client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
def train(client, params, dtrain, *args, evals=(), **kwargs): def train(client, params, dtrain, *args, evals=(), **kwargs):
'''Train XGBoost model. '''Train XGBoost model.
@ -544,75 +664,20 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys(): return client.sync(_train_async, client, params,
raise ValueError( dtrain=dtrain, *args, evals=evals, **kwargs)
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
workers = list(_get_client_workers(client).keys())
rabit_args = _get_rabit_args(workers, client)
def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker()
with RabitContext(rabit_args):
local_dtrain = dtrain.get_worker_data(worker)
local_evals = []
if evals:
for mat, name in evals:
if mat is dtrain:
local_evals.append((local_dtrain, name))
continue
local_mat = mat.get_worker_data(worker)
local_evals.append((local_mat, name))
local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys() and \
local_param['nthread'] is not None and \
local_param['nthread'] != worker.nthreads:
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys() and \
local_param['n_jobs'] is not None and \
local_param['n_jobs'] != worker.nthreads:
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
evals_result=local_history,
evals=local_evals,
**kwargs)
ret = {'booster': bst, 'history': local_history}
if local_dtrain.num_row() == 0:
ret = None
return ret
futures = client.map(dispatched_train,
workers,
pure=False,
workers=workers)
results = client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
def _direct_predict_impl(client, data, predict_fn): async def _direct_predict_impl(client, data, predict_fn):
if isinstance(data, da.Array): if isinstance(data, da.Array):
predictions = client.submit( predictions = await client.submit(
da.map_blocks, da.map_blocks,
predict_fn, data, False, drop_axis=1, predict_fn, data, False, drop_axis=1,
dtype=numpy.float32 dtype=numpy.float32
).result() ).result()
return predictions return predictions
if isinstance(data, dd.DataFrame): if isinstance(data, dd.DataFrame):
predictions = client.submit( predictions = await client.submit(
dd.map_partitions, dd.map_partitions,
predict_fn, data, True, predict_fn, data, True,
meta=dd.utils.make_meta({'prediction': 'f4'}) meta=dd.utils.make_meta({'prediction': 'f4'})
@ -622,6 +687,100 @@ def _direct_predict_impl(client, data, predict_fn):
' is not supported by direct prediction') ' is not supported by direct prediction')
# pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, *args,
missing=numpy.nan):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))
def mapped_predict(partition, is_df):
worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, *args, validate_features=False)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction'])
else:
predt = DataFrame(predt, columns=['prediction'])
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict)
# Prediction on dask DMatrix.
worker_map = data.worker_map
partition_order = data.partition_order
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map, partition_order,
worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part, feature_names=feature_names,
feature_types=feature_types,
missing=missing, nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
predictions.append(ret)
return predictions
def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = _get_worker_x_ordered(worker_map,
partition_order, worker)
shapes = [(part.shape, order) for part, order in list_of_parts]
return shapes
async def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = await client.submit(func, wid,
pure=False,
workers=list_of_workers)
futures.append(f)
# Get delayed objects
results = await client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results
results = await map_function(dispatched_predict)
shapes = await map_function(dispatched_get_shape)
# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
predictions = await da.concatenate(arrays, axis=0)
return predictions
def predict(client, model, data, *args, missing=numpy.nan): def predict(client, model, data, *args, missing=numpy.nan):
'''Run prediction with a trained booster. '''Run prediction with a trained booster.
@ -651,94 +810,44 @@ def predict(client, model, data, *args, missing=numpy.nan):
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
return client.sync(_predict_async, client, model, data, *args,
missing=missing)
async def _inplace_predict_async(client, model, data,
iteration_range=(0, 0),
predict_type='value',
missing=numpy.nan):
client = _xgb_get_client(client)
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
booster = model['booster'] booster = model['booster']
else: else:
raise TypeError(_expect([Booster, dict], type(model))) raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
type(data)))
def mapped_predict(partition, is_df): def mapped_predict(data, is_df):
worker = distributed_get_worker() worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads) prediction = booster.inplace_predict(
predt = booster.predict(m, *args, validate_features=False) data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
if is_df: if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'): if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf # pylint: disable=import-error import cudf # pylint: disable=import-error
predt = cudf.DataFrame(predt, columns=['prediction']) prediction = cudf.DataFrame({'prediction': prediction},
dtype=numpy.float32)
else: else:
predt = DataFrame(predt, columns=['prediction']) # If it's from pandas, the partition is a numpy array
return predt prediction = DataFrame(prediction, columns=['prediction'],
dtype=numpy.float32)
return prediction
if isinstance(data, (da.Array, dd.DataFrame)): return await _direct_predict_impl(client, data, mapped_predict)
return _direct_predict_impl(client, data, mapped_predict)
# Prediction on dask DMatrix.
worker_map = data.worker_map
def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=data.feature_names,
feature_types=data.feature_types,
missing=data.missing,
nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
ret = (delayed(predt), order)
predictions.append(ret)
return predictions
def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
LOGGER.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []
for part, order in list_of_parts:
shapes.append((part.shape, order))
return shapes
def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = client.submit(func, wid,
pure=False,
workers=list_of_workers)
futures.append(f)
# Get delayed objects
results = client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results
results = map_function(dispatched_predict)
shapes = map_function(dispatched_get_shape)
# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
dtype=numpy.float32))
predictions = da.concatenate(arrays, axis=0)
return predictions
def inplace_predict(client, model, data, def inplace_predict(client, model, data,
@ -770,38 +879,14 @@ def inplace_predict(client, model, data,
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
if isinstance(model, Booster): return client.sync(_inplace_predict_async, client, model=model, data=data,
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
def mapped_predict(data, is_df):
worker = distributed_get_worker()
booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict(
data,
iteration_range=iteration_range, iteration_range=iteration_range,
predict_type=predict_type, predict_type=predict_type,
missing=missing) missing=missing)
if is_df:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf # pylint: disable=import-error
prediction = cudf.DataFrame({'prediction': prediction},
dtype=numpy.float32)
else:
# If it's from pandas, the partition is a numpy array
prediction = DataFrame(prediction, columns=['prediction'],
dtype=numpy.float32)
return prediction
return _direct_predict_impl(client, data, mapped_predict)
def _evaluation_matrices(client, validation_set, sample_weights, missing): async def _evaluation_matrices(client, validation_set,
sample_weights, missing):
''' '''
Parameters Parameters
---------- ----------
@ -826,8 +911,8 @@ def _evaluation_matrices(client, validation_set, sample_weights, missing):
for i, e in enumerate(validation_set): for i, e in enumerate(validation_set):
w = (sample_weights[i] w = (sample_weights[i]
if sample_weights is not None else None) if sample_weights is not None else None)
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w, dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
missing=missing) weight=w, missing=missing)
evals.append((dmat, 'validation_{}'.format(i))) evals.append((dmat, 'validation_{}'.format(i)))
else: else:
evals = None evals = None
@ -840,9 +925,7 @@ class DaskScikitLearnBase(XGBModel):
_client = None _client = None
# pylint: disable=arguments-differ # pylint: disable=arguments-differ
def fit(self, def fit(self, X, y,
X,
y,
sample_weights=None, sample_weights=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
@ -879,6 +962,12 @@ class DaskScikitLearnBase(XGBModel):
prediction : dask.array.Array''' prediction : dask.array.Array'''
raise NotImplementedError raise NotImplementedError
def __await__(self):
# Generate a coroutine wrapper to make this class awaitable.
async def _():
return self
return self.client.sync(_).__await__()
@property @property
def client(self): def client(self):
'''The dask client used in this model.''' '''The dask client used in this model.'''
@ -892,40 +981,51 @@ class DaskScikitLearnBase(XGBModel):
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model']) ['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-docstring # pylint: disable=missing-class-docstring
def fit(self, async def _fit_async(self,
X, X,
y, y,
sample_weights=None, sample_weights=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
_assert_dask_support() dtrain = await DaskDMatrix(client=self.client,
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights, data=X, label=y, weight=sample_weights,
missing=self.missing) missing=self.missing)
params = self.get_xgb_params() params = self.get_xgb_params()
evals = _evaluation_matrices(self.client, evals = await _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set, eval_set, sample_weight_eval_set,
self.missing) self.missing)
results = await train(client=self.client, params=params, dtrain=dtrain,
results = train(self.client, params, dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, verbose_eval=verbose) evals=evals, verbose_eval=verbose)
# pylint: disable=attribute-defined-outside-init
self._Booster = results['booster'] self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history'] self.evals_result_ = results['history']
return self return self
def predict(self, data): # pylint: disable=arguments-differ # pylint: disable=missing-docstring
def fit(self, X, y,
sample_weights=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data, return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set,
verbose)
async def _predict_async(self, data): # pylint: disable=arguments-differ
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing) missing=self.missing)
pred_probs = predict(client=self.client, pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)
return pred_probs return pred_probs
def predict(self, data):
_assert_dask_support()
return self.client.sync(self._predict_async, data)
@xgboost_model_doc( @xgboost_model_doc(
'Implementation of the scikit-learn API for XGBoost classification.', 'Implementation of the scikit-learn API for XGBoost classification.',
@ -935,24 +1035,21 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
_client = None _client = None
def fit(self, async def _fit_async(self, X, y,
X,
y,
sample_weights=None, sample_weights=None,
eval_set=None, eval_set=None,
sample_weight_eval_set=None, sample_weight_eval_set=None,
verbose=True): verbose=True):
_assert_dask_support() dtrain = await DaskDMatrix(client=self.client,
dtrain = DaskDMatrix(client=self.client,
data=X, label=y, weight=sample_weights, data=X, label=y, weight=sample_weights,
missing=self.missing) missing=self.missing)
params = self.get_xgb_params() params = self.get_xgb_params()
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if isinstance(y, (da.Array)): if isinstance(y, (da.Array)):
self.classes_ = da.unique(y).compute() self.classes_ = await self.client.compute(da.unique(y))
else: else:
self.classes_ = y.drop_duplicates().compute() self.classes_ = await self.client.compute(y.drop_duplicates())
self.n_classes_ = len(self.classes_) self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2: if self.n_classes_ > 2:
@ -961,10 +1058,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
else: else:
params["objective"] = "binary:logistic" params["objective"] = "binary:logistic"
evals = _evaluation_matrices(self.client, evals = await _evaluation_matrices(self.client,
eval_set, sample_weight_eval_set, eval_set, sample_weight_eval_set,
self.missing) self.missing)
results = train(self.client, params, dtrain, results = await train(client=self.client, params=params, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, verbose_eval=verbose) evals=evals, verbose_eval=verbose)
self._Booster = results['booster'] self._Booster = results['booster']
@ -972,10 +1069,22 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
self.evals_result_ = results['history'] self.evals_result_ = results['history']
return self return self
def predict(self, data): # pylint: disable=arguments-differ def fit(self, X, y,
sample_weights=None,
eval_set=None,
sample_weight_eval_set=None,
verbose=True):
_assert_dask_support() _assert_dask_support()
test_dmatrix = DaskDMatrix(client=self.client, data=data, return self.client.sync(self._fit_async, X, y, sample_weights,
eval_set, sample_weight_eval_set, verbose)
async def _predict_async(self, data):
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
missing=self.missing) missing=self.missing)
pred_probs = predict(client=self.client, pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix) model=self.get_booster(), data=test_dmatrix)
return pred_probs return pred_probs
def predict(self, data): # pylint: disable=arguments-differ
_assert_dask_support()
return self.client.sync(self._predict_async, data)

View File

@ -2,6 +2,7 @@ import sys
import os import os
import pytest import pytest
import numpy as np import numpy as np
import asyncio
import unittest import unittest
import xgboost import xgboost
import subprocess import subprocess
@ -219,7 +220,7 @@ class TestDistributedGPU(unittest.TestCase):
with LocalCUDACluster() as cluster: with LocalCUDACluster() as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = list(dxgb._get_client_workers(client).keys()) workers = list(dxgb._get_client_workers(client).keys())
rabit_args = dxgb._get_rabit_args(workers, client) rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
futures = client.map(runit, futures = client.map(runit,
workers, workers,
pure=False, pure=False,
@ -242,3 +243,39 @@ class TestDistributedGPU(unittest.TestCase):
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_same_on_all_workers(self): def test_quantile_same_on_all_workers(self):
self.run_quantile('SameOnAllWorkers') self.run_quantile('SameOnAllWorkers')
async def run_from_dask_array_asyncio(scheduler_address):
async with Client(scheduler_address, asynchronous=True) as client:
import cupy as cp
X, y = generate_array()
X = X.map_blocks(cp.array)
y = y.map_blocks(cp.array)
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y)
output = await xgboost.dask.train(client, {'tree_method': 'gpu_hist'},
dtrain=m)
with_m = await xgboost.dask.predict(client, output, m)
with_X = await xgboost.dask.predict(client, output, X)
inplace = await xgboost.dask.inplace_predict(client, output, X)
assert isinstance(with_m, da.Array)
assert isinstance(with_X, da.Array)
assert isinstance(inplace, da.Array)
cp.testing.assert_allclose(await client.compute(with_m),
await client.compute(with_X))
cp.testing.assert_allclose(await client.compute(with_m),
await client.compute(inplace))
client.shutdown()
return output
def test_with_asyncio():
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
address = client.scheduler.address
output = asyncio.run(run_from_dask_array_asyncio(address))
assert isinstance(output['booster'], xgboost.Booster)
assert isinstance(output['history'], dict)

View File

@ -27,8 +27,10 @@ def run_rabit_ops(client, n_workers):
from xgboost import rabit from xgboost import rabit
workers = list(_get_client_workers(client).keys()) workers = list(_get_client_workers(client).keys())
rabit_args = _get_rabit_args(workers, client) rabit_args = client.sync(_get_rabit_args, workers, client)
assert not rabit.is_distributed() assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def local_test(worker_id): def local_test(worker_id):
with RabitContext(rabit_args): with RabitContext(rabit_args):

View File

@ -4,6 +4,7 @@ import xgboost as xgb
import sys import sys
import numpy as np import numpy as np
import json import json
import asyncio
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -327,3 +328,96 @@ def test_empty_dmatrix_approx():
parameters = {'tree_method': 'approx'} parameters = {'tree_method': 'approx'}
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters) run_empty_dmatrix_cls(client, parameters)
async def run_from_dask_array_asyncio(scheduler_address):
async with Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
m = await DaskDMatrix(client, X, y)
output = await xgb.dask.train(client, {}, dtrain=m)
with_m = await xgb.dask.predict(client, output, m)
with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X)
assert isinstance(with_m, da.Array)
assert isinstance(with_X, da.Array)
assert isinstance(inplace, da.Array)
np.testing.assert_allclose(await client.compute(with_m),
await client.compute(with_X))
np.testing.assert_allclose(await client.compute(with_m),
await client.compute(inplace))
client.shutdown()
return output
async def run_dask_regressor_asyncio(scheduler_address):
async with Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
n_estimators=2)
regressor.set_params(tree_method='hist')
regressor.client = client
await regressor.fit(X, y, eval_set=[(X, y)])
prediction = await regressor.predict(X)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
history = regressor.evals_result()
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
assert list(history['validation_0'].keys())[0] == 'rmse'
assert len(history['validation_0']['rmse']) == 2
async def run_dask_classifier_asyncio(scheduler_address):
async with Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
y = (y * 10).astype(np.int32)
classifier = await xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
classifier.client = client
await classifier.fit(X, y, eval_set=[(X, y)])
prediction = await classifier.predict(X)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
history = classifier.evals_result()
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
assert list(history.keys())[0] == 'validation_0'
assert list(history['validation_0'].keys())[0] == 'merror'
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2
assert classifier.n_classes_ == 10
# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
await classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10
prediction = await classifier.predict(X_d)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
def test_with_asyncio():
with LocalCluster() as cluster:
with Client(cluster) as client:
address = client.scheduler.address
output = asyncio.run(run_from_dask_array_asyncio(address))
assert isinstance(output['booster'], xgb.Booster)
assert isinstance(output['history'], dict)
asyncio.run(run_dask_regressor_asyncio(address))
asyncio.run(run_dask_classifier_asyncio(address))