diff --git a/.gitignore b/.gitignore index 803e0e6d8..e7d241dc1 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ plugin/updater_gpu/test/cpp/data # files from R-package source install **/config.status R-package/src/Makevars + +# Python install +python-package/xgboost/tracker.py diff --git a/demo/dask/README.md b/demo/dask/README.md new file mode 100644 index 000000000..7e10dd321 --- /dev/null +++ b/demo/dask/README.md @@ -0,0 +1,20 @@ +# Dask Integration + +[Dask](https://dask.org/) is a parallel computing library built on Python. Dask allows easy management of distributed workers and excels handling large distributed data science workflows. + +The simple demo shows how to train and make predictions for an xgboost model on a distributed dask environment. We make use of first-class support in xgboost for launching dask workers. Workers launched in this manner are automatically connected via xgboosts underlying communication framework, Rabit. The calls to `xgb.train()` and `xgb.predict()` occur in parallel on each worker and are synchronized. + +The GPU demo shows how to configure and use GPUs on the local machine for training on a large dataset. + +## Requirements +Dask is trivial to install using either pip or conda. [See here for official install documentation](https://docs.dask.org/en/latest/install.html). + +The GPU demo requires [GPUtil](https://github.com/anderskm/gputil) for detecting system GPUs. + +Install via `pip install gputil` + +## Running the scripts +```bash +python dask_simple_demo.py +python dask_gpu_demo.py +``` \ No newline at end of file diff --git a/demo/dask/dask_gpu_demo.py b/demo/dask/dask_gpu_demo.py new file mode 100644 index 000000000..d3f98df8b --- /dev/null +++ b/demo/dask/dask_gpu_demo.py @@ -0,0 +1,42 @@ +from dask.distributed import Client, LocalCluster +import dask.dataframe as dd +import dask.array as da +import numpy as np +import xgboost as xgb +import GPUtil +import time + + +# Define the function to be executed on each worker +def train(X, y, available_devices): + dtrain = xgb.dask.create_worker_dmatrix(X, y) + local_device = available_devices[xgb.rabit.get_rank()] + # Specify the GPU algorithm and device for this worker + params = {"tree_method": "gpu_hist", "gpu_id": local_device} + print("Worker {} starting training on {} rows".format(xgb.rabit.get_rank(), dtrain.num_row())) + start = time.time() + xgb.train(params, dtrain, num_boost_round=500) + end = time.time() + print("Worker {} finished training in {:0.2f}s".format(xgb.rabit.get_rank(), end - start)) + + +def main(): + max_devices = 16 + # Check which devices we have locally + available_devices = GPUtil.getAvailable(limit=max_devices) + # Use one worker per device + cluster = LocalCluster(n_workers=len(available_devices), threads_per_worker=4) + client = Client(cluster) + + # Set up a relatively large regression problem + n = 100 + m = 10000000 + partition_size = 100000 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + xgb.dask.run(client, train, X, y, available_devices) + + +if __name__ == '__main__': + main() diff --git a/demo/dask/dask_simple_demo.py b/demo/dask/dask_simple_demo.py new file mode 100644 index 000000000..56f07cdae --- /dev/null +++ b/demo/dask/dask_simple_demo.py @@ -0,0 +1,68 @@ +from dask.distributed import Client, LocalCluster +import dask.dataframe as dd +import dask.array as da +import numpy as np +import xgboost as xgb + + +# Define the function to be executed on each worker +def train(X, y): + print("Start training with worker #{}".format(xgb.rabit.get_rank())) + # X,y are dask objects distributed across the cluster. + # We must obtain the data local to this worker and convert it to DMatrix for training. + # xgb.dask.create_worker_dmatrix follows the API exactly of the standard DMatrix constructor + # (xgb.DMatrix()), except that it 'unpacks' dask distributed objects to obtain data local to + # this worker + dtrain = xgb.dask.create_worker_dmatrix(X, y) + + # Train on the data. Each worker will communicate and synchronise during training. The output + # model is expected to be identical on each worker. + bst = xgb.train({}, dtrain) + # Make predictions on local data + pred = bst.predict(dtrain) + print("Finished training with worker #{}".format(xgb.rabit.get_rank())) + # Get text representation of the model + return bst.get_dump() + + +def train_with_sklearn(X, y): + print("Training with worker #{} using the sklearn API".format(xgb.rabit.get_rank())) + X_local = xgb.dask.get_local_data(X) + y_local = xgb.dask.get_local_data(y) + model = xgb.XGBRegressor(n_estimators=10) + model.fit(X_local, y_local) + print("Finished training with worker #{} using the sklearn API".format(xgb.rabit.get_rank())) + return model.predict(X_local) + + +def main(): + # Launch a very simple local cluster using two distributed workers with two CPU threads each + cluster = LocalCluster(n_workers=2, threads_per_worker=2) + client = Client(cluster) + + # Generate some small test data as a dask array + # These data frames are internally split into partitions of 20 rows each and then distributed + # among workers, so we will have 5 partitions distributed among 2 workers + # Note that the partition size MUST be consistent across different dask dataframes/arrays + n = 10 + m = 100 + partition_size = 20 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + + # xgb.dask.run launches an arbitrary function and its arguments on the cluster + # Here train(X, y) will be called on each worker + # This function blocks until all work is complete + models = xgb.dask.run(client, train, X, y) + + # models contains a dictionary mapping workers to results + # We expect that the models are the same over all workers + first_model = next(iter(models.values())) + assert all(model == first_model for worker, model in models.items()) + + # We can also train using the sklearn API + results = xgb.dask.run(client, train_with_sklearn, X, y) + + +if __name__ == '__main__': + main() diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index cb58abd6b..63c3fdd58 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -73,3 +73,13 @@ Callback API .. autofunction:: xgboost.callback.reset_learning_rate .. autofunction:: xgboost.callback.early_stop + +Dask API +-------- +.. automodule:: xgboost.dask + +.. autofunction:: xgboost.dask.run + +.. autofunction:: xgboost.dask.create_worker_dmatrix + +.. autofunction:: xgboost.dask.get_local_data diff --git a/python-package/setup.py b/python-package/setup.py index d9ef113c1..30e0d070f 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -3,8 +3,10 @@ from __future__ import absolute_import import io import sys +import shutil import os from setuptools import setup, find_packages + # import subprocess sys.path.insert(0, '.') @@ -27,6 +29,10 @@ for libfile in libpath['find_lib_path'](): continue print("Install libxgboost from: %s" % LIB_PATH) + +# Get dmlc tracker script +shutil.copy('../dmlc-core/tracker/dmlc_tracker/tracker.py', 'xgboost/') + # Please use setup_pip.py for generating and deploying pip installation # detailed instruction in setup_pip.py setup(name='xgboost', diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 217858659..f31d9fbba 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -11,6 +11,7 @@ import os from .core import DMatrix, Booster from .training import train, cv from . import rabit # noqa +from . import dask # noqa try: from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker from .sklearn import XGBRFClassifier, XGBRFRegressor diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 245f7fd66..21bb4585c 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -6,49 +6,48 @@ from __future__ import absolute_import import sys - PY3 = (sys.version_info[0] == 3) if PY3: # pylint: disable=invalid-name, redefined-builtin STRING_TYPES = (str,) + def py_str(x): """convert c string back to python string""" return x.decode('utf-8') else: STRING_TYPES = (basestring,) # pylint: disable=undefined-variable + def py_str(x): """convert c string back to python string""" return x try: - import cPickle as pickle # noqa + import cPickle as pickle # noqa except ImportError: - import pickle # noqa - + import pickle # noqa # pandas try: from pandas import DataFrame from pandas import MultiIndex + PANDAS_INSTALLED = True except ImportError: - # pylint: disable=too-few-public-methods - class MultiIndex(object): - """ dummy for pandas.MultiIndex """ - - # pylint: disable=too-few-public-methods - class DataFrame(object): - """ dummy for pandas.DataFrame """ - + MultiIndex = object + DataFrame = object PANDAS_INSTALLED = False # dt try: + # Workaround for #4473, compatibility with dask + if sys.__stdin__.closed: + sys.__stdin__ = None import datatable + if hasattr(datatable, "Frame"): DataTable = datatable.Frame else: @@ -60,6 +59,7 @@ except ImportError: class DataTable(object): """ dummy for datatable.DataTable """ + DT_INSTALLED = False # sklearn @@ -67,6 +67,7 @@ try: from sklearn.base import BaseEstimator from sklearn.base import RegressorMixin, ClassifierMixin from sklearn.preprocessing import LabelEncoder + try: from sklearn.model_selection import KFold, StratifiedKFold except ImportError: @@ -92,3 +93,20 @@ except ImportError: XGBKFold = None XGBStratifiedKFold = None XGBLabelEncoder = None + + +# dask +try: + from dask.dataframe import DataFrame as DaskDataFrame + from dask.dataframe import Series as DaskSeries + from dask.array import Array as DaskArray + from distributed import get_worker as distributed_get_worker + + DASK_INSTALLED = True +except ImportError: + DaskDataFrame = object + DaskSeries = object + DaskArray = object + distributed_get_worker = None + + DASK_INSTALLED = False diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py new file mode 100644 index 000000000..5f6d7db20 --- /dev/null +++ b/python-package/xgboost/dask.py @@ -0,0 +1,123 @@ +# pylint: disable=wrong-import-position,wrong-import-order,import-error +"""Dask extensions for distributed training. See xgboost/demo/dask for examples.""" +import os +import sys +import math +import platform +import logging +from threading import Thread +from . import rabit +from .core import DMatrix +from .compat import (DaskDataFrame, DaskSeries, DaskArray, + distributed_get_worker) + +# Try to find the dmlc tracker script +# For developers it will be the following +TRACKER_PATH = os.path.dirname(__file__) + "/../../dmlc-core/tracker/dmlc_tracker" +sys.path.append(TRACKER_PATH) +try: + from tracker import RabitTracker # noqa +except ImportError: + # If packaged it will be local + from .tracker import RabitTracker # noqa + + +def _start_tracker(n_workers): + """ Start Rabit tracker """ + host = distributed_get_worker().address + if '://' in host: + host = host.rsplit('://', 1)[1] + host, port = host.split(':') + port = int(port) + env = {'DMLC_NUM_WORKER': n_workers} + rabit_context = RabitTracker(hostIP=host, nslave=n_workers) + env.update(rabit_context.slave_envs()) + + rabit_context.start(n_workers) + thread = Thread(target=rabit_context.join) + thread.daemon = True + thread.start() + return env + + +def get_local_data(data): + """ + Unpacks a distributed data object to get the rows local to this worker + :param data: A distributed dask data object + :return: Local data partition e.g. numpy or pandas + """ + if isinstance(data, DaskArray): + total_partitions = len(data.chunks[0]) + else: + total_partitions = data.npartitions + partition_size = int(math.ceil(total_partitions / rabit.get_world_size())) + begin_partition = partition_size * rabit.get_rank() + end_partition = min(begin_partition + partition_size, total_partitions) + if isinstance(data, DaskArray): + return data.blocks[begin_partition:end_partition].compute() + + return data.partitions[begin_partition:end_partition].compute() + + +def create_worker_dmatrix(*args, **kwargs): + """ + Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard + DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to + get the local components. + + All dask dataframe arguments must use the same partitioning. + + :param args: DMatrix constructor args. + :return: DMatrix object containing data local to current dask worker + """ + dmatrix_args = [] + dmatrix_kwargs = {} + # Convert positional args + for arg in args: + if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)): + dmatrix_args.append(get_local_data(arg)) + else: + dmatrix_args.append(arg) + + # Convert keyword args + for k, v in kwargs.items(): + if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)): + dmatrix_kwargs[k] = get_local_data(v) + else: + dmatrix_kwargs[k] = v + + return DMatrix(*dmatrix_args, **dmatrix_kwargs) + + +def _run_with_rabit(rabit_args, func, *args): + os.environ["OMP_NUM_THREADS"] = str(distributed_get_worker().ncores) + try: + rabit.init(rabit_args) + result = func(*args) + finally: + rabit.finalize() + return result + + +def run(client, func, *args): + """ + Launch arbitrary function on dask workers. Workers are connected by rabit, allowing + distributed training. The environment variable OMP_NUM_THREADS is defined on each worker + according to dask - this means that calls to xgb.train() will use the threads allocated by + dask by default, unless the user overrides the nthread parameter. + + Note: Windows platforms are not officially supported. Contributions are welcome here. + :param client: Dask client representing the cluster + :param func: Python function to be executed by each worker. Typically contains xgboost + training code. + :param args: Arguments to be forwarded to func + :return: Dict containing the function return value for each worker + """ + if platform.system() == 'Windows': + logging.warning( + 'Windows is not officially supported for dask/xgboost integration. Contributions ' + 'welcome.') + workers = list(client.scheduler_info()['workers'].keys()) + env = client.run(_start_tracker, len(workers), workers=[workers[0]]) + rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()] + return client.run(_run_with_rabit, rabit_args, func, *args) diff --git a/tests/ci_build/Dockerfile.cpu b/tests/ci_build/Dockerfile.cpu index aaea658fa..4e7fcfc59 100644 --- a/tests/ci_build/Dockerfile.cpu +++ b/tests/ci_build/Dockerfile.cpu @@ -23,7 +23,8 @@ ENV GOSU_VERSION 1.10 RUN \ pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \ breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 && \ - pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl + pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \ + conda install dask # Install lightweight sudo (not bound to TTY) RUN set -ex; \ diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 2d646824b..c0a21f517 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -16,7 +16,8 @@ ENV PATH=/opt/python/bin:$PATH # Install Python packages RUN \ - pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz + pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \ + conda install dask ENV GOSU_VERSION 1.10 diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py new file mode 100644 index 000000000..6ee668428 --- /dev/null +++ b/tests/python/test_with_dask.py @@ -0,0 +1,93 @@ +import testing as tm +import pytest +import xgboost as xgb +import numpy as np +import sys + +if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows", allow_module_level=True) + +try: + from distributed.utils_test import client, loop, cluster_fixture + import dask.dataframe as dd + import dask.array as da +except ImportError: + pass + +pytestmark = pytest.mark.skipif(**tm.no_dask()) + + +def run_train(): + # Contains one label equal to rank + dmat = xgb.DMatrix([[0]], label=[xgb.rabit.get_rank()]) + bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1) + pred = bst.predict(dmat) + expected_result = np.average(range(xgb.rabit.get_world_size())) + assert all(p == expected_result for p in pred) + + +def test_train(client): + # Train two workers, the first has label 0, the second has label 1 + # If they build the model together the output should be 0.5 + xgb.dask.run(client, run_train) + # Run again to check we can have multiple sessions + xgb.dask.run(client, run_train) + + +def run_create_dmatrix(X, y, weights): + dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights) + # Expect this worker to get two partitions and concatenate them + assert dmat.num_row() == 50 + + +def test_dask_dataframe(client): + n = 10 + m = 100 + partition_size = 25 + X = dd.from_array(np.random.random((m, n)), partition_size) + y = dd.from_array(np.random.random(m), partition_size) + weights = dd.from_array(np.random.random(m), partition_size) + xgb.dask.run(client, run_create_dmatrix, X, y, weights) + + +def test_dask_array(client): + n = 10 + m = 100 + partition_size = 25 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + weights = da.random.random(m, partition_size) + xgb.dask.run(client, run_create_dmatrix, X, y, weights) + + +def run_get_local_data(X, y): + X_local = xgb.dask.get_local_data(X) + y_local = xgb.dask.get_local_data(y) + assert (X_local.shape == (50, 10)) + assert (y_local.shape == (50,)) + + +def test_get_local_data(client): + n = 10 + m = 100 + partition_size = 25 + X = da.random.random((m, n), partition_size) + y = da.random.random(m, partition_size) + xgb.dask.run(client, run_get_local_data, X, y) + + +def run_sklearn(): + # Contains one label equal to rank + X = [[0]] + y = [xgb.rabit.get_rank()] + model = xgb.XGBRegressor(learning_rate=1.0) + model.fit(X, y) + pred = model.predict(X) + expected_result = np.average(range(xgb.rabit.get_world_size())) + assert all(p == expected_result for p in pred) + return pred + + +def test_sklearn(client): + result = xgb.dask.run(client, run_sklearn) + print(result) diff --git a/tests/python/testing.py b/tests/python/testing.py index 234e39527..b11b5b7d4 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -1,5 +1,5 @@ # coding: utf-8 -from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED +from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED, DASK_INSTALLED def no_sklearn(): @@ -7,6 +7,11 @@ def no_sklearn(): 'reason': 'Scikit-Learn is not installed'} +def no_dask(): + return {'condition': not DASK_INSTALLED, + 'reason': 'Dask is not installed'} + + def no_pandas(): return {'condition': not PANDAS_INSTALLED, 'reason': 'Pandas is not installed.'} @@ -20,7 +25,7 @@ def no_dt(): def no_matplotlib(): reason = 'Matplotlib is not installed.' try: - import matplotlib.pyplot as _ # noqa + import matplotlib.pyplot as _ # noqa return {'condition': False, 'reason': reason} except ImportError: