[Dask] Asyncio support. (#5862)
This commit is contained in:
parent
e4a273e1da
commit
fa3715f584
@ -12,6 +12,12 @@ algorithm. For an overview of GPU based training and internal working, see `A N
|
|||||||
Official Dask API for XGBoost
|
Official Dask API for XGBoost
|
||||||
<https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7>`_.
|
<https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7>`_.
|
||||||
|
|
||||||
|
**Contents**
|
||||||
|
|
||||||
|
.. contents::
|
||||||
|
:backlinks: none
|
||||||
|
:local:
|
||||||
|
|
||||||
************
|
************
|
||||||
Requirements
|
Requirements
|
||||||
************
|
************
|
||||||
@ -105,6 +111,60 @@ set:
|
|||||||
|
|
||||||
XGBoost will use 8 threads in each training process.
|
XGBoost will use 8 threads in each training process.
|
||||||
|
|
||||||
|
********************
|
||||||
|
Working with asyncio
|
||||||
|
********************
|
||||||
|
|
||||||
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
|
XGBoost dask interface supports the new ``asyncio`` in Python and can be integrated into
|
||||||
|
asynchronous workflows. For using dask with asynchronous operations, please refer to
|
||||||
|
`dask example <https://examples.dask.org/applications/async-await.html>`_ and document in
|
||||||
|
`distributed <https://distributed.dask.org/en/latest/asynchronous.html>`_. As XGBoost
|
||||||
|
takes ``Client`` object as an argument for both training and prediction, so when
|
||||||
|
``asynchronous=True`` is specified when creating ``Client``, the dask interface can adapt
|
||||||
|
the change accordingly. All functions provided by the functional interface returns a
|
||||||
|
coroutine when called in async function, and hence require awaiting to get the result,
|
||||||
|
including ``DaskDMatrix``.
|
||||||
|
|
||||||
|
Functional interface:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
m = await xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||||
|
|
||||||
|
with_m = await xgb.dask.predict(client, output, m)
|
||||||
|
with_X = await xgb.dask.predict(client, output, X)
|
||||||
|
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||||
|
|
||||||
|
# Use `client.compute` instead of the `compute` method from dask collection
|
||||||
|
print(await client.compute(with_m))
|
||||||
|
|
||||||
|
|
||||||
|
While for Scikit Learn interface, trivial methods like ``set_params`` and accessing class
|
||||||
|
attributes like ``evals_result_`` do not require ``await``. Other methods involving
|
||||||
|
actual computation will return a coroutine and hence require awaiting:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
|
regressor.set_params(tree_method='hist') # trivial method, synchronous operation
|
||||||
|
regressor.client = client # accessing attribute, synchronous operation
|
||||||
|
regressor = await regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
prediction = await regressor.predict(X)
|
||||||
|
|
||||||
|
# Use `client.compute` instead of the `compute` method from dask collection
|
||||||
|
print(await client.compute(prediction))
|
||||||
|
|
||||||
|
Be careful that XGBoost uses all the workers supplied by the ``client`` object. If you
|
||||||
|
are training on GPU cluster and have 2 GPUs, the client object passed to XGBoost should
|
||||||
|
return 2 workers.
|
||||||
|
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-locals
|
# pylint: disable=too-many-arguments, too-many-locals
|
||||||
|
# pylint: disable=missing-class-docstring, invalid-name
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
"""Dask extensions for distributed training. See
|
"""Dask extensions for distributed training. See
|
||||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||||
tutorial. Also xgboost/demo/dask for some examples.
|
tutorial. Also xgboost/demo/dask for some examples.
|
||||||
@ -35,6 +37,11 @@ from .tracker import RabitTracker
|
|||||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||||
from .sklearn import xgboost_model_doc
|
from .sklearn import xgboost_model_doc
|
||||||
|
|
||||||
|
try:
|
||||||
|
from distributed import Client
|
||||||
|
except ImportError:
|
||||||
|
Client = None
|
||||||
|
|
||||||
# Current status is considered as initial support, many features are
|
# Current status is considered as initial support, many features are
|
||||||
# not properly supported yet.
|
# not properly supported yet.
|
||||||
#
|
#
|
||||||
@ -43,6 +50,17 @@ from .sklearn import xgboost_model_doc
|
|||||||
# - Label encoding.
|
# - Label encoding.
|
||||||
# - CV
|
# - CV
|
||||||
# - Ranking
|
# - Ranking
|
||||||
|
#
|
||||||
|
# Note for developers:
|
||||||
|
|
||||||
|
# As of writing asyncio is still a new feature of Python and in depth
|
||||||
|
# documentation is rare. Best examples of various asyncio tricks are in dask
|
||||||
|
# (luckily). Classes like Client, Worker are awaitable. Some general rules
|
||||||
|
# for the implementation here:
|
||||||
|
# - Synchronous world is different from asynchronous one, and they don't
|
||||||
|
# mix well.
|
||||||
|
# - Write everything with async, then use distributed Client sync function
|
||||||
|
# to do the switch.
|
||||||
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||||
@ -125,6 +143,12 @@ def _get_client_workers(client):
|
|||||||
workers = client.scheduler_info()['workers']
|
workers = client.scheduler_info()['workers']
|
||||||
return workers
|
return workers
|
||||||
|
|
||||||
|
# From the implementation point of view, DaskDMatrix complicates a lots of
|
||||||
|
# things. A large portion of the code base is about syncing and extracting
|
||||||
|
# stuffs from DaskDMatrix. But having an independent data structure gives us a
|
||||||
|
# chance to perform some specialized optimizations, like building histogram
|
||||||
|
# index directly.
|
||||||
|
|
||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
@ -133,6 +157,11 @@ class DaskDMatrix:
|
|||||||
the input data explicitly if you want to see actual computation of
|
the input data explicitly if you want to see actual computation of
|
||||||
constructing `DaskDMatrix`.
|
constructing `DaskDMatrix`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
DaskDMatrix does not repartition or move data between workers. It's
|
||||||
|
the caller's responsibility to balance the data.
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -165,7 +194,7 @@ class DaskDMatrix:
|
|||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None):
|
feature_types=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client: Client = _xgb_get_client(client)
|
||||||
|
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
@ -187,7 +216,13 @@ class DaskDMatrix:
|
|||||||
self.has_label = label is not None
|
self.has_label = label is not None
|
||||||
self.has_weights = weight is not None
|
self.has_weights = weight is not None
|
||||||
|
|
||||||
client.sync(self.map_local_data, client, data, label, weight)
|
self.is_quantile = False
|
||||||
|
|
||||||
|
self._init = client.sync(self.map_local_data,
|
||||||
|
client, data, label, weight)
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
return self._init.__await__()
|
||||||
|
|
||||||
async def map_local_data(self, client, data, label=None, weights=None):
|
async def map_local_data(self, client, data, label=None, weights=None):
|
||||||
'''Obtain references to local data.'''
|
'''Obtain references to local data.'''
|
||||||
@ -264,98 +299,55 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
self.worker_map = worker_map
|
self.worker_map = worker_map
|
||||||
|
|
||||||
def get_worker_x_ordered(self, worker):
|
return self
|
||||||
list_of_parts = self.worker_map[worker.address]
|
|
||||||
client = get_client()
|
|
||||||
list_of_parts_value = client.gather(list_of_parts)
|
|
||||||
result = []
|
|
||||||
for i, part in enumerate(list_of_parts):
|
|
||||||
result.append((list_of_parts_value[i][0],
|
|
||||||
self.partition_order[part.key]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_worker_parts(self, worker):
|
def create_fn_args(self):
|
||||||
'''Get mapped parts of data in each worker.'''
|
'''Create a dictionary of objects that can be pickled for function
|
||||||
list_of_parts = self.worker_map[worker.address]
|
arguments.
|
||||||
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
|
||||||
assert isinstance(list_of_parts, list)
|
|
||||||
|
|
||||||
# `get_worker_parts` is launched inside worker. In dask side
|
|
||||||
# this should be equal to `worker._get_client`.
|
|
||||||
client = get_client()
|
|
||||||
list_of_parts = client.gather(list_of_parts)
|
|
||||||
|
|
||||||
if self.has_label:
|
|
||||||
if self.has_weights:
|
|
||||||
data, labels, weights = zip(*list_of_parts)
|
|
||||||
else:
|
|
||||||
data, labels = zip(*list_of_parts)
|
|
||||||
weights = None
|
|
||||||
else:
|
|
||||||
data = [d[0] for d in list_of_parts]
|
|
||||||
labels = None
|
|
||||||
weights = None
|
|
||||||
return data, labels, weights
|
|
||||||
|
|
||||||
def get_worker_data(self, worker):
|
|
||||||
'''Get data that local to worker.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
worker: The worker used as key to data.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A DMatrix object.
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if worker.address not in set(self.worker_map.keys()):
|
return {'feature_names': self.feature_names,
|
||||||
msg = 'worker {address} has an empty DMatrix. ' \
|
'feature_types': self.feature_types,
|
||||||
'All workers associated with this DMatrix: {workers}'.format(
|
'has_label': self.has_label,
|
||||||
address=worker.address,
|
'has_weights': self.has_weights,
|
||||||
workers=set(self.worker_map.keys()))
|
'missing': self.missing,
|
||||||
LOGGER.warning(msg)
|
'worker_map': self.worker_map,
|
||||||
d = DMatrix(numpy.empty((0, 0)),
|
'is_quantile': self.is_quantile}
|
||||||
feature_names=self.feature_names,
|
|
||||||
feature_types=self.feature_types)
|
|
||||||
return d
|
|
||||||
|
|
||||||
data, labels, weights = self.get_worker_parts(worker)
|
|
||||||
|
|
||||||
data = concat(data)
|
def _get_worker_x_ordered(worker_map, partition_order, worker):
|
||||||
|
list_of_parts = worker_map[worker.address]
|
||||||
|
client = get_client()
|
||||||
|
list_of_parts_value = client.gather(list_of_parts)
|
||||||
|
result = []
|
||||||
|
for i, part in enumerate(list_of_parts):
|
||||||
|
result.append((list_of_parts_value[i][0],
|
||||||
|
partition_order[part.key]))
|
||||||
|
return result
|
||||||
|
|
||||||
if self.has_label:
|
|
||||||
labels = concat(labels)
|
def _get_worker_parts(has_label, has_weights, worker_map, worker):
|
||||||
else:
|
'''Get mapped parts of data in each worker from DaskDMatrix.'''
|
||||||
labels = None
|
list_of_parts = worker_map[worker.address]
|
||||||
if self.has_weights:
|
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
||||||
weights = concat(weights)
|
assert isinstance(list_of_parts, list)
|
||||||
|
|
||||||
|
# `_get_worker_parts` is launched inside worker. In dask side
|
||||||
|
# this should be equal to `worker._get_client`.
|
||||||
|
client = get_client()
|
||||||
|
list_of_parts = client.gather(list_of_parts)
|
||||||
|
|
||||||
|
if has_label:
|
||||||
|
if has_weights:
|
||||||
|
data, labels, weights = zip(*list_of_parts)
|
||||||
else:
|
else:
|
||||||
|
data, labels = zip(*list_of_parts)
|
||||||
weights = None
|
weights = None
|
||||||
dmatrix = DMatrix(data,
|
else:
|
||||||
labels,
|
data = [d[0] for d in list_of_parts]
|
||||||
weight=weights,
|
labels = None
|
||||||
missing=self.missing,
|
weights = None
|
||||||
feature_names=self.feature_names,
|
return data, labels, weights
|
||||||
feature_types=self.feature_types,
|
|
||||||
nthread=worker.nthreads)
|
|
||||||
return dmatrix
|
|
||||||
|
|
||||||
def get_worker_data_shape(self, worker):
|
|
||||||
'''Get the shape of data X in each worker.'''
|
|
||||||
data, _, _ = self.get_worker_parts(worker)
|
|
||||||
|
|
||||||
shapes = [d.shape for d in data]
|
|
||||||
rows = 0
|
|
||||||
cols = 0
|
|
||||||
for shape in shapes:
|
|
||||||
rows += shape[0]
|
|
||||||
|
|
||||||
c = shape[1]
|
|
||||||
assert cols in (0, c), 'Shape between partitions are not the' \
|
|
||||||
' same. Got: {left} and {right}'.format(left=c, right=cols)
|
|
||||||
cols = c
|
|
||||||
return (rows, cols)
|
|
||||||
|
|
||||||
|
|
||||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||||
@ -460,6 +452,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
----------
|
----------
|
||||||
max_bin: Number of bins for histogram construction.
|
max_bin: Number of bins for histogram construction.
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, client, data, label=None, weight=None,
|
def __init__(self, client, data, label=None, weight=None,
|
||||||
missing=None,
|
missing=None,
|
||||||
@ -471,39 +464,99 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types)
|
feature_types=feature_types)
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
|
self.is_quantile = True
|
||||||
|
|
||||||
def get_worker_data(self, worker):
|
def create_fn_args(self):
|
||||||
if worker.address not in set(self.worker_map.keys()):
|
args = super().create_fn_args()
|
||||||
msg = 'worker {address} has an empty DMatrix. ' \
|
args['max_bin'] = self.max_bin
|
||||||
'All workers associated with this DMatrix: {workers}'.format(
|
return args
|
||||||
address=worker.address,
|
|
||||||
workers=set(self.worker_map.keys()))
|
|
||||||
LOGGER.warning(msg)
|
|
||||||
import cupy # pylint: disable=import-error
|
|
||||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
|
||||||
feature_names=self.feature_names,
|
|
||||||
feature_types=self.feature_types,
|
|
||||||
max_bin=self.max_bin)
|
|
||||||
return d
|
|
||||||
|
|
||||||
data, labels, weights = self.get_worker_parts(worker)
|
|
||||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
|
||||||
|
|
||||||
dmatrix = DeviceQuantileDMatrix(it,
|
|
||||||
missing=self.missing,
|
|
||||||
feature_names=self.feature_names,
|
|
||||||
feature_types=self.feature_types,
|
|
||||||
nthread=worker.nthreads,
|
|
||||||
max_bin=self.max_bin)
|
|
||||||
return dmatrix
|
|
||||||
|
|
||||||
|
|
||||||
def _get_rabit_args(worker_map, client):
|
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||||
|
has_label,
|
||||||
|
has_weights, missing, worker_map,
|
||||||
|
max_bin):
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
if worker.address not in set(worker_map.keys()):
|
||||||
|
msg = 'worker {address} has an empty DMatrix. ' \
|
||||||
|
'All workers associated with this DMatrix: {workers}'.format(
|
||||||
|
address=worker.address,
|
||||||
|
workers=set(worker_map.keys()))
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
import cupy # pylint: disable=import-error
|
||||||
|
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||||
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types,
|
||||||
|
max_bin=max_bin)
|
||||||
|
return d
|
||||||
|
|
||||||
|
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||||
|
worker_map, worker)
|
||||||
|
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
||||||
|
|
||||||
|
dmatrix = DeviceQuantileDMatrix(it,
|
||||||
|
missing=missing,
|
||||||
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types,
|
||||||
|
nthread=worker.nthreads,
|
||||||
|
max_bin=max_bin)
|
||||||
|
return dmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dmatrix(feature_names, feature_types, has_label,
|
||||||
|
has_weights, missing, worker_map):
|
||||||
|
'''Get data that local to worker from DaskDMatrix.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A DMatrix object.
|
||||||
|
|
||||||
|
'''
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
if worker.address not in set(worker_map.keys()):
|
||||||
|
msg = 'worker {address} has an empty DMatrix. ' \
|
||||||
|
'All workers associated with this DMatrix: {workers}'.format(
|
||||||
|
address=worker.address,
|
||||||
|
workers=set(worker_map.keys()))
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
d = DMatrix(numpy.empty((0, 0)),
|
||||||
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types)
|
||||||
|
return d
|
||||||
|
|
||||||
|
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||||
|
worker_map, worker)
|
||||||
|
data = concat(data)
|
||||||
|
|
||||||
|
if has_label:
|
||||||
|
labels = concat(labels)
|
||||||
|
else:
|
||||||
|
labels = None
|
||||||
|
if has_weights:
|
||||||
|
weights = concat(weights)
|
||||||
|
else:
|
||||||
|
weights = None
|
||||||
|
dmatrix = DMatrix(data,
|
||||||
|
labels,
|
||||||
|
weight=weights,
|
||||||
|
missing=missing,
|
||||||
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types,
|
||||||
|
nthread=worker.nthreads)
|
||||||
|
return dmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def _dmatrix_from_worker_map(is_quantile, **kwargs):
|
||||||
|
if is_quantile:
|
||||||
|
return _create_device_quantile_dmatrix(**kwargs)
|
||||||
|
return _create_dmatrix(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_rabit_args(worker_map, client: Client):
|
||||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||||
host = distributed_comm.get_address_host(client.scheduler.address)
|
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||||
|
env = await client.run_on_scheduler(
|
||||||
env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
|
_start_tracker, host.strip('/:'), len(worker_map))
|
||||||
len(worker_map))
|
|
||||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||||
return rabit_args
|
return rabit_args
|
||||||
|
|
||||||
@ -514,6 +567,73 @@ def _get_rabit_args(worker_map, client):
|
|||||||
# evaluation history is instead returned.
|
# evaluation history is instead returned.
|
||||||
|
|
||||||
|
|
||||||
|
async def _train_async(client, params, dtrain: DaskDMatrix,
|
||||||
|
*args, evals=(), **kwargs):
|
||||||
|
_assert_dask_support()
|
||||||
|
client: Client = _xgb_get_client(client)
|
||||||
|
if 'evals_result' in kwargs.keys():
|
||||||
|
raise ValueError(
|
||||||
|
'evals_result is not supported in dask interface.',
|
||||||
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
rabit_args = await _get_rabit_args(workers, client)
|
||||||
|
|
||||||
|
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
|
||||||
|
'''Perform training on a single worker. A local function prevents pickling.
|
||||||
|
|
||||||
|
'''
|
||||||
|
LOGGER.info('Training on %s', str(worker_addr))
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
with RabitContext(rabit_args):
|
||||||
|
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
|
||||||
|
local_evals = []
|
||||||
|
if evals_ref:
|
||||||
|
for ref, name in evals_ref:
|
||||||
|
if ref['worker_map'] == dtrain_ref['worker_map']:
|
||||||
|
local_evals.append((local_dtrain, name))
|
||||||
|
continue
|
||||||
|
local_evals.append((_dmatrix_from_worker_map(**ref), name))
|
||||||
|
|
||||||
|
local_history = {}
|
||||||
|
local_param = params.copy() # just to be consistent
|
||||||
|
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||||
|
if 'nthread' in local_param.keys() and \
|
||||||
|
local_param['nthread'] is not None and \
|
||||||
|
local_param['nthread'] != worker.nthreads:
|
||||||
|
msg += '`nthread` is specified. ' + msg
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
elif 'n_jobs' in local_param.keys() and \
|
||||||
|
local_param['n_jobs'] is not None and \
|
||||||
|
local_param['n_jobs'] != worker.nthreads:
|
||||||
|
msg = '`n_jobs` is specified. ' + msg
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
else:
|
||||||
|
local_param['nthread'] = worker.nthreads
|
||||||
|
bst = worker_train(params=local_param,
|
||||||
|
dtrain=local_dtrain,
|
||||||
|
*args,
|
||||||
|
evals_result=local_history,
|
||||||
|
evals=local_evals,
|
||||||
|
**kwargs)
|
||||||
|
ret = {'booster': bst, 'history': local_history}
|
||||||
|
if local_dtrain.num_row() == 0:
|
||||||
|
ret = None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if evals:
|
||||||
|
evals = [(e.create_fn_args(), name) for e, name in evals]
|
||||||
|
|
||||||
|
futures = client.map(dispatched_train,
|
||||||
|
workers,
|
||||||
|
[dtrain.create_fn_args()] * len(workers),
|
||||||
|
[evals] * len(workers),
|
||||||
|
pure=False,
|
||||||
|
workers=workers)
|
||||||
|
results = await client.gather(futures)
|
||||||
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
|
|
||||||
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||||
'''Train XGBoost model.
|
'''Train XGBoost model.
|
||||||
|
|
||||||
@ -544,75 +664,20 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
if 'evals_result' in kwargs.keys():
|
return client.sync(_train_async, client, params,
|
||||||
raise ValueError(
|
dtrain=dtrain, *args, evals=evals, **kwargs)
|
||||||
'evals_result is not supported in dask interface.',
|
|
||||||
'The evaluation history is returned as result of training.')
|
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
|
||||||
|
|
||||||
rabit_args = _get_rabit_args(workers, client)
|
|
||||||
|
|
||||||
def dispatched_train(worker_addr):
|
|
||||||
'''Perform training on a single worker.'''
|
|
||||||
LOGGER.info('Training on %s', str(worker_addr))
|
|
||||||
worker = distributed_get_worker()
|
|
||||||
with RabitContext(rabit_args):
|
|
||||||
local_dtrain = dtrain.get_worker_data(worker)
|
|
||||||
|
|
||||||
local_evals = []
|
|
||||||
if evals:
|
|
||||||
for mat, name in evals:
|
|
||||||
if mat is dtrain:
|
|
||||||
local_evals.append((local_dtrain, name))
|
|
||||||
continue
|
|
||||||
local_mat = mat.get_worker_data(worker)
|
|
||||||
local_evals.append((local_mat, name))
|
|
||||||
|
|
||||||
local_history = {}
|
|
||||||
local_param = params.copy() # just to be consistent
|
|
||||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
|
||||||
if 'nthread' in local_param.keys() and \
|
|
||||||
local_param['nthread'] is not None and \
|
|
||||||
local_param['nthread'] != worker.nthreads:
|
|
||||||
msg += '`nthread` is specified. ' + msg
|
|
||||||
LOGGER.warning(msg)
|
|
||||||
elif 'n_jobs' in local_param.keys() and \
|
|
||||||
local_param['n_jobs'] is not None and \
|
|
||||||
local_param['n_jobs'] != worker.nthreads:
|
|
||||||
msg = '`n_jobs` is specified. ' + msg
|
|
||||||
LOGGER.warning(msg)
|
|
||||||
else:
|
|
||||||
local_param['nthread'] = worker.nthreads
|
|
||||||
bst = worker_train(params=local_param,
|
|
||||||
dtrain=local_dtrain,
|
|
||||||
*args,
|
|
||||||
evals_result=local_history,
|
|
||||||
evals=local_evals,
|
|
||||||
**kwargs)
|
|
||||||
ret = {'booster': bst, 'history': local_history}
|
|
||||||
if local_dtrain.num_row() == 0:
|
|
||||||
ret = None
|
|
||||||
return ret
|
|
||||||
|
|
||||||
futures = client.map(dispatched_train,
|
|
||||||
workers,
|
|
||||||
pure=False,
|
|
||||||
workers=workers)
|
|
||||||
results = client.gather(futures)
|
|
||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
|
||||||
|
|
||||||
|
|
||||||
def _direct_predict_impl(client, data, predict_fn):
|
async def _direct_predict_impl(client, data, predict_fn):
|
||||||
if isinstance(data, da.Array):
|
if isinstance(data, da.Array):
|
||||||
predictions = client.submit(
|
predictions = await client.submit(
|
||||||
da.map_blocks,
|
da.map_blocks,
|
||||||
predict_fn, data, False, drop_axis=1,
|
predict_fn, data, False, drop_axis=1,
|
||||||
dtype=numpy.float32
|
dtype=numpy.float32
|
||||||
).result()
|
).result()
|
||||||
return predictions
|
return predictions
|
||||||
if isinstance(data, dd.DataFrame):
|
if isinstance(data, dd.DataFrame):
|
||||||
predictions = client.submit(
|
predictions = await client.submit(
|
||||||
dd.map_partitions,
|
dd.map_partitions,
|
||||||
predict_fn, data, True,
|
predict_fn, data, True,
|
||||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||||
@ -622,6 +687,100 @@ def _direct_predict_impl(client, data, predict_fn):
|
|||||||
' is not supported by direct prediction')
|
' is not supported by direct prediction')
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-statements
|
||||||
|
async def _predict_async(client: Client, model, data, *args,
|
||||||
|
missing=numpy.nan):
|
||||||
|
if isinstance(model, Booster):
|
||||||
|
booster = model
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
booster = model['booster']
|
||||||
|
else:
|
||||||
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
|
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||||
|
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||||
|
type(data)))
|
||||||
|
|
||||||
|
def mapped_predict(partition, is_df):
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
|
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||||
|
predt = booster.predict(m, *args, validate_features=False)
|
||||||
|
if is_df:
|
||||||
|
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||||
|
import cudf # pylint: disable=import-error
|
||||||
|
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||||
|
else:
|
||||||
|
predt = DataFrame(predt, columns=['prediction'])
|
||||||
|
return predt
|
||||||
|
# Predict on dask collection directly.
|
||||||
|
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
|
return await _direct_predict_impl(client, data, mapped_predict)
|
||||||
|
|
||||||
|
# Prediction on dask DMatrix.
|
||||||
|
worker_map = data.worker_map
|
||||||
|
partition_order = data.partition_order
|
||||||
|
feature_names = data.feature_names
|
||||||
|
feature_types = data.feature_types
|
||||||
|
missing = data.missing
|
||||||
|
|
||||||
|
def dispatched_predict(worker_id):
|
||||||
|
'''Perform prediction on each worker.'''
|
||||||
|
LOGGER.info('Predicting on %d', worker_id)
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
list_of_parts = _get_worker_x_ordered(worker_map, partition_order,
|
||||||
|
worker)
|
||||||
|
predictions = []
|
||||||
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
|
for part, order in list_of_parts:
|
||||||
|
local_x = DMatrix(part, feature_names=feature_names,
|
||||||
|
feature_types=feature_types,
|
||||||
|
missing=missing, nthread=worker.nthreads)
|
||||||
|
predt = booster.predict(data=local_x,
|
||||||
|
validate_features=local_x.num_row() != 0,
|
||||||
|
*args)
|
||||||
|
ret = (delayed(predt), order)
|
||||||
|
predictions.append(ret)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def dispatched_get_shape(worker_id):
|
||||||
|
'''Get shape of data in each worker.'''
|
||||||
|
LOGGER.info('Get shape on %d', worker_id)
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
list_of_parts = _get_worker_x_ordered(worker_map,
|
||||||
|
partition_order, worker)
|
||||||
|
shapes = [(part.shape, order) for part, order in list_of_parts]
|
||||||
|
return shapes
|
||||||
|
|
||||||
|
async def map_function(func):
|
||||||
|
'''Run function for each part of the data.'''
|
||||||
|
futures = []
|
||||||
|
for wid in range(len(worker_map)):
|
||||||
|
list_of_workers = [list(worker_map.keys())[wid]]
|
||||||
|
f = await client.submit(func, wid,
|
||||||
|
pure=False,
|
||||||
|
workers=list_of_workers)
|
||||||
|
futures.append(f)
|
||||||
|
# Get delayed objects
|
||||||
|
results = await client.gather(futures)
|
||||||
|
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||||
|
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||||
|
results = sorted(results, key=lambda l: l[1])
|
||||||
|
results = [predt for predt, order in results] # remove order
|
||||||
|
return results
|
||||||
|
|
||||||
|
results = await map_function(dispatched_predict)
|
||||||
|
shapes = await map_function(dispatched_get_shape)
|
||||||
|
|
||||||
|
# Constructing a dask array from list of numpy arrays
|
||||||
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
|
arrays = []
|
||||||
|
for i, shape in enumerate(shapes):
|
||||||
|
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
||||||
|
dtype=numpy.float32))
|
||||||
|
predictions = await da.concatenate(arrays, axis=0)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def predict(client, model, data, *args, missing=numpy.nan):
|
def predict(client, model, data, *args, missing=numpy.nan):
|
||||||
'''Run prediction with a trained booster.
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
@ -651,94 +810,44 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
|||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
return client.sync(_predict_async, client, model, data, *args,
|
||||||
|
missing=missing)
|
||||||
|
|
||||||
|
|
||||||
|
async def _inplace_predict_async(client, model, data,
|
||||||
|
iteration_range=(0, 0),
|
||||||
|
predict_type='value',
|
||||||
|
missing=numpy.nan):
|
||||||
|
client = _xgb_get_client(client)
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
booster = model['booster']
|
booster = model['booster']
|
||||||
else:
|
else:
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||||
type(data)))
|
|
||||||
|
|
||||||
def mapped_predict(partition, is_df):
|
def mapped_predict(data, is_df):
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
prediction = booster.inplace_predict(
|
||||||
predt = booster.predict(m, *args, validate_features=False)
|
data,
|
||||||
|
iteration_range=iteration_range,
|
||||||
|
predict_type=predict_type,
|
||||||
|
missing=missing)
|
||||||
if is_df:
|
if is_df:
|
||||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||||
import cudf # pylint: disable=import-error
|
import cudf # pylint: disable=import-error
|
||||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
prediction = cudf.DataFrame({'prediction': prediction},
|
||||||
|
dtype=numpy.float32)
|
||||||
else:
|
else:
|
||||||
predt = DataFrame(predt, columns=['prediction'])
|
# If it's from pandas, the partition is a numpy array
|
||||||
return predt
|
prediction = DataFrame(prediction, columns=['prediction'],
|
||||||
|
dtype=numpy.float32)
|
||||||
|
return prediction
|
||||||
|
|
||||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
return await _direct_predict_impl(client, data, mapped_predict)
|
||||||
return _direct_predict_impl(client, data, mapped_predict)
|
|
||||||
|
|
||||||
# Prediction on dask DMatrix.
|
|
||||||
worker_map = data.worker_map
|
|
||||||
|
|
||||||
def dispatched_predict(worker_id):
|
|
||||||
'''Perform prediction on each worker.'''
|
|
||||||
LOGGER.info('Predicting on %d', worker_id)
|
|
||||||
worker = distributed_get_worker()
|
|
||||||
list_of_parts = data.get_worker_x_ordered(worker)
|
|
||||||
predictions = []
|
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
|
||||||
for part, order in list_of_parts:
|
|
||||||
local_x = DMatrix(part,
|
|
||||||
feature_names=data.feature_names,
|
|
||||||
feature_types=data.feature_types,
|
|
||||||
missing=data.missing,
|
|
||||||
nthread=worker.nthreads)
|
|
||||||
predt = booster.predict(data=local_x,
|
|
||||||
validate_features=local_x.num_row() != 0,
|
|
||||||
*args)
|
|
||||||
ret = (delayed(predt), order)
|
|
||||||
predictions.append(ret)
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
def dispatched_get_shape(worker_id):
|
|
||||||
'''Get shape of data in each worker.'''
|
|
||||||
LOGGER.info('Trying to get data shape on %d', worker_id)
|
|
||||||
worker = distributed_get_worker()
|
|
||||||
list_of_parts = data.get_worker_x_ordered(worker)
|
|
||||||
shapes = []
|
|
||||||
for part, order in list_of_parts:
|
|
||||||
shapes.append((part.shape, order))
|
|
||||||
return shapes
|
|
||||||
|
|
||||||
def map_function(func):
|
|
||||||
'''Run function for each part of the data.'''
|
|
||||||
futures = []
|
|
||||||
for wid in range(len(worker_map)):
|
|
||||||
list_of_workers = [list(worker_map.keys())[wid]]
|
|
||||||
f = client.submit(func, wid,
|
|
||||||
pure=False,
|
|
||||||
workers=list_of_workers)
|
|
||||||
futures.append(f)
|
|
||||||
|
|
||||||
# Get delayed objects
|
|
||||||
results = client.gather(futures)
|
|
||||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
|
||||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
|
||||||
results = sorted(results, key=lambda l: l[1])
|
|
||||||
results = [predt for predt, order in results] # remove order
|
|
||||||
return results
|
|
||||||
|
|
||||||
results = map_function(dispatched_predict)
|
|
||||||
shapes = map_function(dispatched_get_shape)
|
|
||||||
|
|
||||||
# Constructing a dask array from list of numpy arrays
|
|
||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
|
||||||
arrays = []
|
|
||||||
for i, shape in enumerate(shapes):
|
|
||||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
|
||||||
dtype=numpy.float32))
|
|
||||||
predictions = da.concatenate(arrays, axis=0)
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
|
|
||||||
def inplace_predict(client, model, data,
|
def inplace_predict(client, model, data,
|
||||||
@ -770,38 +879,14 @@ def inplace_predict(client, model, data,
|
|||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
if isinstance(model, Booster):
|
return client.sync(_inplace_predict_async, client, model=model, data=data,
|
||||||
booster = model
|
iteration_range=iteration_range,
|
||||||
elif isinstance(model, dict):
|
predict_type=predict_type,
|
||||||
booster = model['booster']
|
missing=missing)
|
||||||
else:
|
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
|
||||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
|
||||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
|
||||||
|
|
||||||
def mapped_predict(data, is_df):
|
|
||||||
worker = distributed_get_worker()
|
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
|
||||||
prediction = booster.inplace_predict(
|
|
||||||
data,
|
|
||||||
iteration_range=iteration_range,
|
|
||||||
predict_type=predict_type,
|
|
||||||
missing=missing)
|
|
||||||
if is_df:
|
|
||||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
|
||||||
import cudf # pylint: disable=import-error
|
|
||||||
prediction = cudf.DataFrame({'prediction': prediction},
|
|
||||||
dtype=numpy.float32)
|
|
||||||
else:
|
|
||||||
# If it's from pandas, the partition is a numpy array
|
|
||||||
prediction = DataFrame(prediction, columns=['prediction'],
|
|
||||||
dtype=numpy.float32)
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
return _direct_predict_impl(client, data, mapped_predict)
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
async def _evaluation_matrices(client, validation_set,
|
||||||
|
sample_weights, missing):
|
||||||
'''
|
'''
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -826,8 +911,8 @@ def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
|||||||
for i, e in enumerate(validation_set):
|
for i, e in enumerate(validation_set):
|
||||||
w = (sample_weights[i]
|
w = (sample_weights[i]
|
||||||
if sample_weights is not None else None)
|
if sample_weights is not None else None)
|
||||||
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w,
|
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
||||||
missing=missing)
|
weight=w, missing=missing)
|
||||||
evals.append((dmat, 'validation_{}'.format(i)))
|
evals.append((dmat, 'validation_{}'.format(i)))
|
||||||
else:
|
else:
|
||||||
evals = None
|
evals = None
|
||||||
@ -840,9 +925,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
# pylint: disable=arguments-differ
|
# pylint: disable=arguments-differ
|
||||||
def fit(self,
|
def fit(self, X, y,
|
||||||
X,
|
|
||||||
y,
|
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
@ -879,6 +962,12 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
prediction : dask.array.Array'''
|
prediction : dask.array.Array'''
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
# Generate a coroutine wrapper to make this class awaitable.
|
||||||
|
async def _():
|
||||||
|
return self
|
||||||
|
return self.client.sync(_).__await__()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
'''The dask client used in this model.'''
|
'''The dask client used in this model.'''
|
||||||
@ -892,40 +981,51 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-class-docstring
|
||||||
def fit(self,
|
async def _fit_async(self,
|
||||||
X,
|
X,
|
||||||
y,
|
y,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
_assert_dask_support()
|
dtrain = await DaskDMatrix(client=self.client,
|
||||||
dtrain = DaskDMatrix(client=self.client,
|
data=X, label=y, weight=sample_weights,
|
||||||
data=X, label=y, weight=sample_weights,
|
missing=self.missing)
|
||||||
missing=self.missing)
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
evals = _evaluation_matrices(self.client,
|
evals = await _evaluation_matrices(self.client,
|
||||||
eval_set, sample_weight_eval_set,
|
eval_set, sample_weight_eval_set,
|
||||||
self.missing)
|
self.missing)
|
||||||
|
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||||
results = train(self.client, params, dtrain,
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
evals=evals, verbose_eval=verbose)
|
||||||
evals=evals, verbose_eval=verbose)
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
# pylint: disable=missing-docstring
|
||||||
|
def fit(self, X, y,
|
||||||
|
sample_weights=None,
|
||||||
|
eval_set=None,
|
||||||
|
sample_weight_eval_set=None,
|
||||||
|
verbose=True):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||||
missing=self.missing)
|
eval_set, sample_weight_eval_set,
|
||||||
pred_probs = predict(client=self.client,
|
verbose)
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
|
||||||
|
async def _predict_async(self, data): # pylint: disable=arguments-differ
|
||||||
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
|
missing=self.missing)
|
||||||
|
pred_probs = await predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
return pred_probs
|
return pred_probs
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
_assert_dask_support()
|
||||||
|
return self.client.sync(self._predict_async, data)
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||||
@ -935,24 +1035,21 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
def fit(self,
|
async def _fit_async(self, X, y,
|
||||||
X,
|
sample_weights=None,
|
||||||
y,
|
eval_set=None,
|
||||||
sample_weights=None,
|
sample_weight_eval_set=None,
|
||||||
eval_set=None,
|
verbose=True):
|
||||||
sample_weight_eval_set=None,
|
dtrain = await DaskDMatrix(client=self.client,
|
||||||
verbose=True):
|
data=X, label=y, weight=sample_weights,
|
||||||
_assert_dask_support()
|
missing=self.missing)
|
||||||
dtrain = DaskDMatrix(client=self.client,
|
|
||||||
data=X, label=y, weight=sample_weights,
|
|
||||||
missing=self.missing)
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if isinstance(y, (da.Array)):
|
if isinstance(y, (da.Array)):
|
||||||
self.classes_ = da.unique(y).compute()
|
self.classes_ = await self.client.compute(da.unique(y))
|
||||||
else:
|
else:
|
||||||
self.classes_ = y.drop_duplicates().compute()
|
self.classes_ = await self.client.compute(y.drop_duplicates())
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
@ -961,21 +1058,33 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
params["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
|
|
||||||
evals = _evaluation_matrices(self.client,
|
evals = await _evaluation_matrices(self.client,
|
||||||
eval_set, sample_weight_eval_set,
|
eval_set, sample_weight_eval_set,
|
||||||
self.missing)
|
self.missing)
|
||||||
results = train(self.client, params, dtrain,
|
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals, verbose_eval=verbose)
|
evals=evals, verbose_eval=verbose)
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def fit(self, X, y,
|
||||||
|
sample_weights=None,
|
||||||
|
eval_set=None,
|
||||||
|
sample_weight_eval_set=None,
|
||||||
|
verbose=True):
|
||||||
|
_assert_dask_support()
|
||||||
|
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||||
|
eval_set, sample_weight_eval_set, verbose)
|
||||||
|
|
||||||
|
async def _predict_async(self, data):
|
||||||
|
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||||
|
missing=self.missing)
|
||||||
|
pred_probs = await predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
return pred_probs
|
||||||
|
|
||||||
def predict(self, data): # pylint: disable=arguments-differ
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
return self.client.sync(self._predict_async, data)
|
||||||
missing=self.missing)
|
|
||||||
pred_probs = predict(client=self.client,
|
|
||||||
model=self.get_booster(), data=test_dmatrix)
|
|
||||||
return pred_probs
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import asyncio
|
||||||
import unittest
|
import unittest
|
||||||
import xgboost
|
import xgboost
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -219,7 +220,7 @@ class TestDistributedGPU(unittest.TestCase):
|
|||||||
with LocalCUDACluster() as cluster:
|
with LocalCUDACluster() as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
workers = list(dxgb._get_client_workers(client).keys())
|
workers = list(dxgb._get_client_workers(client).keys())
|
||||||
rabit_args = dxgb._get_rabit_args(workers, client)
|
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
@ -242,3 +243,39 @@ class TestDistributedGPU(unittest.TestCase):
|
|||||||
@pytest.mark.gtest
|
@pytest.mark.gtest
|
||||||
def test_quantile_same_on_all_workers(self):
|
def test_quantile_same_on_all_workers(self):
|
||||||
self.run_quantile('SameOnAllWorkers')
|
self.run_quantile('SameOnAllWorkers')
|
||||||
|
|
||||||
|
|
||||||
|
async def run_from_dask_array_asyncio(scheduler_address):
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
import cupy as cp
|
||||||
|
X, y = generate_array()
|
||||||
|
X = X.map_blocks(cp.array)
|
||||||
|
y = y.map_blocks(cp.array)
|
||||||
|
|
||||||
|
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y)
|
||||||
|
output = await xgboost.dask.train(client, {'tree_method': 'gpu_hist'},
|
||||||
|
dtrain=m)
|
||||||
|
|
||||||
|
with_m = await xgboost.dask.predict(client, output, m)
|
||||||
|
with_X = await xgboost.dask.predict(client, output, X)
|
||||||
|
inplace = await xgboost.dask.inplace_predict(client, output, X)
|
||||||
|
assert isinstance(with_m, da.Array)
|
||||||
|
assert isinstance(with_X, da.Array)
|
||||||
|
assert isinstance(inplace, da.Array)
|
||||||
|
|
||||||
|
cp.testing.assert_allclose(await client.compute(with_m),
|
||||||
|
await client.compute(with_X))
|
||||||
|
cp.testing.assert_allclose(await client.compute(with_m),
|
||||||
|
await client.compute(inplace))
|
||||||
|
|
||||||
|
client.shutdown()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_asyncio():
|
||||||
|
with LocalCUDACluster() as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
address = client.scheduler.address
|
||||||
|
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||||
|
assert isinstance(output['booster'], xgboost.Booster)
|
||||||
|
assert isinstance(output['history'], dict)
|
||||||
|
|||||||
@ -27,8 +27,10 @@ def run_rabit_ops(client, n_workers):
|
|||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = list(_get_client_workers(client).keys())
|
||||||
rabit_args = _get_rabit_args(workers, client)
|
rabit_args = client.sync(_get_rabit_args, workers, client)
|
||||||
assert not rabit.is_distributed()
|
assert not rabit.is_distributed()
|
||||||
|
n_workers_from_dask = len(workers)
|
||||||
|
assert n_workers == n_workers_from_dask
|
||||||
|
|
||||||
def local_test(worker_id):
|
def local_test(worker_id):
|
||||||
with RabitContext(rabit_args):
|
with RabitContext(rabit_args):
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import xgboost as xgb
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -327,3 +328,96 @@ def test_empty_dmatrix_approx():
|
|||||||
parameters = {'tree_method': 'approx'}
|
parameters = {'tree_method': 'approx'}
|
||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
run_empty_dmatrix_cls(client, parameters)
|
run_empty_dmatrix_cls(client, parameters)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_from_dask_array_asyncio(scheduler_address):
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
m = await DaskDMatrix(client, X, y)
|
||||||
|
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||||
|
|
||||||
|
with_m = await xgb.dask.predict(client, output, m)
|
||||||
|
with_X = await xgb.dask.predict(client, output, X)
|
||||||
|
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||||
|
assert isinstance(with_m, da.Array)
|
||||||
|
assert isinstance(with_X, da.Array)
|
||||||
|
assert isinstance(inplace, da.Array)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(await client.compute(with_m),
|
||||||
|
await client.compute(with_X))
|
||||||
|
np.testing.assert_allclose(await client.compute(with_m),
|
||||||
|
await client.compute(inplace))
|
||||||
|
|
||||||
|
client.shutdown()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dask_regressor_asyncio(scheduler_address):
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
||||||
|
n_estimators=2)
|
||||||
|
regressor.set_params(tree_method='hist')
|
||||||
|
regressor.client = client
|
||||||
|
await regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
prediction = await regressor.predict(X)
|
||||||
|
|
||||||
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
history = regressor.evals_result()
|
||||||
|
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
|
assert isinstance(history, dict)
|
||||||
|
|
||||||
|
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||||
|
assert len(history['validation_0']['rmse']) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dask_classifier_asyncio(scheduler_address):
|
||||||
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
|
X, y = generate_array()
|
||||||
|
y = (y * 10).astype(np.int32)
|
||||||
|
classifier = await xgb.dask.DaskXGBClassifier(
|
||||||
|
verbosity=1, n_estimators=2)
|
||||||
|
classifier.client = client
|
||||||
|
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
|
prediction = await classifier.predict(X)
|
||||||
|
|
||||||
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
history = classifier.evals_result()
|
||||||
|
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
|
assert isinstance(history, dict)
|
||||||
|
|
||||||
|
assert list(history.keys())[0] == 'validation_0'
|
||||||
|
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||||
|
assert len(list(history['validation_0'])) == 1
|
||||||
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|
||||||
|
assert classifier.n_classes_ == 10
|
||||||
|
|
||||||
|
# Test with dataframe.
|
||||||
|
X_d = dd.from_dask_array(X)
|
||||||
|
y_d = dd.from_dask_array(y)
|
||||||
|
await classifier.fit(X_d, y_d)
|
||||||
|
|
||||||
|
assert classifier.n_classes_ == 10
|
||||||
|
prediction = await classifier.predict(X_d)
|
||||||
|
|
||||||
|
assert prediction.ndim == 1
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_asyncio():
|
||||||
|
with LocalCluster() as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
address = client.scheduler.address
|
||||||
|
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||||
|
assert isinstance(output['booster'], xgb.Booster)
|
||||||
|
assert isinstance(output['history'], dict)
|
||||||
|
|
||||||
|
asyncio.run(run_dask_regressor_asyncio(address))
|
||||||
|
asyncio.run(run_dask_classifier_asyncio(address))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user