Add native support for Dask (#4473)

* Add native support for Dask

* Add multi-GPU demo

* Add sklearn example
This commit is contained in:
Rory Mitchell 2019-05-27 13:29:28 +12:00 committed by GitHub
parent 55e645c5f5
commit 09b90d9329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 407 additions and 16 deletions

3
.gitignore vendored
View File

@ -96,3 +96,6 @@ plugin/updater_gpu/test/cpp/data
# files from R-package source install
**/config.status
R-package/src/Makevars
# Python install
python-package/xgboost/tracker.py

20
demo/dask/README.md Normal file
View File

@ -0,0 +1,20 @@
# Dask Integration
[Dask](https://dask.org/) is a parallel computing library built on Python. Dask allows easy management of distributed workers and excels handling large distributed data science workflows.
The simple demo shows how to train and make predictions for an xgboost model on a distributed dask environment. We make use of first-class support in xgboost for launching dask workers. Workers launched in this manner are automatically connected via xgboosts underlying communication framework, Rabit. The calls to `xgb.train()` and `xgb.predict()` occur in parallel on each worker and are synchronized.
The GPU demo shows how to configure and use GPUs on the local machine for training on a large dataset.
## Requirements
Dask is trivial to install using either pip or conda. [See here for official install documentation](https://docs.dask.org/en/latest/install.html).
The GPU demo requires [GPUtil](https://github.com/anderskm/gputil) for detecting system GPUs.
Install via `pip install gputil`
## Running the scripts
```bash
python dask_simple_demo.py
python dask_gpu_demo.py
```

View File

@ -0,0 +1,42 @@
from dask.distributed import Client, LocalCluster
import dask.dataframe as dd
import dask.array as da
import numpy as np
import xgboost as xgb
import GPUtil
import time
# Define the function to be executed on each worker
def train(X, y, available_devices):
dtrain = xgb.dask.create_worker_dmatrix(X, y)
local_device = available_devices[xgb.rabit.get_rank()]
# Specify the GPU algorithm and device for this worker
params = {"tree_method": "gpu_hist", "gpu_id": local_device}
print("Worker {} starting training on {} rows".format(xgb.rabit.get_rank(), dtrain.num_row()))
start = time.time()
xgb.train(params, dtrain, num_boost_round=500)
end = time.time()
print("Worker {} finished training in {:0.2f}s".format(xgb.rabit.get_rank(), end - start))
def main():
max_devices = 16
# Check which devices we have locally
available_devices = GPUtil.getAvailable(limit=max_devices)
# Use one worker per device
cluster = LocalCluster(n_workers=len(available_devices), threads_per_worker=4)
client = Client(cluster)
# Set up a relatively large regression problem
n = 100
m = 10000000
partition_size = 100000
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
xgb.dask.run(client, train, X, y, available_devices)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,68 @@
from dask.distributed import Client, LocalCluster
import dask.dataframe as dd
import dask.array as da
import numpy as np
import xgboost as xgb
# Define the function to be executed on each worker
def train(X, y):
print("Start training with worker #{}".format(xgb.rabit.get_rank()))
# X,y are dask objects distributed across the cluster.
# We must obtain the data local to this worker and convert it to DMatrix for training.
# xgb.dask.create_worker_dmatrix follows the API exactly of the standard DMatrix constructor
# (xgb.DMatrix()), except that it 'unpacks' dask distributed objects to obtain data local to
# this worker
dtrain = xgb.dask.create_worker_dmatrix(X, y)
# Train on the data. Each worker will communicate and synchronise during training. The output
# model is expected to be identical on each worker.
bst = xgb.train({}, dtrain)
# Make predictions on local data
pred = bst.predict(dtrain)
print("Finished training with worker #{}".format(xgb.rabit.get_rank()))
# Get text representation of the model
return bst.get_dump()
def train_with_sklearn(X, y):
print("Training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
X_local = xgb.dask.get_local_data(X)
y_local = xgb.dask.get_local_data(y)
model = xgb.XGBRegressor(n_estimators=10)
model.fit(X_local, y_local)
print("Finished training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
return model.predict(X_local)
def main():
# Launch a very simple local cluster using two distributed workers with two CPU threads each
cluster = LocalCluster(n_workers=2, threads_per_worker=2)
client = Client(cluster)
# Generate some small test data as a dask array
# These data frames are internally split into partitions of 20 rows each and then distributed
# among workers, so we will have 5 partitions distributed among 2 workers
# Note that the partition size MUST be consistent across different dask dataframes/arrays
n = 10
m = 100
partition_size = 20
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
# xgb.dask.run launches an arbitrary function and its arguments on the cluster
# Here train(X, y) will be called on each worker
# This function blocks until all work is complete
models = xgb.dask.run(client, train, X, y)
# models contains a dictionary mapping workers to results
# We expect that the models are the same over all workers
first_model = next(iter(models.values()))
assert all(model == first_model for worker, model in models.items())
# We can also train using the sklearn API
results = xgb.dask.run(client, train_with_sklearn, X, y)
if __name__ == '__main__':
main()

View File

@ -73,3 +73,13 @@ Callback API
.. autofunction:: xgboost.callback.reset_learning_rate
.. autofunction:: xgboost.callback.early_stop
Dask API
--------
.. automodule:: xgboost.dask
.. autofunction:: xgboost.dask.run
.. autofunction:: xgboost.dask.create_worker_dmatrix
.. autofunction:: xgboost.dask.get_local_data

View File

@ -3,8 +3,10 @@
from __future__ import absolute_import
import io
import sys
import shutil
import os
from setuptools import setup, find_packages
# import subprocess
sys.path.insert(0, '.')
@ -27,6 +29,10 @@ for libfile in libpath['find_lib_path']():
continue
print("Install libxgboost from: %s" % LIB_PATH)
# Get dmlc tracker script
shutil.copy('../dmlc-core/tracker/dmlc_tracker/tracker.py', 'xgboost/')
# Please use setup_pip.py for generating and deploying pip installation
# detailed instruction in setup_pip.py
setup(name='xgboost',

View File

@ -11,6 +11,7 @@ import os
from .core import DMatrix, Booster
from .training import train, cv
from . import rabit # noqa
from . import dask # noqa
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor

View File

@ -6,19 +6,20 @@ from __future__ import absolute_import
import sys
PY3 = (sys.version_info[0] == 3)
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = (str,)
def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8')
else:
STRING_TYPES = (basestring,) # pylint: disable=undefined-variable
def py_str(x):
"""convert c string back to python string"""
return x
@ -28,27 +29,25 @@ try:
except ImportError:
import pickle # noqa
# pandas
try:
from pandas import DataFrame
from pandas import MultiIndex
PANDAS_INSTALLED = True
except ImportError:
# pylint: disable=too-few-public-methods
class MultiIndex(object):
""" dummy for pandas.MultiIndex """
# pylint: disable=too-few-public-methods
class DataFrame(object):
""" dummy for pandas.DataFrame """
MultiIndex = object
DataFrame = object
PANDAS_INSTALLED = False
# dt
try:
# Workaround for #4473, compatibility with dask
if sys.__stdin__.closed:
sys.__stdin__ = None
import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
@ -60,6 +59,7 @@ except ImportError:
class DataTable(object):
""" dummy for datatable.DataTable """
DT_INSTALLED = False
# sklearn
@ -67,6 +67,7 @@ try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
try:
from sklearn.model_selection import KFold, StratifiedKFold
except ImportError:
@ -92,3 +93,20 @@ except ImportError:
XGBKFold = None
XGBStratifiedKFold = None
XGBLabelEncoder = None
# dask
try:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import Series as DaskSeries
from dask.array import Array as DaskArray
from distributed import get_worker as distributed_get_worker
DASK_INSTALLED = True
except ImportError:
DaskDataFrame = object
DaskSeries = object
DaskArray = object
distributed_get_worker = None
DASK_INSTALLED = False

View File

@ -0,0 +1,123 @@
# pylint: disable=wrong-import-position,wrong-import-order,import-error
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
import os
import sys
import math
import platform
import logging
from threading import Thread
from . import rabit
from .core import DMatrix
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
distributed_get_worker)
# Try to find the dmlc tracker script
# For developers it will be the following
TRACKER_PATH = os.path.dirname(__file__) + "/../../dmlc-core/tracker/dmlc_tracker"
sys.path.append(TRACKER_PATH)
try:
from tracker import RabitTracker # noqa
except ImportError:
# If packaged it will be local
from .tracker import RabitTracker # noqa
def _start_tracker(n_workers):
""" Start Rabit tracker """
host = distributed_get_worker().address
if '://' in host:
host = host.rsplit('://', 1)[1]
host, port = host.split(':')
port = int(port)
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env
def get_local_data(data):
"""
Unpacks a distributed data object to get the rows local to this worker
:param data: A distributed dask data object
:return: Local data partition e.g. numpy or pandas
"""
if isinstance(data, DaskArray):
total_partitions = len(data.chunks[0])
else:
total_partitions = data.npartitions
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
begin_partition = partition_size * rabit.get_rank()
end_partition = min(begin_partition + partition_size, total_partitions)
if isinstance(data, DaskArray):
return data.blocks[begin_partition:end_partition].compute()
return data.partitions[begin_partition:end_partition].compute()
def create_worker_dmatrix(*args, **kwargs):
"""
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
get the local components.
All dask dataframe arguments must use the same partitioning.
:param args: DMatrix constructor args.
:return: DMatrix object containing data local to current dask worker
"""
dmatrix_args = []
dmatrix_kwargs = {}
# Convert positional args
for arg in args:
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_args.append(get_local_data(arg))
else:
dmatrix_args.append(arg)
# Convert keyword args
for k, v in kwargs.items():
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_kwargs[k] = get_local_data(v)
else:
dmatrix_kwargs[k] = v
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
def _run_with_rabit(rabit_args, func, *args):
os.environ["OMP_NUM_THREADS"] = str(distributed_get_worker().ncores)
try:
rabit.init(rabit_args)
result = func(*args)
finally:
rabit.finalize()
return result
def run(client, func, *args):
"""
Launch arbitrary function on dask workers. Workers are connected by rabit, allowing
distributed training. The environment variable OMP_NUM_THREADS is defined on each worker
according to dask - this means that calls to xgb.train() will use the threads allocated by
dask by default, unless the user overrides the nthread parameter.
Note: Windows platforms are not officially supported. Contributions are welcome here.
:param client: Dask client representing the cluster
:param func: Python function to be executed by each worker. Typically contains xgboost
training code.
:param args: Arguments to be forwarded to func
:return: Dict containing the function return value for each worker
"""
if platform.system() == 'Windows':
logging.warning(
'Windows is not officially supported for dask/xgboost integration. Contributions '
'welcome.')
workers = list(client.scheduler_info()['workers'].keys())
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
return client.run(_run_with_rabit, rabit_args, func, *args)

View File

@ -23,7 +23,8 @@ ENV GOSU_VERSION 1.10
RUN \
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 && \
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
conda install dask
# Install lightweight sudo (not bound to TTY)
RUN set -ex; \

View File

@ -16,7 +16,8 @@ ENV PATH=/opt/python/bin:$PATH
# Install Python packages
RUN \
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
conda install dask
ENV GOSU_VERSION 1.10

View File

@ -0,0 +1,93 @@
import testing as tm
import pytest
import xgboost as xgb
import numpy as np
import sys
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
try:
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
import dask.array as da
except ImportError:
pass
pytestmark = pytest.mark.skipif(**tm.no_dask())
def run_train():
# Contains one label equal to rank
dmat = xgb.DMatrix([[0]], label=[xgb.rabit.get_rank()])
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
pred = bst.predict(dmat)
expected_result = np.average(range(xgb.rabit.get_world_size()))
assert all(p == expected_result for p in pred)
def test_train(client):
# Train two workers, the first has label 0, the second has label 1
# If they build the model together the output should be 0.5
xgb.dask.run(client, run_train)
# Run again to check we can have multiple sessions
xgb.dask.run(client, run_train)
def run_create_dmatrix(X, y, weights):
dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights)
# Expect this worker to get two partitions and concatenate them
assert dmat.num_row() == 50
def test_dask_dataframe(client):
n = 10
m = 100
partition_size = 25
X = dd.from_array(np.random.random((m, n)), partition_size)
y = dd.from_array(np.random.random(m), partition_size)
weights = dd.from_array(np.random.random(m), partition_size)
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
def test_dask_array(client):
n = 10
m = 100
partition_size = 25
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
weights = da.random.random(m, partition_size)
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
def run_get_local_data(X, y):
X_local = xgb.dask.get_local_data(X)
y_local = xgb.dask.get_local_data(y)
assert (X_local.shape == (50, 10))
assert (y_local.shape == (50,))
def test_get_local_data(client):
n = 10
m = 100
partition_size = 25
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
xgb.dask.run(client, run_get_local_data, X, y)
def run_sklearn():
# Contains one label equal to rank
X = [[0]]
y = [xgb.rabit.get_rank()]
model = xgb.XGBRegressor(learning_rate=1.0)
model.fit(X, y)
pred = model.predict(X)
expected_result = np.average(range(xgb.rabit.get_world_size()))
assert all(p == expected_result for p in pred)
return pred
def test_sklearn(client):
result = xgb.dask.run(client, run_sklearn)
print(result)

View File

@ -1,5 +1,5 @@
# coding: utf-8
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED, DASK_INSTALLED
def no_sklearn():
@ -7,6 +7,11 @@ def no_sklearn():
'reason': 'Scikit-Learn is not installed'}
def no_dask():
return {'condition': not DASK_INSTALLED,
'reason': 'Dask is not installed'}
def no_pandas():
return {'condition': not PANDAS_INSTALLED,
'reason': 'Pandas is not installed.'}