Add native support for Dask (#4473)
* Add native support for Dask * Add multi-GPU demo * Add sklearn example
This commit is contained in:
parent
55e645c5f5
commit
09b90d9329
3
.gitignore
vendored
3
.gitignore
vendored
@ -96,3 +96,6 @@ plugin/updater_gpu/test/cpp/data
|
||||
# files from R-package source install
|
||||
**/config.status
|
||||
R-package/src/Makevars
|
||||
|
||||
# Python install
|
||||
python-package/xgboost/tracker.py
|
||||
|
||||
20
demo/dask/README.md
Normal file
20
demo/dask/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# Dask Integration
|
||||
|
||||
[Dask](https://dask.org/) is a parallel computing library built on Python. Dask allows easy management of distributed workers and excels handling large distributed data science workflows.
|
||||
|
||||
The simple demo shows how to train and make predictions for an xgboost model on a distributed dask environment. We make use of first-class support in xgboost for launching dask workers. Workers launched in this manner are automatically connected via xgboosts underlying communication framework, Rabit. The calls to `xgb.train()` and `xgb.predict()` occur in parallel on each worker and are synchronized.
|
||||
|
||||
The GPU demo shows how to configure and use GPUs on the local machine for training on a large dataset.
|
||||
|
||||
## Requirements
|
||||
Dask is trivial to install using either pip or conda. [See here for official install documentation](https://docs.dask.org/en/latest/install.html).
|
||||
|
||||
The GPU demo requires [GPUtil](https://github.com/anderskm/gputil) for detecting system GPUs.
|
||||
|
||||
Install via `pip install gputil`
|
||||
|
||||
## Running the scripts
|
||||
```bash
|
||||
python dask_simple_demo.py
|
||||
python dask_gpu_demo.py
|
||||
```
|
||||
42
demo/dask/dask_gpu_demo.py
Normal file
42
demo/dask/dask_gpu_demo.py
Normal file
@ -0,0 +1,42 @@
|
||||
from dask.distributed import Client, LocalCluster
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import GPUtil
|
||||
import time
|
||||
|
||||
|
||||
# Define the function to be executed on each worker
|
||||
def train(X, y, available_devices):
|
||||
dtrain = xgb.dask.create_worker_dmatrix(X, y)
|
||||
local_device = available_devices[xgb.rabit.get_rank()]
|
||||
# Specify the GPU algorithm and device for this worker
|
||||
params = {"tree_method": "gpu_hist", "gpu_id": local_device}
|
||||
print("Worker {} starting training on {} rows".format(xgb.rabit.get_rank(), dtrain.num_row()))
|
||||
start = time.time()
|
||||
xgb.train(params, dtrain, num_boost_round=500)
|
||||
end = time.time()
|
||||
print("Worker {} finished training in {:0.2f}s".format(xgb.rabit.get_rank(), end - start))
|
||||
|
||||
|
||||
def main():
|
||||
max_devices = 16
|
||||
# Check which devices we have locally
|
||||
available_devices = GPUtil.getAvailable(limit=max_devices)
|
||||
# Use one worker per device
|
||||
cluster = LocalCluster(n_workers=len(available_devices), threads_per_worker=4)
|
||||
client = Client(cluster)
|
||||
|
||||
# Set up a relatively large regression problem
|
||||
n = 100
|
||||
m = 10000000
|
||||
partition_size = 100000
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
|
||||
xgb.dask.run(client, train, X, y, available_devices)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
68
demo/dask/dask_simple_demo.py
Normal file
68
demo/dask/dask_simple_demo.py
Normal file
@ -0,0 +1,68 @@
|
||||
from dask.distributed import Client, LocalCluster
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
# Define the function to be executed on each worker
|
||||
def train(X, y):
|
||||
print("Start training with worker #{}".format(xgb.rabit.get_rank()))
|
||||
# X,y are dask objects distributed across the cluster.
|
||||
# We must obtain the data local to this worker and convert it to DMatrix for training.
|
||||
# xgb.dask.create_worker_dmatrix follows the API exactly of the standard DMatrix constructor
|
||||
# (xgb.DMatrix()), except that it 'unpacks' dask distributed objects to obtain data local to
|
||||
# this worker
|
||||
dtrain = xgb.dask.create_worker_dmatrix(X, y)
|
||||
|
||||
# Train on the data. Each worker will communicate and synchronise during training. The output
|
||||
# model is expected to be identical on each worker.
|
||||
bst = xgb.train({}, dtrain)
|
||||
# Make predictions on local data
|
||||
pred = bst.predict(dtrain)
|
||||
print("Finished training with worker #{}".format(xgb.rabit.get_rank()))
|
||||
# Get text representation of the model
|
||||
return bst.get_dump()
|
||||
|
||||
|
||||
def train_with_sklearn(X, y):
|
||||
print("Training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
|
||||
X_local = xgb.dask.get_local_data(X)
|
||||
y_local = xgb.dask.get_local_data(y)
|
||||
model = xgb.XGBRegressor(n_estimators=10)
|
||||
model.fit(X_local, y_local)
|
||||
print("Finished training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
|
||||
return model.predict(X_local)
|
||||
|
||||
|
||||
def main():
|
||||
# Launch a very simple local cluster using two distributed workers with two CPU threads each
|
||||
cluster = LocalCluster(n_workers=2, threads_per_worker=2)
|
||||
client = Client(cluster)
|
||||
|
||||
# Generate some small test data as a dask array
|
||||
# These data frames are internally split into partitions of 20 rows each and then distributed
|
||||
# among workers, so we will have 5 partitions distributed among 2 workers
|
||||
# Note that the partition size MUST be consistent across different dask dataframes/arrays
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 20
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
|
||||
# xgb.dask.run launches an arbitrary function and its arguments on the cluster
|
||||
# Here train(X, y) will be called on each worker
|
||||
# This function blocks until all work is complete
|
||||
models = xgb.dask.run(client, train, X, y)
|
||||
|
||||
# models contains a dictionary mapping workers to results
|
||||
# We expect that the models are the same over all workers
|
||||
first_model = next(iter(models.values()))
|
||||
assert all(model == first_model for worker, model in models.items())
|
||||
|
||||
# We can also train using the sklearn API
|
||||
results = xgb.dask.run(client, train_with_sklearn, X, y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -73,3 +73,13 @@ Callback API
|
||||
.. autofunction:: xgboost.callback.reset_learning_rate
|
||||
|
||||
.. autofunction:: xgboost.callback.early_stop
|
||||
|
||||
Dask API
|
||||
--------
|
||||
.. automodule:: xgboost.dask
|
||||
|
||||
.. autofunction:: xgboost.dask.run
|
||||
|
||||
.. autofunction:: xgboost.dask.create_worker_dmatrix
|
||||
|
||||
.. autofunction:: xgboost.dask.get_local_data
|
||||
|
||||
@ -3,8 +3,10 @@
|
||||
from __future__ import absolute_import
|
||||
import io
|
||||
import sys
|
||||
import shutil
|
||||
import os
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
# import subprocess
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
@ -27,6 +29,10 @@ for libfile in libpath['find_lib_path']():
|
||||
continue
|
||||
|
||||
print("Install libxgboost from: %s" % LIB_PATH)
|
||||
|
||||
# Get dmlc tracker script
|
||||
shutil.copy('../dmlc-core/tracker/dmlc_tracker/tracker.py', 'xgboost/')
|
||||
|
||||
# Please use setup_pip.py for generating and deploying pip installation
|
||||
# detailed instruction in setup_pip.py
|
||||
setup(name='xgboost',
|
||||
|
||||
@ -11,6 +11,7 @@ import os
|
||||
from .core import DMatrix, Booster
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import dask # noqa
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
|
||||
@ -6,49 +6,48 @@ from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
PY3 = (sys.version_info[0] == 3)
|
||||
|
||||
if PY3:
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = (str,)
|
||||
|
||||
|
||||
def py_str(x):
|
||||
"""convert c string back to python string"""
|
||||
return x.decode('utf-8')
|
||||
else:
|
||||
STRING_TYPES = (basestring,) # pylint: disable=undefined-variable
|
||||
|
||||
|
||||
def py_str(x):
|
||||
"""convert c string back to python string"""
|
||||
return x
|
||||
|
||||
try:
|
||||
import cPickle as pickle # noqa
|
||||
import cPickle as pickle # noqa
|
||||
except ImportError:
|
||||
import pickle # noqa
|
||||
|
||||
import pickle # noqa
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame
|
||||
from pandas import MultiIndex
|
||||
|
||||
PANDAS_INSTALLED = True
|
||||
except ImportError:
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class MultiIndex(object):
|
||||
""" dummy for pandas.MultiIndex """
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class DataFrame(object):
|
||||
""" dummy for pandas.DataFrame """
|
||||
|
||||
MultiIndex = object
|
||||
DataFrame = object
|
||||
PANDAS_INSTALLED = False
|
||||
|
||||
# dt
|
||||
try:
|
||||
# Workaround for #4473, compatibility with dask
|
||||
if sys.__stdin__.closed:
|
||||
sys.__stdin__ = None
|
||||
import datatable
|
||||
|
||||
if hasattr(datatable, "Frame"):
|
||||
DataTable = datatable.Frame
|
||||
else:
|
||||
@ -60,6 +59,7 @@ except ImportError:
|
||||
class DataTable(object):
|
||||
""" dummy for datatable.DataTable """
|
||||
|
||||
|
||||
DT_INSTALLED = False
|
||||
|
||||
# sklearn
|
||||
@ -67,6 +67,7 @@ try:
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.base import RegressorMixin, ClassifierMixin
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
try:
|
||||
from sklearn.model_selection import KFold, StratifiedKFold
|
||||
except ImportError:
|
||||
@ -92,3 +93,20 @@ except ImportError:
|
||||
XGBKFold = None
|
||||
XGBStratifiedKFold = None
|
||||
XGBLabelEncoder = None
|
||||
|
||||
|
||||
# dask
|
||||
try:
|
||||
from dask.dataframe import DataFrame as DaskDataFrame
|
||||
from dask.dataframe import Series as DaskSeries
|
||||
from dask.array import Array as DaskArray
|
||||
from distributed import get_worker as distributed_get_worker
|
||||
|
||||
DASK_INSTALLED = True
|
||||
except ImportError:
|
||||
DaskDataFrame = object
|
||||
DaskSeries = object
|
||||
DaskArray = object
|
||||
distributed_get_worker = None
|
||||
|
||||
DASK_INSTALLED = False
|
||||
|
||||
123
python-package/xgboost/dask.py
Normal file
123
python-package/xgboost/dask.py
Normal file
@ -0,0 +1,123 @@
|
||||
# pylint: disable=wrong-import-position,wrong-import-order,import-error
|
||||
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import platform
|
||||
import logging
|
||||
from threading import Thread
|
||||
from . import rabit
|
||||
from .core import DMatrix
|
||||
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
|
||||
distributed_get_worker)
|
||||
|
||||
# Try to find the dmlc tracker script
|
||||
# For developers it will be the following
|
||||
TRACKER_PATH = os.path.dirname(__file__) + "/../../dmlc-core/tracker/dmlc_tracker"
|
||||
sys.path.append(TRACKER_PATH)
|
||||
try:
|
||||
from tracker import RabitTracker # noqa
|
||||
except ImportError:
|
||||
# If packaged it will be local
|
||||
from .tracker import RabitTracker # noqa
|
||||
|
||||
|
||||
def _start_tracker(n_workers):
|
||||
""" Start Rabit tracker """
|
||||
host = distributed_get_worker().address
|
||||
if '://' in host:
|
||||
host = host.rsplit('://', 1)[1]
|
||||
host, port = host.split(':')
|
||||
port = int(port)
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
||||
env.update(rabit_context.slave_envs())
|
||||
|
||||
rabit_context.start(n_workers)
|
||||
thread = Thread(target=rabit_context.join)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return env
|
||||
|
||||
|
||||
def get_local_data(data):
|
||||
"""
|
||||
Unpacks a distributed data object to get the rows local to this worker
|
||||
:param data: A distributed dask data object
|
||||
:return: Local data partition e.g. numpy or pandas
|
||||
"""
|
||||
if isinstance(data, DaskArray):
|
||||
total_partitions = len(data.chunks[0])
|
||||
else:
|
||||
total_partitions = data.npartitions
|
||||
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
|
||||
begin_partition = partition_size * rabit.get_rank()
|
||||
end_partition = min(begin_partition + partition_size, total_partitions)
|
||||
if isinstance(data, DaskArray):
|
||||
return data.blocks[begin_partition:end_partition].compute()
|
||||
|
||||
return data.partitions[begin_partition:end_partition].compute()
|
||||
|
||||
|
||||
def create_worker_dmatrix(*args, **kwargs):
|
||||
"""
|
||||
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
|
||||
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
|
||||
get the local components.
|
||||
|
||||
All dask dataframe arguments must use the same partitioning.
|
||||
|
||||
:param args: DMatrix constructor args.
|
||||
:return: DMatrix object containing data local to current dask worker
|
||||
"""
|
||||
dmatrix_args = []
|
||||
dmatrix_kwargs = {}
|
||||
# Convert positional args
|
||||
for arg in args:
|
||||
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
|
||||
dmatrix_args.append(get_local_data(arg))
|
||||
else:
|
||||
dmatrix_args.append(arg)
|
||||
|
||||
# Convert keyword args
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
|
||||
dmatrix_kwargs[k] = get_local_data(v)
|
||||
else:
|
||||
dmatrix_kwargs[k] = v
|
||||
|
||||
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
|
||||
|
||||
|
||||
def _run_with_rabit(rabit_args, func, *args):
|
||||
os.environ["OMP_NUM_THREADS"] = str(distributed_get_worker().ncores)
|
||||
try:
|
||||
rabit.init(rabit_args)
|
||||
result = func(*args)
|
||||
finally:
|
||||
rabit.finalize()
|
||||
return result
|
||||
|
||||
|
||||
def run(client, func, *args):
|
||||
"""
|
||||
Launch arbitrary function on dask workers. Workers are connected by rabit, allowing
|
||||
distributed training. The environment variable OMP_NUM_THREADS is defined on each worker
|
||||
according to dask - this means that calls to xgb.train() will use the threads allocated by
|
||||
dask by default, unless the user overrides the nthread parameter.
|
||||
|
||||
Note: Windows platforms are not officially supported. Contributions are welcome here.
|
||||
:param client: Dask client representing the cluster
|
||||
:param func: Python function to be executed by each worker. Typically contains xgboost
|
||||
training code.
|
||||
:param args: Arguments to be forwarded to func
|
||||
:return: Dict containing the function return value for each worker
|
||||
"""
|
||||
if platform.system() == 'Windows':
|
||||
logging.warning(
|
||||
'Windows is not officially supported for dask/xgboost integration. Contributions '
|
||||
'welcome.')
|
||||
workers = list(client.scheduler_info()['workers'].keys())
|
||||
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
|
||||
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
|
||||
return client.run(_run_with_rabit, rabit_args, func, *args)
|
||||
@ -23,7 +23,8 @@ ENV GOSU_VERSION 1.10
|
||||
RUN \
|
||||
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \
|
||||
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 && \
|
||||
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl
|
||||
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
|
||||
conda install dask
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
|
||||
@ -16,7 +16,8 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Install Python packages
|
||||
RUN \
|
||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz
|
||||
pip install numpy pytest scipy scikit-learn pandas matplotlib wheel kubernetes urllib3 graphviz && \
|
||||
conda install dask
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
|
||||
93
tests/python/test_with_dask.py
Normal file
93
tests/python/test_with_dask.py
Normal file
@ -0,0 +1,93 @@
|
||||
import testing as tm
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
try:
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
import dask.array as da
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
|
||||
def run_train():
|
||||
# Contains one label equal to rank
|
||||
dmat = xgb.DMatrix([[0]], label=[xgb.rabit.get_rank()])
|
||||
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
|
||||
pred = bst.predict(dmat)
|
||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
||||
assert all(p == expected_result for p in pred)
|
||||
|
||||
|
||||
def test_train(client):
|
||||
# Train two workers, the first has label 0, the second has label 1
|
||||
# If they build the model together the output should be 0.5
|
||||
xgb.dask.run(client, run_train)
|
||||
# Run again to check we can have multiple sessions
|
||||
xgb.dask.run(client, run_train)
|
||||
|
||||
|
||||
def run_create_dmatrix(X, y, weights):
|
||||
dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights)
|
||||
# Expect this worker to get two partitions and concatenate them
|
||||
assert dmat.num_row() == 50
|
||||
|
||||
|
||||
def test_dask_dataframe(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = dd.from_array(np.random.random((m, n)), partition_size)
|
||||
y = dd.from_array(np.random.random(m), partition_size)
|
||||
weights = dd.from_array(np.random.random(m), partition_size)
|
||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
||||
|
||||
|
||||
def test_dask_array(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
weights = da.random.random(m, partition_size)
|
||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
||||
|
||||
|
||||
def run_get_local_data(X, y):
|
||||
X_local = xgb.dask.get_local_data(X)
|
||||
y_local = xgb.dask.get_local_data(y)
|
||||
assert (X_local.shape == (50, 10))
|
||||
assert (y_local.shape == (50,))
|
||||
|
||||
|
||||
def test_get_local_data(client):
|
||||
n = 10
|
||||
m = 100
|
||||
partition_size = 25
|
||||
X = da.random.random((m, n), partition_size)
|
||||
y = da.random.random(m, partition_size)
|
||||
xgb.dask.run(client, run_get_local_data, X, y)
|
||||
|
||||
|
||||
def run_sklearn():
|
||||
# Contains one label equal to rank
|
||||
X = [[0]]
|
||||
y = [xgb.rabit.get_rank()]
|
||||
model = xgb.XGBRegressor(learning_rate=1.0)
|
||||
model.fit(X, y)
|
||||
pred = model.predict(X)
|
||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
||||
assert all(p == expected_result for p in pred)
|
||||
return pred
|
||||
|
||||
|
||||
def test_sklearn(client):
|
||||
result = xgb.dask.run(client, run_sklearn)
|
||||
print(result)
|
||||
@ -1,5 +1,5 @@
|
||||
# coding: utf-8
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED, DASK_INSTALLED
|
||||
|
||||
|
||||
def no_sklearn():
|
||||
@ -7,6 +7,11 @@ def no_sklearn():
|
||||
'reason': 'Scikit-Learn is not installed'}
|
||||
|
||||
|
||||
def no_dask():
|
||||
return {'condition': not DASK_INSTALLED,
|
||||
'reason': 'Dask is not installed'}
|
||||
|
||||
|
||||
def no_pandas():
|
||||
return {'condition': not PANDAS_INSTALLED,
|
||||
'reason': 'Pandas is not installed.'}
|
||||
@ -20,7 +25,7 @@ def no_dt():
|
||||
def no_matplotlib():
|
||||
reason = 'Matplotlib is not installed.'
|
||||
try:
|
||||
import matplotlib.pyplot as _ # noqa
|
||||
import matplotlib.pyplot as _ # noqa
|
||||
return {'condition': False,
|
||||
'reason': reason}
|
||||
except ImportError:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user