[Dask] Asyncio support. (#5862)
This commit is contained in:
parent
e4a273e1da
commit
fa3715f584
@ -12,6 +12,12 @@ algorithm. For an overview of GPU based training and internal working, see `A N
|
||||
Official Dask API for XGBoost
|
||||
<https://medium.com/rapids-ai/a-new-official-dask-api-for-xgboost-e8b10f3d1eb7>`_.
|
||||
|
||||
**Contents**
|
||||
|
||||
.. contents::
|
||||
:backlinks: none
|
||||
:local:
|
||||
|
||||
************
|
||||
Requirements
|
||||
************
|
||||
@ -105,6 +111,60 @@ set:
|
||||
|
||||
XGBoost will use 8 threads in each training process.
|
||||
|
||||
********************
|
||||
Working with asyncio
|
||||
********************
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
XGBoost dask interface supports the new ``asyncio`` in Python and can be integrated into
|
||||
asynchronous workflows. For using dask with asynchronous operations, please refer to
|
||||
`dask example <https://examples.dask.org/applications/async-await.html>`_ and document in
|
||||
`distributed <https://distributed.dask.org/en/latest/asynchronous.html>`_. As XGBoost
|
||||
takes ``Client`` object as an argument for both training and prediction, so when
|
||||
``asynchronous=True`` is specified when creating ``Client``, the dask interface can adapt
|
||||
the change accordingly. All functions provided by the functional interface returns a
|
||||
coroutine when called in async function, and hence require awaiting to get the result,
|
||||
including ``DaskDMatrix``.
|
||||
|
||||
Functional interface:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
m = await xgb.dask.DaskDMatrix(client, X, y)
|
||||
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||
|
||||
with_m = await xgb.dask.predict(client, output, m)
|
||||
with_X = await xgb.dask.predict(client, output, X)
|
||||
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||
|
||||
# Use `client.compute` instead of the `compute` method from dask collection
|
||||
print(await client.compute(with_m))
|
||||
|
||||
|
||||
While for Scikit Learn interface, trivial methods like ``set_params`` and accessing class
|
||||
attributes like ``evals_result_`` do not require ``await``. Other methods involving
|
||||
actual computation will return a coroutine and hence require awaiting:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist') # trivial method, synchronous operation
|
||||
regressor.client = client # accessing attribute, synchronous operation
|
||||
regressor = await regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await regressor.predict(X)
|
||||
|
||||
# Use `client.compute` instead of the `compute` method from dask collection
|
||||
print(await client.compute(prediction))
|
||||
|
||||
Be careful that XGBoost uses all the workers supplied by the ``client`` object. If you
|
||||
are training on GPU cluster and have 2 GPUs, the client object passed to XGBoost should
|
||||
return 2 workers.
|
||||
|
||||
*****************************************************************************
|
||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||
*****************************************************************************
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines
|
||||
"""Dask extensions for distributed training. See
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||
tutorial. Also xgboost/demo/dask for some examples.
|
||||
@ -35,6 +37,11 @@ from .tracker import RabitTracker
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import xgboost_model_doc
|
||||
|
||||
try:
|
||||
from distributed import Client
|
||||
except ImportError:
|
||||
Client = None
|
||||
|
||||
# Current status is considered as initial support, many features are
|
||||
# not properly supported yet.
|
||||
#
|
||||
@ -43,6 +50,17 @@ from .sklearn import xgboost_model_doc
|
||||
# - Label encoding.
|
||||
# - CV
|
||||
# - Ranking
|
||||
#
|
||||
# Note for developers:
|
||||
|
||||
# As of writing asyncio is still a new feature of Python and in depth
|
||||
# documentation is rare. Best examples of various asyncio tricks are in dask
|
||||
# (luckily). Classes like Client, Worker are awaitable. Some general rules
|
||||
# for the implementation here:
|
||||
# - Synchronous world is different from asynchronous one, and they don't
|
||||
# mix well.
|
||||
# - Write everything with async, then use distributed Client sync function
|
||||
# to do the switch.
|
||||
|
||||
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
@ -125,6 +143,12 @@ def _get_client_workers(client):
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
|
||||
# From the implementation point of view, DaskDMatrix complicates a lots of
|
||||
# things. A large portion of the code base is about syncing and extracting
|
||||
# stuffs from DaskDMatrix. But having an independent data structure gives us a
|
||||
# chance to perform some specialized optimizations, like building histogram
|
||||
# index directly.
|
||||
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
@ -133,6 +157,11 @@ class DaskDMatrix:
|
||||
the input data explicitly if you want to see actual computation of
|
||||
constructing `DaskDMatrix`.
|
||||
|
||||
.. note::
|
||||
|
||||
DaskDMatrix does not repartition or move data between workers. It's
|
||||
the caller's responsibility to balance the data.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
Parameters
|
||||
@ -165,7 +194,7 @@ class DaskDMatrix:
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
client: Client = _xgb_get_client(client)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
@ -187,7 +216,13 @@ class DaskDMatrix:
|
||||
self.has_label = label is not None
|
||||
self.has_weights = weight is not None
|
||||
|
||||
client.sync(self.map_local_data, client, data, label, weight)
|
||||
self.is_quantile = False
|
||||
|
||||
self._init = client.sync(self.map_local_data,
|
||||
client, data, label, weight)
|
||||
|
||||
def __await__(self):
|
||||
return self._init.__await__()
|
||||
|
||||
async def map_local_data(self, client, data, label=None, weights=None):
|
||||
'''Obtain references to local data.'''
|
||||
@ -264,29 +299,46 @@ class DaskDMatrix:
|
||||
|
||||
self.worker_map = worker_map
|
||||
|
||||
def get_worker_x_ordered(self, worker):
|
||||
list_of_parts = self.worker_map[worker.address]
|
||||
return self
|
||||
|
||||
def create_fn_args(self):
|
||||
'''Create a dictionary of objects that can be pickled for function
|
||||
arguments.
|
||||
|
||||
'''
|
||||
return {'feature_names': self.feature_names,
|
||||
'feature_types': self.feature_types,
|
||||
'has_label': self.has_label,
|
||||
'has_weights': self.has_weights,
|
||||
'missing': self.missing,
|
||||
'worker_map': self.worker_map,
|
||||
'is_quantile': self.is_quantile}
|
||||
|
||||
|
||||
def _get_worker_x_ordered(worker_map, partition_order, worker):
|
||||
list_of_parts = worker_map[worker.address]
|
||||
client = get_client()
|
||||
list_of_parts_value = client.gather(list_of_parts)
|
||||
result = []
|
||||
for i, part in enumerate(list_of_parts):
|
||||
result.append((list_of_parts_value[i][0],
|
||||
self.partition_order[part.key]))
|
||||
partition_order[part.key]))
|
||||
return result
|
||||
|
||||
def get_worker_parts(self, worker):
|
||||
'''Get mapped parts of data in each worker.'''
|
||||
list_of_parts = self.worker_map[worker.address]
|
||||
|
||||
def _get_worker_parts(has_label, has_weights, worker_map, worker):
|
||||
'''Get mapped parts of data in each worker from DaskDMatrix.'''
|
||||
list_of_parts = worker_map[worker.address]
|
||||
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
||||
assert isinstance(list_of_parts, list)
|
||||
|
||||
# `get_worker_parts` is launched inside worker. In dask side
|
||||
# `_get_worker_parts` is launched inside worker. In dask side
|
||||
# this should be equal to `worker._get_client`.
|
||||
client = get_client()
|
||||
list_of_parts = client.gather(list_of_parts)
|
||||
|
||||
if self.has_label:
|
||||
if self.has_weights:
|
||||
if has_label:
|
||||
if has_weights:
|
||||
data, labels, weights = zip(*list_of_parts)
|
||||
else:
|
||||
data, labels = zip(*list_of_parts)
|
||||
@ -297,66 +349,6 @@ class DaskDMatrix:
|
||||
weights = None
|
||||
return data, labels, weights
|
||||
|
||||
def get_worker_data(self, worker):
|
||||
'''Get data that local to worker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
worker: The worker used as key to data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A DMatrix object.
|
||||
|
||||
'''
|
||||
if worker.address not in set(self.worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
return d
|
||||
|
||||
data, labels, weights = self.get_worker_parts(worker)
|
||||
|
||||
data = concat(data)
|
||||
|
||||
if self.has_label:
|
||||
labels = concat(labels)
|
||||
else:
|
||||
labels = None
|
||||
if self.has_weights:
|
||||
weights = concat(weights)
|
||||
else:
|
||||
weights = None
|
||||
dmatrix = DMatrix(data,
|
||||
labels,
|
||||
weight=weights,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
nthread=worker.nthreads)
|
||||
return dmatrix
|
||||
|
||||
def get_worker_data_shape(self, worker):
|
||||
'''Get the shape of data X in each worker.'''
|
||||
data, _, _ = self.get_worker_parts(worker)
|
||||
|
||||
shapes = [d.shape for d in data]
|
||||
rows = 0
|
||||
cols = 0
|
||||
for shape in shapes:
|
||||
rows += shape[0]
|
||||
|
||||
c = shape[1]
|
||||
assert cols in (0, c), 'Shape between partitions are not the' \
|
||||
' same. Got: {left} and {right}'.format(left=c, right=cols)
|
||||
cols = c
|
||||
return (rows, cols)
|
||||
|
||||
|
||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
'''A data iterator for `DaskDeviceQuantileDMatrix`.
|
||||
@ -460,6 +452,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
----------
|
||||
max_bin: Number of bins for histogram construction.
|
||||
|
||||
|
||||
'''
|
||||
def __init__(self, client, data, label=None, weight=None,
|
||||
missing=None,
|
||||
@ -471,39 +464,99 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
def get_worker_data(self, worker):
|
||||
if worker.address not in set(self.worker_map.keys()):
|
||||
def create_fn_args(self):
|
||||
args = super().create_fn_args()
|
||||
args['max_bin'] = self.max_bin
|
||||
return args
|
||||
|
||||
|
||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
has_label,
|
||||
has_weights, missing, worker_map,
|
||||
max_bin):
|
||||
worker = distributed_get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
workers=set(worker_map.keys()))
|
||||
LOGGER.warning(msg)
|
||||
import cupy # pylint: disable=import-error
|
||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
max_bin=self.max_bin)
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin)
|
||||
return d
|
||||
|
||||
data, labels, weights = self.get_worker_parts(worker)
|
||||
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||
worker_map, worker)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(it,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=self.max_bin)
|
||||
max_bin=max_bin)
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _get_rabit_args(worker_map, client):
|
||||
def _create_dmatrix(feature_names, feature_types, has_label,
|
||||
has_weights, missing, worker_map):
|
||||
'''Get data that local to worker from DaskDMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A DMatrix object.
|
||||
|
||||
'''
|
||||
worker = distributed_get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(worker_map.keys()))
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
return d
|
||||
|
||||
data, labels, weights = _get_worker_parts(has_label, has_weights,
|
||||
worker_map, worker)
|
||||
data = concat(data)
|
||||
|
||||
if has_label:
|
||||
labels = concat(labels)
|
||||
else:
|
||||
labels = None
|
||||
if has_weights:
|
||||
weights = concat(weights)
|
||||
else:
|
||||
weights = None
|
||||
dmatrix = DMatrix(data,
|
||||
labels,
|
||||
weight=weights,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads)
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _dmatrix_from_worker_map(is_quantile, **kwargs):
|
||||
if is_quantile:
|
||||
return _create_device_quantile_dmatrix(**kwargs)
|
||||
return _create_dmatrix(**kwargs)
|
||||
|
||||
|
||||
async def _get_rabit_args(worker_map, client: Client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||
|
||||
env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
|
||||
len(worker_map))
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, host.strip('/:'), len(worker_map))
|
||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||
return rabit_args
|
||||
|
||||
@ -514,6 +567,73 @@ def _get_rabit_args(worker_map, client):
|
||||
# evaluation history is instead returned.
|
||||
|
||||
|
||||
async def _train_async(client, params, dtrain: DaskDMatrix,
|
||||
*args, evals=(), **kwargs):
|
||||
_assert_dask_support()
|
||||
client: Client = _xgb_get_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = await _get_rabit_args(workers, client)
|
||||
|
||||
def dispatched_train(worker_addr, dtrain_ref, evals_ref):
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed_get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
|
||||
local_evals = []
|
||||
if evals_ref:
|
||||
for ref, name in evals_ref:
|
||||
if ref['worker_map'] == dtrain_ref['worker_map']:
|
||||
local_evals.append((local_dtrain, name))
|
||||
continue
|
||||
local_evals.append((_dmatrix_from_worker_map(**ref), name))
|
||||
|
||||
local_history = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
if 'nthread' in local_param.keys() and \
|
||||
local_param['nthread'] is not None and \
|
||||
local_param['nthread'] != worker.nthreads:
|
||||
msg += '`nthread` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
elif 'n_jobs' in local_param.keys() and \
|
||||
local_param['n_jobs'] is not None and \
|
||||
local_param['n_jobs'] != worker.nthreads:
|
||||
msg = '`n_jobs` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
local_param['nthread'] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
**kwargs)
|
||||
ret = {'booster': bst, 'history': local_history}
|
||||
if local_dtrain.num_row() == 0:
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
if evals:
|
||||
evals = [(e.create_fn_args(), name) for e, name in evals]
|
||||
|
||||
futures = client.map(dispatched_train,
|
||||
workers,
|
||||
[dtrain.create_fn_args()] * len(workers),
|
||||
[evals] * len(workers),
|
||||
pure=False,
|
||||
workers=workers)
|
||||
results = await client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
'''Train XGBoost model.
|
||||
|
||||
@ -544,75 +664,20 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
'The evaluation history is returned as result of training.')
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
|
||||
rabit_args = _get_rabit_args(workers, client)
|
||||
|
||||
def dispatched_train(worker_addr):
|
||||
'''Perform training on a single worker.'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed_get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = dtrain.get_worker_data(worker)
|
||||
|
||||
local_evals = []
|
||||
if evals:
|
||||
for mat, name in evals:
|
||||
if mat is dtrain:
|
||||
local_evals.append((local_dtrain, name))
|
||||
continue
|
||||
local_mat = mat.get_worker_data(worker)
|
||||
local_evals.append((local_mat, name))
|
||||
|
||||
local_history = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
if 'nthread' in local_param.keys() and \
|
||||
local_param['nthread'] is not None and \
|
||||
local_param['nthread'] != worker.nthreads:
|
||||
msg += '`nthread` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
elif 'n_jobs' in local_param.keys() and \
|
||||
local_param['n_jobs'] is not None and \
|
||||
local_param['n_jobs'] != worker.nthreads:
|
||||
msg = '`n_jobs` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
local_param['nthread'] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
**kwargs)
|
||||
ret = {'booster': bst, 'history': local_history}
|
||||
if local_dtrain.num_row() == 0:
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
futures = client.map(dispatched_train,
|
||||
workers,
|
||||
pure=False,
|
||||
workers=workers)
|
||||
results = client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
return client.sync(_train_async, client, params,
|
||||
dtrain=dtrain, *args, evals=evals, **kwargs)
|
||||
|
||||
|
||||
def _direct_predict_impl(client, data, predict_fn):
|
||||
async def _direct_predict_impl(client, data, predict_fn):
|
||||
if isinstance(data, da.Array):
|
||||
predictions = client.submit(
|
||||
predictions = await client.submit(
|
||||
da.map_blocks,
|
||||
predict_fn, data, False, drop_axis=1,
|
||||
dtype=numpy.float32
|
||||
).result()
|
||||
return predictions
|
||||
if isinstance(data, dd.DataFrame):
|
||||
predictions = client.submit(
|
||||
predictions = await client.submit(
|
||||
dd.map_partitions,
|
||||
predict_fn, data, True,
|
||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
||||
@ -622,6 +687,100 @@ def _direct_predict_impl(client, data, predict_fn):
|
||||
' is not supported by direct prediction')
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(client: Client, model, data, *args,
|
||||
missing=numpy.nan):
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, *args, validate_features=False)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
# Prediction on dask DMatrix.
|
||||
worker_map = data.worker_map
|
||||
partition_order = data.partition_order
|
||||
feature_names = data.feature_names
|
||||
feature_types = data.feature_types
|
||||
missing = data.missing
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = _get_worker_x_ordered(worker_map, partition_order,
|
||||
worker)
|
||||
predictions = []
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part, feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
ret = (delayed(predt), order)
|
||||
predictions.append(ret)
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Get shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = _get_worker_x_ordered(worker_map,
|
||||
partition_order, worker)
|
||||
shapes = [(part.shape, order) for part, order in list_of_parts]
|
||||
return shapes
|
||||
|
||||
async def map_function(func):
|
||||
'''Run function for each part of the data.'''
|
||||
futures = []
|
||||
for wid in range(len(worker_map)):
|
||||
list_of_workers = [list(worker_map.keys())[wid]]
|
||||
f = await client.submit(func, wid,
|
||||
pure=False,
|
||||
workers=list_of_workers)
|
||||
futures.append(f)
|
||||
# Get delayed objects
|
||||
results = await client.gather(futures)
|
||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||
results = sorted(results, key=lambda l: l[1])
|
||||
results = [predt for predt, order in results] # remove order
|
||||
return results
|
||||
|
||||
results = await map_function(dispatched_predict)
|
||||
shapes = await map_function(dispatched_get_shape)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
arrays = []
|
||||
for i, shape in enumerate(shapes):
|
||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
||||
dtype=numpy.float32))
|
||||
predictions = await da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
|
||||
|
||||
def predict(client, model, data, *args, missing=numpy.nan):
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
@ -651,94 +810,44 @@ def predict(client, model, data, *args, missing=numpy.nan):
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(_predict_async, client, model, data, *args,
|
||||
missing=missing)
|
||||
|
||||
|
||||
async def _inplace_predict_async(client, model, data,
|
||||
iteration_range=(0, 0),
|
||||
predict_type='value',
|
||||
missing=numpy.nan):
|
||||
client = _xgb_get_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
def mapped_predict(data, is_df):
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, *args, validate_features=False)
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
prediction = cudf.DataFrame({'prediction': prediction},
|
||||
dtype=numpy.float32)
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
return predt
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(prediction, columns=['prediction'],
|
||||
dtype=numpy.float32)
|
||||
return prediction
|
||||
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
# Prediction on dask DMatrix.
|
||||
worker_map = data.worker_map
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
predictions = []
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part,
|
||||
feature_names=data.feature_names,
|
||||
feature_types=data.feature_types,
|
||||
missing=data.missing,
|
||||
nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
ret = (delayed(predt), order)
|
||||
predictions.append(ret)
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Trying to get data shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
shapes = []
|
||||
for part, order in list_of_parts:
|
||||
shapes.append((part.shape, order))
|
||||
return shapes
|
||||
|
||||
def map_function(func):
|
||||
'''Run function for each part of the data.'''
|
||||
futures = []
|
||||
for wid in range(len(worker_map)):
|
||||
list_of_workers = [list(worker_map.keys())[wid]]
|
||||
f = client.submit(func, wid,
|
||||
pure=False,
|
||||
workers=list_of_workers)
|
||||
futures.append(f)
|
||||
|
||||
# Get delayed objects
|
||||
results = client.gather(futures)
|
||||
results = [t for l in results for t in l] # flatten into 1 dim list
|
||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
||||
results = sorted(results, key=lambda l: l[1])
|
||||
results = [predt for predt, order in results] # remove order
|
||||
return results
|
||||
|
||||
results = map_function(dispatched_predict)
|
||||
shapes = map_function(dispatched_get_shape)
|
||||
|
||||
# Constructing a dask array from list of numpy arrays
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
arrays = []
|
||||
for i, shape in enumerate(shapes):
|
||||
arrays.append(da.from_delayed(results[i], shape=(shape[0], ),
|
||||
dtype=numpy.float32))
|
||||
predictions = da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
|
||||
def inplace_predict(client, model, data,
|
||||
@ -770,38 +879,14 @@ def inplace_predict(client, model, data,
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(data, is_df):
|
||||
worker = distributed_get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
return client.sync(_inplace_predict_async, client, model=model, data=data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing)
|
||||
if is_df:
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
import cudf # pylint: disable=import-error
|
||||
prediction = cudf.DataFrame({'prediction': prediction},
|
||||
dtype=numpy.float32)
|
||||
else:
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(prediction, columns=['prediction'],
|
||||
dtype=numpy.float32)
|
||||
return prediction
|
||||
|
||||
return _direct_predict_impl(client, data, mapped_predict)
|
||||
|
||||
|
||||
def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
||||
async def _evaluation_matrices(client, validation_set,
|
||||
sample_weights, missing):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
@ -826,8 +911,8 @@ def _evaluation_matrices(client, validation_set, sample_weights, missing):
|
||||
for i, e in enumerate(validation_set):
|
||||
w = (sample_weights[i]
|
||||
if sample_weights is not None else None)
|
||||
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w,
|
||||
missing=missing)
|
||||
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
||||
weight=w, missing=missing)
|
||||
evals.append((dmat, 'validation_{}'.format(i)))
|
||||
else:
|
||||
evals = None
|
||||
@ -840,9 +925,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
_client = None
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def fit(self,
|
||||
X,
|
||||
y,
|
||||
def fit(self, X, y,
|
||||
sample_weights=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
@ -879,6 +962,12 @@ class DaskScikitLearnBase(XGBModel):
|
||||
prediction : dask.array.Array'''
|
||||
raise NotImplementedError
|
||||
|
||||
def __await__(self):
|
||||
# Generate a coroutine wrapper to make this class awaitable.
|
||||
async def _():
|
||||
return self
|
||||
return self.client.sync(_).__await__()
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
'''The dask client used in this model.'''
|
||||
@ -892,40 +981,51 @@ class DaskScikitLearnBase(XGBModel):
|
||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||
['estimators', 'model'])
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
# pylint: disable=missing-docstring
|
||||
def fit(self,
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(self,
|
||||
X,
|
||||
y,
|
||||
sample_weights=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
dtrain = DaskDMatrix(client=self.client,
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X, label=y, weight=sample_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
evals = _evaluation_matrices(self.client,
|
||||
evals = await _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set,
|
||||
self.missing)
|
||||
|
||||
results = train(self.client, params, dtrain,
|
||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals, verbose_eval=verbose)
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
return self
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
# pylint: disable=missing-docstring
|
||||
def fit(self, X, y,
|
||||
sample_weights=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
||||
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||
eval_set, sample_weight_eval_set,
|
||||
verbose)
|
||||
|
||||
async def _predict_async(self, data): # pylint: disable=arguments-differ
|
||||
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||
missing=self.missing)
|
||||
pred_probs = predict(client=self.client,
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix)
|
||||
return pred_probs
|
||||
|
||||
def predict(self, data):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._predict_async, data)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||
@ -935,24 +1035,21 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-docstring
|
||||
_client = None
|
||||
|
||||
def fit(self,
|
||||
X,
|
||||
y,
|
||||
async def _fit_async(self, X, y,
|
||||
sample_weights=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
dtrain = DaskDMatrix(client=self.client,
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X, label=y, weight=sample_weights,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if isinstance(y, (da.Array)):
|
||||
self.classes_ = da.unique(y).compute()
|
||||
self.classes_ = await self.client.compute(da.unique(y))
|
||||
else:
|
||||
self.classes_ = y.drop_duplicates().compute()
|
||||
self.classes_ = await self.client.compute(y.drop_duplicates())
|
||||
self.n_classes_ = len(self.classes_)
|
||||
|
||||
if self.n_classes_ > 2:
|
||||
@ -961,10 +1058,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
else:
|
||||
params["objective"] = "binary:logistic"
|
||||
|
||||
evals = _evaluation_matrices(self.client,
|
||||
evals = await _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set,
|
||||
self.missing)
|
||||
results = train(self.client, params, dtrain,
|
||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals, verbose_eval=verbose)
|
||||
self._Booster = results['booster']
|
||||
@ -972,10 +1069,22 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
self.evals_result_ = results['history']
|
||||
return self
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
def fit(self, X, y,
|
||||
sample_weights=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
test_dmatrix = DaskDMatrix(client=self.client, data=data,
|
||||
return self.client.sync(self._fit_async, X, y, sample_weights,
|
||||
eval_set, sample_weight_eval_set, verbose)
|
||||
|
||||
async def _predict_async(self, data):
|
||||
test_dmatrix = await DaskDMatrix(client=self.client, data=data,
|
||||
missing=self.missing)
|
||||
pred_probs = predict(client=self.client,
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix)
|
||||
return pred_probs
|
||||
|
||||
def predict(self, data): # pylint: disable=arguments-differ
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._predict_async, data)
|
||||
|
||||
@ -2,6 +2,7 @@ import sys
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import unittest
|
||||
import xgboost
|
||||
import subprocess
|
||||
@ -219,7 +220,7 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(dxgb._get_client_workers(client).keys())
|
||||
rabit_args = dxgb._get_rabit_args(workers, client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
@ -242,3 +243,39 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
import cupy as cp
|
||||
X, y = generate_array()
|
||||
X = X.map_blocks(cp.array)
|
||||
y = y.map_blocks(cp.array)
|
||||
|
||||
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y)
|
||||
output = await xgboost.dask.train(client, {'tree_method': 'gpu_hist'},
|
||||
dtrain=m)
|
||||
|
||||
with_m = await xgboost.dask.predict(client, output, m)
|
||||
with_X = await xgboost.dask.predict(client, output, X)
|
||||
inplace = await xgboost.dask.inplace_predict(client, output, X)
|
||||
assert isinstance(with_m, da.Array)
|
||||
assert isinstance(with_X, da.Array)
|
||||
assert isinstance(inplace, da.Array)
|
||||
|
||||
cp.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(with_X))
|
||||
cp.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(inplace))
|
||||
|
||||
client.shutdown()
|
||||
return output
|
||||
|
||||
|
||||
def test_with_asyncio():
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
address = client.scheduler.address
|
||||
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||
assert isinstance(output['booster'], xgboost.Booster)
|
||||
assert isinstance(output['history'], dict)
|
||||
|
||||
@ -27,8 +27,10 @@ def run_rabit_ops(client, n_workers):
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = _get_rabit_args(workers, client)
|
||||
rabit_args = client.sync(_get_rabit_args, workers, client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
|
||||
@ -4,6 +4,7 @@ import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@ -327,3 +328,96 @@ def test_empty_dmatrix_approx():
|
||||
parameters = {'tree_method': 'approx'}
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
m = await DaskDMatrix(client, X, y)
|
||||
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||
|
||||
with_m = await xgb.dask.predict(client, output, m)
|
||||
with_X = await xgb.dask.predict(client, output, X)
|
||||
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||
assert isinstance(with_m, da.Array)
|
||||
assert isinstance(with_X, da.Array)
|
||||
assert isinstance(inplace, da.Array)
|
||||
|
||||
np.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(with_X))
|
||||
np.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(inplace))
|
||||
|
||||
client.shutdown()
|
||||
return output
|
||||
|
||||
|
||||
async def run_dask_regressor_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
||||
n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
await regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await regressor.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = regressor.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
async def run_dask_classifier_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = await xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
await classifier.fit(X_d, y_d)
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = await classifier.predict(X_d)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
|
||||
def test_with_asyncio():
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
address = client.scheduler.address
|
||||
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||
assert isinstance(output['booster'], xgb.Booster)
|
||||
assert isinstance(output['history'], dict)
|
||||
|
||||
asyncio.run(run_dask_regressor_asyncio(address))
|
||||
asyncio.run(run_dask_classifier_asyncio(address))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user