diff --git a/demo/dask/README.md b/demo/dask/README.md deleted file mode 100644 index 7e10dd321..000000000 --- a/demo/dask/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# Dask Integration - -[Dask](https://dask.org/) is a parallel computing library built on Python. Dask allows easy management of distributed workers and excels handling large distributed data science workflows. - -The simple demo shows how to train and make predictions for an xgboost model on a distributed dask environment. We make use of first-class support in xgboost for launching dask workers. Workers launched in this manner are automatically connected via xgboosts underlying communication framework, Rabit. The calls to `xgb.train()` and `xgb.predict()` occur in parallel on each worker and are synchronized. - -The GPU demo shows how to configure and use GPUs on the local machine for training on a large dataset. - -## Requirements -Dask is trivial to install using either pip or conda. [See here for official install documentation](https://docs.dask.org/en/latest/install.html). - -The GPU demo requires [GPUtil](https://github.com/anderskm/gputil) for detecting system GPUs. - -Install via `pip install gputil` - -## Running the scripts -```bash -python dask_simple_demo.py -python dask_gpu_demo.py -``` \ No newline at end of file diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py new file mode 100644 index 000000000..a949a27bf --- /dev/null +++ b/demo/dask/cpu_training.py @@ -0,0 +1,35 @@ +import xgboost as xgb +from xgboost.dask import DaskDMatrix +from dask.distributed import Client +from dask.distributed import LocalCluster +from dask import array as da + + +def main(client): + n = 100 + m = 100000 + partition_size = 1000 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + dtrain = DaskDMatrix(client, X, y) + + output = xgb.dask.train(client, + {'verbosity': 2, + 'nthread': 1, + 'tree_method': 'hist'}, + dtrain, + num_boost_round=4, evals=[(dtrain, 'train')]) + bst = output['booster'] + history = output['history'] + + prediction = xgb.dask.predict(client, bst, dtrain) + print('Evaluation history:', history) + return prediction + + +if __name__ == '__main__': + # or use any other clusters + cluster = LocalCluster(n_workers=4, threads_per_worker=1) + client = Client(cluster) + main(client) diff --git a/demo/dask/dask_gpu_demo.py b/demo/dask/dask_gpu_demo.py deleted file mode 100644 index d3f98df8b..000000000 --- a/demo/dask/dask_gpu_demo.py +++ /dev/null @@ -1,42 +0,0 @@ -from dask.distributed import Client, LocalCluster -import dask.dataframe as dd -import dask.array as da -import numpy as np -import xgboost as xgb -import GPUtil -import time - - -# Define the function to be executed on each worker -def train(X, y, available_devices): - dtrain = xgb.dask.create_worker_dmatrix(X, y) - local_device = available_devices[xgb.rabit.get_rank()] - # Specify the GPU algorithm and device for this worker - params = {"tree_method": "gpu_hist", "gpu_id": local_device} - print("Worker {} starting training on {} rows".format(xgb.rabit.get_rank(), dtrain.num_row())) - start = time.time() - xgb.train(params, dtrain, num_boost_round=500) - end = time.time() - print("Worker {} finished training in {:0.2f}s".format(xgb.rabit.get_rank(), end - start)) - - -def main(): - max_devices = 16 - # Check which devices we have locally - available_devices = GPUtil.getAvailable(limit=max_devices) - # Use one worker per device - cluster = LocalCluster(n_workers=len(available_devices), threads_per_worker=4) - client = Client(cluster) - - # Set up a relatively large regression problem - n = 100 - m = 10000000 - partition_size = 100000 - X = da.random.random((m, n), partition_size) - y = da.random.random(m, partition_size) - - xgb.dask.run(client, train, X, y, available_devices) - - -if __name__ == '__main__': - main() diff --git a/demo/dask/dask_simple_demo.py b/demo/dask/dask_simple_demo.py deleted file mode 100644 index 56f07cdae..000000000 --- a/demo/dask/dask_simple_demo.py +++ /dev/null @@ -1,68 +0,0 @@ -from dask.distributed import Client, LocalCluster -import dask.dataframe as dd -import dask.array as da -import numpy as np -import xgboost as xgb - - -# Define the function to be executed on each worker -def train(X, y): - print("Start training with worker #{}".format(xgb.rabit.get_rank())) - # X,y are dask objects distributed across the cluster. - # We must obtain the data local to this worker and convert it to DMatrix for training. - # xgb.dask.create_worker_dmatrix follows the API exactly of the standard DMatrix constructor - # (xgb.DMatrix()), except that it 'unpacks' dask distributed objects to obtain data local to - # this worker - dtrain = xgb.dask.create_worker_dmatrix(X, y) - - # Train on the data. Each worker will communicate and synchronise during training. The output - # model is expected to be identical on each worker. - bst = xgb.train({}, dtrain) - # Make predictions on local data - pred = bst.predict(dtrain) - print("Finished training with worker #{}".format(xgb.rabit.get_rank())) - # Get text representation of the model - return bst.get_dump() - - -def train_with_sklearn(X, y): - print("Training with worker #{} using the sklearn API".format(xgb.rabit.get_rank())) - X_local = xgb.dask.get_local_data(X) - y_local = xgb.dask.get_local_data(y) - model = xgb.XGBRegressor(n_estimators=10) - model.fit(X_local, y_local) - print("Finished training with worker #{} using the sklearn API".format(xgb.rabit.get_rank())) - return model.predict(X_local) - - -def main(): - # Launch a very simple local cluster using two distributed workers with two CPU threads each - cluster = LocalCluster(n_workers=2, threads_per_worker=2) - client = Client(cluster) - - # Generate some small test data as a dask array - # These data frames are internally split into partitions of 20 rows each and then distributed - # among workers, so we will have 5 partitions distributed among 2 workers - # Note that the partition size MUST be consistent across different dask dataframes/arrays - n = 10 - m = 100 - partition_size = 20 - X = da.random.random((m, n), partition_size) - y = da.random.random(m, partition_size) - - # xgb.dask.run launches an arbitrary function and its arguments on the cluster - # Here train(X, y) will be called on each worker - # This function blocks until all work is complete - models = xgb.dask.run(client, train, X, y) - - # models contains a dictionary mapping workers to results - # We expect that the models are the same over all workers - first_model = next(iter(models.values())) - assert all(model == first_model for worker, model in models.items()) - - # We can also train using the sklearn API - results = xgb.dask.run(client, train_with_sklearn, X, y) - - -if __name__ == '__main__': - main() diff --git a/demo/dask/gpu_training.py b/demo/dask/gpu_training.py new file mode 100644 index 000000000..469c6a7ee --- /dev/null +++ b/demo/dask/gpu_training.py @@ -0,0 +1,41 @@ +from dask_cuda import LocalCUDACluster +from dask.distributed import Client +from dask import array as da +import xgboost as xgb +from xgboost.dask import DaskDMatrix + + +def main(client): + n = 100 + m = 100000 + partition_size = 1000 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + # DaskDMatrix acts like normal DMatrix, works as a proxy for local + # DMatrix scatter around workers. + dtrain = DaskDMatrix(client, X, y) + + # Use train method from xgboost.dask instead of xgboost. This + # distributed version of train returns a dictionary containing the + # resulting booster and evaluation history obtained from + # evaluation metrics. + output = xgb.dask.train(client, + {'verbosity': 2, + 'nthread': 1, + 'tree_method': 'gpu_hist'}, + dtrain, + num_boost_round=4, evals=[(dtrain, 'train')]) + bst = output['booster'] + history = output['history'] + + prediction = xgb.dask.predict(client, bst, dtrain) + print('Evaluation history:', history) + return prediction + + +if __name__ == '__main__': + # or use any other clusters + cluster = LocalCUDACluster(n_workers=4, threads_per_worker=1) + client = Client(cluster) + main(client) diff --git a/demo/dask/sklearn_cpu_training.py b/demo/dask/sklearn_cpu_training.py new file mode 100644 index 000000000..4a16f9b4d --- /dev/null +++ b/demo/dask/sklearn_cpu_training.py @@ -0,0 +1,30 @@ +'''Dask interface demo: + +Use scikit-learn regressor interface with CPU histogram tree method.''' +from dask.distributed import Client +from dask.distributed import LocalCluster +from dask import array as da +import xgboost + +if __name__ == '__main__': + cluster = LocalCluster(n_workers=2, silence_logs=False) # or use any other clusters + client = Client(cluster) + + n = 100 + m = 10000 + partition_size = 100 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + regressor = xgboost.dask.DaskXGBRegressor(verbosity=2, n_estimators=2) + regressor.set_params(tree_method='hist') + regressor.client = client + + regressor.fit(X, y, eval_set=[(X, y)]) + prediction = regressor.predict(X) + + bst = regressor.get_booster() + history = regressor.evals_result() + + print('Evaluation history:', history) + assert isinstance(prediction, da.Array) diff --git a/demo/dask/sklearn_gpu_training.py b/demo/dask/sklearn_gpu_training.py new file mode 100644 index 000000000..caa58cfe1 --- /dev/null +++ b/demo/dask/sklearn_gpu_training.py @@ -0,0 +1,31 @@ +'''Dask interface demo: + +Use scikit-learn regressor interface with GPU histogram tree method.''' + +from dask.distributed import Client +# It's recommended to use dask_cuda for GPU assignment +from dask_cuda import LocalCUDACluster +from dask import array as da +import xgboost + +if __name__ == '__main__': + cluster = LocalCUDACluster() + client = Client(cluster) + + n = 100 + m = 1000000 + partition_size = 10000 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + regressor = xgboost.dask.DaskXGBRegressor(verbosity=2) + regressor.set_params(tree_method='gpu_hist') + regressor.client = client + + regressor.fit(X, y, eval_set=[(X, y)]) + prediction = regressor.predict(X) + + bst = regressor.get_booster() + history = regressor.evals_result() + + print('Evaluation history:', history) diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 5a19f02f6..6ef42c067 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -80,9 +80,10 @@ Dask API -------- .. automodule:: xgboost.dask -.. autofunction:: xgboost.dask.run +.. autofunction:: xgboost.dask.DaskDMatrix -.. autofunction:: xgboost.dask.create_worker_dmatrix +.. autofunction:: xgboost.dask.predict -.. autofunction:: xgboost.dask.get_local_data +.. autofunction:: xgboost.dask.DaskXGBClassifier +.. autofunction:: xgboost.dask.DaskXGBRegressor diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst new file mode 100644 index 000000000..d5079b403 --- /dev/null +++ b/doc/tutorials/dask.rst @@ -0,0 +1,92 @@ +############################# +Distributed XGBoost with Dask +############################# + +`Dask `_ is a parallel computing library built on Python. Dask allows +easy management of distributed workers and excels handling large distributed data science +workflows. The implementation in XGBoost originates from `dask-xgboost +`_ with some extended functionalities and a +different interface. Right now it is still under construction and may change (with proper +warnings) in the future. + +************ +Requirements +************ + +Dask is trivial to install using either pip or conda. `See here for official install +documentation `_. For accelerating XGBoost +with GPU, `dask-cuda `_ is recommended for creating +GPU clusters. + + +******** +Overview +******** + +There are 3 different components in dask from a user's perspective, namely a scheduler, +bunch of workers and some clients connecting to the scheduler. For using XGBoost with +dask, one needs to call XGBoost dask interface from the client side. A small example +illustrates the basic usage: + +.. code-block:: python + + cluster = LocalCluster(n_workers=4, threads_per_worker=1) + client = Client(cluster) + + dtrain = xgb.dask.DaskDMatrix(client, X, y) # X and y are dask dataframes or arrays + + output = xgb.dask.train(client, + {'verbosity': 2, + 'nthread': 1, + 'tree_method': 'hist'}, + dtrain, + num_boost_round=4, evals=[(dtrain, 'train')]) + +Here we first create a cluster in signle-node mode wtih ``distributed.LocalCluster``, then +connect a ``client`` to this cluster, setting up environment for later computation. +Similar to non-distributed interface, we create a ``DMatrix`` object and pass it to +``train`` along with some other parameters. Except in dask interface, client is an extra +argument for carrying out the computation, when set to ``None`` XGBoost will use the +default client returned from dask. + +There are two sets of APIs implemented in XGBoost. The first set is functional API +illustrated in above example. Given the data and a set of parameters, `train` function +returns a model and the computation history as Python dictionary + +.. code-block:: python + + {'booster': Booster, + 'history': dict} + +For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predict`` + +.. code-block:: python + + prediction = xgb.dask.predict(client, output, dtrain) + +Or equivalently, pass ``output['booster']``: + +.. code-block:: python + + prediction = xgb.dask.predict(client, output['booster'], dtrain) + +Here ``prediction`` is a dask ``Array`` object containing predictions from model. + +Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-Learn +interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` +for more examples. + + +*********** +Limitations +*********** + +Basic functionalities including training and generating predictions for regression and +classification are implemented. But there are still some other limitations we haven't +addressed yet. + +- Label encoding for Scikit-Learn classifier. +- Ranking +- Callback functions are not tested. +- To use cross validation one needs to explicitly train different models instead of using + a functional API like ``xgboost.cv``. diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 14bff5ad0..65481fcbc 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -21,3 +21,4 @@ See `Awesome XGBoost `_ for mo param_tuning external_memory custom_metric_obj + dask diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 43fa6baef..24ca59bb1 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -11,9 +11,9 @@ import os from .core import DMatrix, Booster from .training import train, cv from . import rabit # noqa -from . import dask # noqa from . import tracker # noqa from .tracker import RabitTracker # noqa +from . import dask try: from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker from .sklearn import XGBRFClassifier, XGBRFRegressor @@ -30,4 +30,4 @@ __all__ = ['DMatrix', 'Booster', 'RabitTracker', 'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker', 'XGBRFClassifier', 'XGBRFRegressor', - 'plot_importance', 'plot_tree', 'to_graphviz'] + 'plot_importance', 'plot_tree', 'to_graphviz', 'dask'] diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 25edc2435..cbcd6822e 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -96,14 +96,17 @@ except ImportError: # pandas try: - from pandas import DataFrame + from pandas import DataFrame, Series from pandas import MultiIndex + from pandas import concat as pandas_concat PANDAS_INSTALLED = True except ImportError: MultiIndex = object DataFrame = object + Series = object + pandas_concat = None PANDAS_INSTALLED = False # dt @@ -169,16 +172,35 @@ except ImportError: # dask try: - from dask.dataframe import DataFrame as DaskDataFrame - from dask.dataframe import Series as DaskSeries - from dask.array import Array as DaskArray + import dask + from dask import delayed + from dask import dataframe as dd + from dask import array as da + from dask.distributed import Client, get_client + from dask.distributed import comm as distributed_comm + from dask.distributed import wait as distributed_wait from distributed import get_worker as distributed_get_worker DASK_INSTALLED = True except ImportError: - DaskDataFrame = object - DaskSeries = object - DaskArray = object + dd = None + da = None + Client = None + delayed = None + get_client = None + distributed_comm = None + distributed_wait = None distributed_get_worker = None + dask = None DASK_INSTALLED = False + + +try: + import sparse + import scipy.sparse as scipy_sparse + SCIPY_INSTALLED = True +except ImportError: + sparse = False + scipy_sparse = False + SCIPY_INSTALLED = False diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d2c62b71d..5052fe029 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -106,6 +106,28 @@ def from_cstr_to_pystr(data, length): return res +def _expect(expectations, got): + '''Translate input error into string. + + Parameters + ---------- + expectations: sequence + a list of expected value. + got: + actual input + + Returns + ------- + msg: str''' + msg = 'Expecting ' + for t in range(len(expectations) - 1): + msg += str(expectations[t]) + msg += ' or ' + msg += str(expectations[-1]) + msg += '. Got ' + str(got) + return msg + + def _log_callback(msg): """Redirect logs from native library into Python console""" print("{0:s}".format(py_str(msg))) @@ -513,7 +535,8 @@ class DMatrix(object): and type if memory use is a concern. """ if len(mat.shape) != 2: - raise ValueError('Input numpy.ndarray must be 2 dimensional') + raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', + mat.shape) # flatten the array by rows and ensure it is float32. # we try to avoid data copies if possible (reshape returns a view when possible # and we explicitly tell np.array to try and avoid copying) @@ -1010,7 +1033,7 @@ class Booster(object): """ for d in cache: if not isinstance(d, DMatrix): - raise TypeError('invalid cache item: {}'.format(type(d).__name__)) + raise TypeError('invalid cache item: {}'.format(type(d).__name__), cache) self._validate_features(d) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) @@ -1353,6 +1376,10 @@ class Booster(object): if pred_interactions: option_mask |= 0x10 + if not isinstance(data, DMatrix): + raise TypeError('Expecting data to be a DMatrix object, got: ', + type(data)) + if validate_features: self._validate_features(data) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 6ebfcf961..62918c630 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1,121 +1,609 @@ -# pylint: disable=wrong-import-position,wrong-import-order,import-error -"""Dask extensions for distributed training. See xgboost/demo/dask for examples.""" -import os -import math -import platform -import logging -from threading import Thread -from . import rabit -from .core import DMatrix -from .compat import (DaskDataFrame, DaskSeries, DaskArray, - distributed_get_worker) - -from .tracker import RabitTracker - - -def _start_tracker(n_workers): - """ Start Rabit tracker """ - host = distributed_get_worker().address - if '://' in host: - host = host.rsplit('://', 1)[1] - host, port = host.split(':') - port = int(port) - env = {'DMLC_NUM_WORKER': n_workers} - rabit_context = RabitTracker(hostIP=host, nslave=n_workers) - env.update(rabit_context.slave_envs()) - - rabit_context.start(n_workers) - thread = Thread(target=rabit_context.join) - thread.daemon = True - thread.start() - return env - - -def get_local_data(data): - """ - Unpacks a distributed data object to get the rows local to this worker - - :param data: A distributed dask data object - :return: Local data partition e.g. numpy or pandas - """ - if isinstance(data, DaskArray): - total_partitions = len(data.chunks[0]) - else: - total_partitions = data.npartitions - partition_size = int(math.ceil(total_partitions / rabit.get_world_size())) - begin_partition = partition_size * rabit.get_rank() - end_partition = min(begin_partition + partition_size, total_partitions) - if isinstance(data, DaskArray): - return data.blocks[begin_partition:end_partition].compute() - - return data.partitions[begin_partition:end_partition].compute() - - -def create_worker_dmatrix(*args, **kwargs): - """ - Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard - DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to - get the local components. - - All dask dataframe arguments must use the same partitioning. - - :param args: DMatrix constructor args. - :return: DMatrix object containing data local to current dask worker - """ - dmatrix_args = [] - dmatrix_kwargs = {} - # Convert positional args - for arg in args: - if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)): - dmatrix_args.append(get_local_data(arg)) - else: - dmatrix_args.append(arg) - - # Convert keyword args - for k, v in kwargs.items(): - if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)): - dmatrix_kwargs[k] = get_local_data(v) - else: - dmatrix_kwargs[k] = v - - return DMatrix(*dmatrix_args, **dmatrix_kwargs) - - -def _run_with_rabit(rabit_args, func, *args): - worker = distributed_get_worker() - try: - os.environ["OMP_NUM_THREADS"] = str(worker.ncores) - except AttributeError: - os.environ["OMP_NUM_THREADS"] = str(worker.nthreads) - try: - rabit.init(rabit_args) - result = func(*args) - finally: - rabit.finalize() - return result - - -def run(client, func, *args): - """Launch arbitrary function on dask workers. Workers are connected by rabit, - allowing distributed training. The environment variable OMP_NUM_THREADS is - defined on each worker according to dask - this means that calls to - xgb.train() will use the threads allocated by dask by default, unless the - user overrides the nthread parameter. - - Note: Windows platforms are not officially - supported. Contributions are welcome here. - - :param client: Dask client representing the cluster - :param func: Python function to be executed by each worker. Typically - contains xgboost training code. - :param args: Arguments to be forwarded to func - :return: Dict containing the function return value for each worker - - """ - if platform.system() == 'Windows': - logging.warning('Windows is not officially supported for dask/xgboost' - 'integration. Contributions welcome.') - workers = list(client.scheduler_info()['workers'].keys()) - env = client.run(_start_tracker, len(workers), workers=[workers[0]]) - rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()] - return client.run(_run_with_rabit, rabit_args, func, *args) +# pylint: disable=too-many-arguments, too-many-locals +"""Dask extensions for distributed training. See +https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple +tutorial. Also xgboost/demo/dask for some examples. + +There are two sets of APIs in this module, one is the functional API including +``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper +inherited from single-node Scikit-Learn interface. + +The implementation is heavily influenced by dask_xgboost: +https://github.com/dask/dask-xgboost + +""" +import platform +import logging +from collections import defaultdict +from threading import Thread + +import numpy + +from . import rabit + +from .compat import DASK_INSTALLED +from .compat import distributed_get_worker, distributed_wait, distributed_comm +from .compat import da, dd, delayed, get_client +from .compat import sparse, scipy_sparse +from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat + +from .core import DMatrix, Booster, _expect +from .training import train as worker_train +from .tracker import RabitTracker +from .sklearn import XGBModel, XGBClassifierBase + +# Current status is considered as initial support, many features are +# not properly supported yet. +# +# TODOs: +# - Callback. +# - Label encoding. +# - CV +# - Ranking + + +def _start_tracker(host, n_workers): + """Start Rabit tracker """ + env = {'DMLC_NUM_WORKER': n_workers} + rabit_context = RabitTracker(hostIP=host, nslave=n_workers) + env.update(rabit_context.slave_envs()) + + rabit_context.start(n_workers) + thread = Thread(target=rabit_context.join) + thread.daemon = True + thread.start() + return env + + +def _assert_dask_installed(): + if not DASK_INSTALLED: + raise ImportError( + 'Dask needs to be installed in order to use this module') + + +class RabitContext: + '''A context controling rabit initialization and finalization.''' + def __init__(self, args): + self.args = args + + def __enter__(self): + rabit.init(self.args) + logging.debug('-------------- rabit say hello ------------------') + + def __exit__(self, *args): + rabit.finalize() + logging.debug('--------------- rabit say bye ------------------') + + +def concat(value): + '''To be replaced with dask builtin.''' + if isinstance(value[0], numpy.ndarray): + return numpy.concatenate(value, axis=0) + if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix): + return scipy_sparse.vstack(value, format='csr') + if sparse and isinstance(value[0], sparse.SparseArray): + return sparse.concatenate(value, axis=0) + if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)): + return pandas_concat(value, axis=0) + return dd.multi.concat(list(value), axis=0) + + +def _xgb_get_client(client): + '''Simple wrapper around testing None.''' + ret = get_client() if client is None else client + return ret + + +class DaskDMatrix: + # pylint: disable=missing-docstring, too-many-instance-attributes + '''DMatrix holding on references to Dask DataFrame or Dask Array. + + Parameters + ---------- + client: dask.distributed.Client + Specify the dask client used for training. Use default client + returned from dask if it's set to None. + data : dask.array.Array/dask.dataframe.DataFrame + data source of DMatrix. + label: dask.array.Array/dask.dataframe.DataFrame + label used for trainin. + missing : float, optional + Value in the input data (e.g. `numpy.ndarray`) which needs + to be present as a missing value. If None, defaults to np.nan. + weight : dask.array.Array/dask.dataframe.DataFrame + Weight for each instance. + feature_names : list, optional + Set names for features. + feature_types : list, optional + Set types for features + + ''' + + _feature_names = None # for previous version's pickle + _feature_types = None + + def __init__(self, + client, + data, + label=None, + missing=None, + weight=None, + feature_names=None, + feature_types=None): + _assert_dask_installed() + + self._feature_names = feature_names + self._feature_types = feature_types + self._missing = missing + + if len(data.shape) != 2: + _expect('2 dimensions input', data.shape) + self.n_rows = data.shape[0] + self.n_cols = data.shape[1] + + if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)): + raise TypeError(_expect((dd.DataFrame, da.Array), type(data))) + if not any( + isinstance(label, t) + for t in (dd.DataFrame, da.Array, dd.Series, type(None))): + raise TypeError( + _expect((dd.DataFrame, da.Array, dd.Series), type(label))) + + self.worker_map = None + self.has_label = label is not None + self.has_weights = weight is not None + + client = _xgb_get_client(client) + client.sync(self.map_local_data, client, data, label, weight) + + async def map_local_data(self, client, data, label=None, weights=None): + '''Obtain references to local data.''' + data = data.persist() + if label is not None: + label = label.persist() + if weights is not None: + weights = weights.persist() + # Breaking data into partitions, a trick borrowed from dask_xgboost. + + # `to_delayed` downgrades high-level objects into numpy or pandas + # equivalents. + X_parts = data.to_delayed() + if isinstance(X_parts, numpy.ndarray): + assert X_parts.shape[1] == 1 + X_parts = X_parts.flatten().tolist() + + if label is not None: + y_parts = label.to_delayed() + if isinstance(y_parts, numpy.ndarray): + assert y_parts.ndim == 1 or y_parts.shape[1] == 1 + y_parts = y_parts.flatten().tolist() + if weights is not None: + w_parts = weights.to_delayed() + if isinstance(w_parts, numpy.ndarray): + assert w_parts.ndim == 1 or w_parts.shape[1] == 1 + w_parts = w_parts.flatten().tolist() + + parts = [X_parts] + if label is not None: + assert len(X_parts) == len( + y_parts), 'Partitions between X and y are not consistent' + parts.append(y_parts) + if weights is not None: + assert len(X_parts) == len( + w_parts), 'Partitions between X and weight are not consistent.' + parts.append(w_parts) + parts = list(map(delayed, zip(*parts))) + + parts = client.compute(parts) + await distributed_wait(parts) # async wait for parts to be computed + + for part in parts: + assert part.status == 'finished' + + key_to_partition = {part.key: part for part in parts} + who_has = await client.scheduler.who_has( + keys=[part.key for part in parts]) + + worker_map = defaultdict(list) + for key, workers in who_has.items(): + worker_map[next(iter(workers))].append(key_to_partition[key]) + + self.worker_map = worker_map + + def get_worker_parts(self, worker): + '''Get mapped parts of data in each worker.''' + list_of_parts = self.worker_map[worker.address] + assert list_of_parts, 'data in ' + worker.address + ' was moved.' + assert isinstance(list_of_parts, list) + + # `get_worker_parts` is launched inside worker. In dask side + # this should be equal to `worker._get_client`. + client = get_client() + list_of_parts = client.gather(list_of_parts) + + if self.has_label: + if self.has_weights: + data, labels, weights = zip(*list_of_parts) + else: + data, labels = zip(*list_of_parts) + weights = None + else: + data = [d[0] for d in list_of_parts] + labels = None + weights = None + return data, labels, weights + + def get_worker_data(self, worker): + '''Get data that local to worker. + + Parameters + ---------- + worker: The worker used as key to data. + + Returns + ------- + A DMatrix object. + + ''' + data, labels, weights = self.get_worker_parts(worker) + + data = concat(data) + + if self.has_label: + labels = concat(labels) + else: + labels = None + if self.has_weights: + weights = concat(weights) + else: + weights = None + + dmatrix = DMatrix(data, + labels, + weight=weights, + missing=self._missing, + feature_names=self._feature_names, + feature_types=self._feature_types) + return dmatrix + + def get_worker_data_shape(self, worker): + '''Get the shape of data X in each worker.''' + data, _, _ = self.get_worker_parts(worker) + + shapes = [d.shape for d in data] + rows = 0 + cols = 0 + for shape in shapes: + rows += shape[0] + cols += shape[1] + return (rows, cols) + + def num_row(self): + return self.n_rows + + def num_col(self): + return self.n_cols + + +def _get_rabit_args(worker_map, client): + '''Get rabit context arguments from data distribution in DaskDMatrix.''' + host = distributed_comm.get_address_host(client.scheduler.address) + + env = client.run_on_scheduler(_start_tracker, host.strip('/:'), + len(worker_map)) + rabit_args = [('%s=%s' % item).encode() for item in env.items()] + return rabit_args + +# train and predict methods are supposed to be "functional", which meets the +# dask paradigm. But as a side effect, the `evals_result` in single-node API +# is no longer supported since it mutates the input parameter, and it's not +# intuitive to sync the mutation result. Therefore, a dictionary containing +# evaluation history is instead returned. + + +def train(client, params, dtrain, *args, evals=(), **kwargs): + '''Train XGBoost model. + + Parameters + ---------- + client: dask.distributed.Client + Specify the dask client used for training. Use default client + returned from dask if it's set to None. + + Other parameters are the same as `xgboost.train` except for `evals_result`, + which is returned as part of function return value instead of argument. + + Returns + ------- + results: dict + A dictionary containing trained booster and evaluation history. + `history` field is the same as `eval_result` from `xgboost.train`. + + .. code-block:: python + + {'booster': xgboost.Booster, + 'history': {'train': {'logloss': ['0.48253', '0.35953']}, + 'eval': {'logloss': ['0.480385', '0.357756']}}} + + ''' + _assert_dask_installed() + if platform.system() == 'Windows': + msg = 'Windows is not officially supported for dask/xgboost,' + msg += ' contribution are welcomed.' + logging.warning(msg) + + if 'evals_result' in kwargs.keys(): + raise ValueError( + 'evals_result is not supported in dask interface.', + 'The evaluation history is returned as result of training.') + + client = _xgb_get_client(client) + + worker_map = dtrain.worker_map + rabit_args = _get_rabit_args(worker_map, client) + + def dispatched_train(worker_id): + '''Perform training on worker.''' + logging.info('Training on %d', worker_id) + worker = distributed_get_worker() + local_dtrain = dtrain.get_worker_data(worker) + + local_evals = [] + if evals: + for mat, name in evals: + local_mat = mat.get_worker_data(worker) + local_evals.append((local_mat, name)) + + with RabitContext(rabit_args): + local_history = {} + local_param = params.copy() # just to be consistent + bst = worker_train(params=local_param, + dtrain=local_dtrain, + *args, + evals_result=local_history, + evals=local_evals, + **kwargs) + ret = {'booster': bst, 'history': local_history} + if rabit.get_rank() != 0: + ret = None + return ret + + futures = client.map(dispatched_train, + range(len(worker_map)), + workers=list(worker_map.keys())) + results = client.gather(futures) + return list(filter(lambda ret: ret is not None, results))[0] + + +def predict(client, model, data, *args): + '''Run prediction with a trained booster. + + Parameters + ---------- + client: dask.distributed.Client + Specify the dask client used for training. Use default client + returned from dask if it's set to None. + model: A Booster or a dictionary returned by `xgboost.dask.train`. + The trained model. + data: DaskDMatrix + Input data used for prediction. + + Returns + ------- + prediction: dask.array.Array + + ''' + _assert_dask_installed() + if isinstance(model, Booster): + booster = model + elif isinstance(model, dict): + booster = model['booster'] + else: + raise TypeError(_expect([Booster, dict], type(model))) + + if not isinstance(data, DaskDMatrix): + raise TypeError(_expect([DaskDMatrix], type(data))) + + worker_map = data.worker_map + client = _xgb_get_client(client) + + rabit_args = _get_rabit_args(worker_map, client) + + def dispatched_predict(worker_id): + '''Perform prediction on each worker.''' + logging.info('Predicting on %d', worker_id) + worker = distributed_get_worker() + local_x = data.get_worker_data(worker) + + with RabitContext(rabit_args): + local_predictions = booster.predict(data=local_x, *args) + return local_predictions + + futures = client.map(dispatched_predict, + range(len(worker_map)), + workers=list(worker_map.keys())) + + def dispatched_get_shape(worker_id): + '''Get shape of data in each worker.''' + logging.info('Trying to get data shape on %d', worker_id) + worker = distributed_get_worker() + rows, cols = data.get_worker_data_shape(worker) + return rows, cols + + # Constructing a dask array from list of numpy arrays + # See https://docs.dask.org/en/latest/array-creation.html + futures_shape = client.map(dispatched_get_shape, + range(len(worker_map)), + workers=list(worker_map.keys())) + shapes = client.gather(futures_shape) + arrays = [] + for i in range(len(futures_shape)): + arrays.append(da.from_delayed(futures[i], shape=shapes[i], + dtype=numpy.float32)) + predictions = da.concatenate(arrays, axis=0) + return predictions + + +def _evaluation_matrices(client, validation_set, sample_weights): + ''' + Parameters + ---------- + validation_set: list of tuples + Each tuple contains a validation dataset including input X and label y. + E.g.: + + .. code-block:: python + + [(X_0, y_0), (X_1, y_1), ... ] + + sample_weights: list of arrays + The weight vector for validation data. + + Returns + ------- + evals: list of validation DMatrix + ''' + evals = [] + if validation_set is not None: + assert isinstance(validation_set, list) + for i, e in enumerate(validation_set): + w = (sample_weights[i] + if sample_weights is not None else None) + dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w) + evals.append((dmat, 'validation_{}'.format(i))) + else: + evals = None + return evals + + +class DaskScikitLearnBase(XGBModel): + '''Base class for implementing scikit-learn interface with Dask''' + + _client = None + + # pylint: disable=arguments-differ + def fit(self, + X, + y, + sample_weights=None, + eval_set=None, + sample_weight_eval_set=None): + '''Fit the regressor. + + Parameters + ---------- + X : array_like + Feature matrix + y : array_like + Labels + sample_weight : array_like + instance weights + eval_set : list, optional + A list of (X, y) tuple pairs to use as validation sets, for which + metrics will be computed. + Validation metrics will help us track the performance of the model. + sample_weight_eval_set : list, optional + A list of the form [L_1, L_2, ..., L_n], where each L_i is a list + of group weights on the i-th validation set.''' + raise NotImplementedError + + def predict(self, data): # pylint: disable=arguments-differ + '''Predict with `data`. + Parameters + ---------- + data: data that can be used to construct a DaskDMatrix + Returns + ------- + prediction : dask.array.Array''' + raise NotImplementedError + + @property + def client(self): + '''The dask client used in this model.''' + client = _xgb_get_client(self._client) + return client + + @client.setter + def client(self, clt): + self._client = clt + + +class DaskXGBRegressor(DaskScikitLearnBase): + # pylint: disable=missing-docstring + __doc__ = ('Implementation of the scikit-learn API for XGBoost ' + + 'regression. \n\n') + '\n'.join( + XGBModel.__doc__.split('\n')[2:]) + + def fit(self, + X, + y, + sample_weights=None, + eval_set=None, + sample_weight_eval_set=None): + _assert_dask_installed() + dtrain = DaskDMatrix(client=self.client, + data=X, label=y, weight=sample_weights) + params = self.get_xgb_params() + evals = _evaluation_matrices(self.client, + eval_set, sample_weight_eval_set) + + results = train(self.client, params, dtrain, + num_boost_round=self.get_num_boosting_rounds(), + evals=evals) + self._Booster = results['booster'] + # pylint: disable=attribute-defined-outside-init + self.evals_result_ = results['history'] + return self + + def predict(self, data): # pylint: disable=arguments-differ + _assert_dask_installed() + test_dmatrix = DaskDMatrix(client=self.client, data=data) + pred_probs = predict(client=self.client, + model=self.get_booster(), data=test_dmatrix) + return pred_probs + + +class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): + # pylint: disable=missing-docstring + _client = None + __doc__ = ('Implementation of the scikit-learn API for XGBoost ' + + 'classification.\n\n') + '\n'.join( + XGBModel.__doc__.split('\n')[2:]) + + def fit(self, + X, + y, + sample_weights=None, + eval_set=None, + sample_weight_eval_set=None): + _assert_dask_installed() + dtrain = DaskDMatrix(client=self.client, + data=X, label=y, weight=sample_weights) + params = self.get_xgb_params() + + # pylint: disable=attribute-defined-outside-init + self.classes_ = da.unique(y).compute() + self.n_classes_ = len(self.classes_) + + if self.n_classes_ > 2: + params["objective"] = "multi:softprob" + params['num_class'] = self.n_classes_ + else: + params["objective"] = "binary:logistic" + params.setdefault('num_class', self.n_classes_) + + evals = _evaluation_matrices(self.client, + eval_set, sample_weight_eval_set) + results = train(self.client, params, dtrain, + num_boost_round=self.get_num_boosting_rounds(), + evals=evals) + self._Booster = results['booster'] + # pylint: disable=attribute-defined-outside-init + self.evals_result_ = results['history'] + return self + + def predict(self, data): # pylint: disable=arguments-differ + _assert_dask_installed() + test_dmatrix = DaskDMatrix(client=self.client, data=data) + pred_probs = predict(client=self.client, + model=self.get_booster(), data=test_dmatrix) + return pred_probs diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 019ac5875..829724258 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,8 +1,6 @@ # coding: utf-8 # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302 """Scikit-Learn Wrapper interface for XGBoost.""" -from __future__ import absolute_import - import warnings import json import numpy as np @@ -282,7 +280,8 @@ class XGBModel(XGBModelBase): "object {} will be lost. ".format(type(self).__name__) + "If you did not mean to export the model to " + "a non-Python binding of XGBoost, consider " + - "using `pickle` or `joblib` to save your model.", Warning) + "using `pickle` or `joblib` to save your model.", + Warning) self.get_booster().save_model(fname) def load_model(self, fname): diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 238efc3db..663beefb4 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1,93 +1,96 @@ -import testing as tm -import pytest -import xgboost as xgb -import numpy as np -import sys - -if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) - -try: - from distributed.utils_test import client, loop, cluster_fixture - import dask.dataframe as dd - import dask.array as da -except ImportError: - pass - -pytestmark = pytest.mark.skipif(**tm.no_dask()) - - -def run_train(): - # Contains one label equal to rank - dmat = xgb.DMatrix(np.array([[0]]), label=[xgb.rabit.get_rank()]) - bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1) - pred = bst.predict(dmat) - expected_result = np.average(range(xgb.rabit.get_world_size())) - assert all(p == expected_result for p in pred) - - -def test_train(client): - # Train two workers, the first has label 0, the second has label 1 - # If they build the model together the output should be 0.5 - xgb.dask.run(client, run_train) - # Run again to check we can have multiple sessions - xgb.dask.run(client, run_train) - - -def run_create_dmatrix(X, y, weights): - dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights) - # Expect this worker to get two partitions and concatenate them - assert dmat.num_row() == 50 - - -def test_dask_dataframe(client): - n = 10 - m = 100 - partition_size = 25 - X = dd.from_array(np.random.random((m, n)), partition_size) - y = dd.from_array(np.random.random(m), partition_size) - weights = dd.from_array(np.random.random(m), partition_size) - xgb.dask.run(client, run_create_dmatrix, X, y, weights) - - -def test_dask_array(client): - n = 10 - m = 100 - partition_size = 25 - X = da.random.random((m, n), partition_size) - y = da.random.random(m, partition_size) - weights = da.random.random(m, partition_size) - xgb.dask.run(client, run_create_dmatrix, X, y, weights) - - -def run_get_local_data(X, y): - X_local = xgb.dask.get_local_data(X) - y_local = xgb.dask.get_local_data(y) - assert (X_local.shape == (50, 10)) - assert (y_local.shape == (50,)) - - -def test_get_local_data(client): - n = 10 - m = 100 - partition_size = 25 - X = da.random.random((m, n), partition_size) - y = da.random.random(m, partition_size) - xgb.dask.run(client, run_get_local_data, X, y) - - -def run_sklearn(): - # Contains one label equal to rank - X = np.array([[0]]) - y = [xgb.rabit.get_rank()] - model = xgb.XGBRegressor(learning_rate=1.0) - model.fit(X, y) - pred = model.predict(X) - expected_result = np.average(range(xgb.rabit.get_world_size())) - assert all(p == expected_result for p in pred) - return pred - - -def test_sklearn(client): - result = xgb.dask.run(client, run_sklearn) - print(result) +import testing as tm +import pytest +import xgboost as xgb +import sys +import numpy as np + +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + +try: + from distributed.utils_test import client, loop, cluster_fixture + import dask.dataframe as dd + import dask.array as da + from xgboost.dask import DaskDMatrix +except ImportError: + pass + +pytestmark = pytest.mark.skipif(**tm.no_dask()) + +kRows = 1000 + + +def generate_array(): + n = 10 + partition_size = 20 + X = da.random.random((kRows, n), partition_size) + y = da.random.random(kRows, partition_size) + return X, y + + +def test_from_dask_dataframe(client): + X, y = generate_array() + + X = dd.from_dask_array(X) + y = dd.from_dask_array(y) + + dtrain = DaskDMatrix(client, X, y) + booster = xgb.dask.train( + client, {}, dtrain, num_boost_round=2)['booster'] + + prediction = xgb.dask.predict(client, model=booster, data=dtrain) + + assert isinstance(prediction, da.Array) + assert prediction.shape[0] == kRows, prediction + + with pytest.raises(ValueError): + # evals_result is not supported in dask interface. + xgb.dask.train( + client, {}, dtrain, num_boost_round=2, evals_result={}) + + +def test_from_dask_array(client): + X, y = generate_array() + dtrain = DaskDMatrix(client, X, y) + # results is {'booster': Booster, 'history': {...}} + result = xgb.dask.train(client, {}, dtrain) + + prediction = xgb.dask.predict(client, result, dtrain) + + assert isinstance(prediction, da.Array) + + +def test_regressor(client): + X, y = generate_array() + regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) + regressor.set_params(tree_method='hist') + regressor.client = client + regressor.fit(X, y, eval_set=[(X, y)]) + prediction = regressor.predict(X) + + history = regressor.evals_result() + + assert isinstance(prediction, da.Array) + assert isinstance(history, dict) + + assert list(history['validation_0'].keys())[0] == 'rmse' + assert len(history['validation_0']['rmse']) == 2 + + +def test_classifier(client): + X, y = generate_array() + y = (y * 10).astype(np.int32) + classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2) + classifier.client = client + classifier.fit(X, y, eval_set=[(X, y)]) + prediction = classifier.predict(X) + + history = classifier.evals_result() + + assert isinstance(prediction, da.Array) + assert isinstance(history, dict) + + assert list(history.keys())[0] == 'validation_0' + assert list(history['validation_0'].keys())[0] == 'merror' + assert len(list(history['validation_0'])) == 1 + assert len(history['validation_0']['merror']) == 2 diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 1fa0f513a..3d9602fd1 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -18,6 +18,7 @@ if [ ${TASK} == "python_test" ]; then conda install numpy scipy pandas matplotlib scikit-learn python -m pip install graphviz pytest pytest-cov codecov + python -m pip install dask distributed dask[dataframe] python -m pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl python -m pytest -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1 codecov