Rewrite Dask interface. (#4819)
This commit is contained in:
parent
562bb0ae31
commit
b8433c455a
@ -1,20 +0,0 @@
|
|||||||
# Dask Integration
|
|
||||||
|
|
||||||
[Dask](https://dask.org/) is a parallel computing library built on Python. Dask allows easy management of distributed workers and excels handling large distributed data science workflows.
|
|
||||||
|
|
||||||
The simple demo shows how to train and make predictions for an xgboost model on a distributed dask environment. We make use of first-class support in xgboost for launching dask workers. Workers launched in this manner are automatically connected via xgboosts underlying communication framework, Rabit. The calls to `xgb.train()` and `xgb.predict()` occur in parallel on each worker and are synchronized.
|
|
||||||
|
|
||||||
The GPU demo shows how to configure and use GPUs on the local machine for training on a large dataset.
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
Dask is trivial to install using either pip or conda. [See here for official install documentation](https://docs.dask.org/en/latest/install.html).
|
|
||||||
|
|
||||||
The GPU demo requires [GPUtil](https://github.com/anderskm/gputil) for detecting system GPUs.
|
|
||||||
|
|
||||||
Install via `pip install gputil`
|
|
||||||
|
|
||||||
## Running the scripts
|
|
||||||
```bash
|
|
||||||
python dask_simple_demo.py
|
|
||||||
python dask_gpu_demo.py
|
|
||||||
```
|
|
||||||
35
demo/dask/cpu_training.py
Normal file
35
demo/dask/cpu_training.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
from xgboost.dask import DaskDMatrix
|
||||||
|
from dask.distributed import Client
|
||||||
|
from dask.distributed import LocalCluster
|
||||||
|
from dask import array as da
|
||||||
|
|
||||||
|
|
||||||
|
def main(client):
|
||||||
|
n = 100
|
||||||
|
m = 100000
|
||||||
|
partition_size = 1000
|
||||||
|
X = da.random.random((m, n), partition_size)
|
||||||
|
y = da.random.random(m, partition_size)
|
||||||
|
|
||||||
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
output = xgb.dask.train(client,
|
||||||
|
{'verbosity': 2,
|
||||||
|
'nthread': 1,
|
||||||
|
'tree_method': 'hist'},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
|
bst = output['booster']
|
||||||
|
history = output['history']
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
|
print('Evaluation history:', history)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# or use any other clusters
|
||||||
|
cluster = LocalCluster(n_workers=4, threads_per_worker=1)
|
||||||
|
client = Client(cluster)
|
||||||
|
main(client)
|
||||||
@ -1,42 +0,0 @@
|
|||||||
from dask.distributed import Client, LocalCluster
|
|
||||||
import dask.dataframe as dd
|
|
||||||
import dask.array as da
|
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
|
||||||
import GPUtil
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
# Define the function to be executed on each worker
|
|
||||||
def train(X, y, available_devices):
|
|
||||||
dtrain = xgb.dask.create_worker_dmatrix(X, y)
|
|
||||||
local_device = available_devices[xgb.rabit.get_rank()]
|
|
||||||
# Specify the GPU algorithm and device for this worker
|
|
||||||
params = {"tree_method": "gpu_hist", "gpu_id": local_device}
|
|
||||||
print("Worker {} starting training on {} rows".format(xgb.rabit.get_rank(), dtrain.num_row()))
|
|
||||||
start = time.time()
|
|
||||||
xgb.train(params, dtrain, num_boost_round=500)
|
|
||||||
end = time.time()
|
|
||||||
print("Worker {} finished training in {:0.2f}s".format(xgb.rabit.get_rank(), end - start))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
max_devices = 16
|
|
||||||
# Check which devices we have locally
|
|
||||||
available_devices = GPUtil.getAvailable(limit=max_devices)
|
|
||||||
# Use one worker per device
|
|
||||||
cluster = LocalCluster(n_workers=len(available_devices), threads_per_worker=4)
|
|
||||||
client = Client(cluster)
|
|
||||||
|
|
||||||
# Set up a relatively large regression problem
|
|
||||||
n = 100
|
|
||||||
m = 10000000
|
|
||||||
partition_size = 100000
|
|
||||||
X = da.random.random((m, n), partition_size)
|
|
||||||
y = da.random.random(m, partition_size)
|
|
||||||
|
|
||||||
xgb.dask.run(client, train, X, y, available_devices)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,68 +0,0 @@
|
|||||||
from dask.distributed import Client, LocalCluster
|
|
||||||
import dask.dataframe as dd
|
|
||||||
import dask.array as da
|
|
||||||
import numpy as np
|
|
||||||
import xgboost as xgb
|
|
||||||
|
|
||||||
|
|
||||||
# Define the function to be executed on each worker
|
|
||||||
def train(X, y):
|
|
||||||
print("Start training with worker #{}".format(xgb.rabit.get_rank()))
|
|
||||||
# X,y are dask objects distributed across the cluster.
|
|
||||||
# We must obtain the data local to this worker and convert it to DMatrix for training.
|
|
||||||
# xgb.dask.create_worker_dmatrix follows the API exactly of the standard DMatrix constructor
|
|
||||||
# (xgb.DMatrix()), except that it 'unpacks' dask distributed objects to obtain data local to
|
|
||||||
# this worker
|
|
||||||
dtrain = xgb.dask.create_worker_dmatrix(X, y)
|
|
||||||
|
|
||||||
# Train on the data. Each worker will communicate and synchronise during training. The output
|
|
||||||
# model is expected to be identical on each worker.
|
|
||||||
bst = xgb.train({}, dtrain)
|
|
||||||
# Make predictions on local data
|
|
||||||
pred = bst.predict(dtrain)
|
|
||||||
print("Finished training with worker #{}".format(xgb.rabit.get_rank()))
|
|
||||||
# Get text representation of the model
|
|
||||||
return bst.get_dump()
|
|
||||||
|
|
||||||
|
|
||||||
def train_with_sklearn(X, y):
|
|
||||||
print("Training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
|
|
||||||
X_local = xgb.dask.get_local_data(X)
|
|
||||||
y_local = xgb.dask.get_local_data(y)
|
|
||||||
model = xgb.XGBRegressor(n_estimators=10)
|
|
||||||
model.fit(X_local, y_local)
|
|
||||||
print("Finished training with worker #{} using the sklearn API".format(xgb.rabit.get_rank()))
|
|
||||||
return model.predict(X_local)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Launch a very simple local cluster using two distributed workers with two CPU threads each
|
|
||||||
cluster = LocalCluster(n_workers=2, threads_per_worker=2)
|
|
||||||
client = Client(cluster)
|
|
||||||
|
|
||||||
# Generate some small test data as a dask array
|
|
||||||
# These data frames are internally split into partitions of 20 rows each and then distributed
|
|
||||||
# among workers, so we will have 5 partitions distributed among 2 workers
|
|
||||||
# Note that the partition size MUST be consistent across different dask dataframes/arrays
|
|
||||||
n = 10
|
|
||||||
m = 100
|
|
||||||
partition_size = 20
|
|
||||||
X = da.random.random((m, n), partition_size)
|
|
||||||
y = da.random.random(m, partition_size)
|
|
||||||
|
|
||||||
# xgb.dask.run launches an arbitrary function and its arguments on the cluster
|
|
||||||
# Here train(X, y) will be called on each worker
|
|
||||||
# This function blocks until all work is complete
|
|
||||||
models = xgb.dask.run(client, train, X, y)
|
|
||||||
|
|
||||||
# models contains a dictionary mapping workers to results
|
|
||||||
# We expect that the models are the same over all workers
|
|
||||||
first_model = next(iter(models.values()))
|
|
||||||
assert all(model == first_model for worker, model in models.items())
|
|
||||||
|
|
||||||
# We can also train using the sklearn API
|
|
||||||
results = xgb.dask.run(client, train_with_sklearn, X, y)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
41
demo/dask/gpu_training.py
Normal file
41
demo/dask/gpu_training.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from dask_cuda import LocalCUDACluster
|
||||||
|
from dask.distributed import Client
|
||||||
|
from dask import array as da
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost.dask import DaskDMatrix
|
||||||
|
|
||||||
|
|
||||||
|
def main(client):
|
||||||
|
n = 100
|
||||||
|
m = 100000
|
||||||
|
partition_size = 1000
|
||||||
|
X = da.random.random((m, n), partition_size)
|
||||||
|
y = da.random.random(m, partition_size)
|
||||||
|
|
||||||
|
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
|
||||||
|
# DMatrix scatter around workers.
|
||||||
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
# Use train method from xgboost.dask instead of xgboost. This
|
||||||
|
# distributed version of train returns a dictionary containing the
|
||||||
|
# resulting booster and evaluation history obtained from
|
||||||
|
# evaluation metrics.
|
||||||
|
output = xgb.dask.train(client,
|
||||||
|
{'verbosity': 2,
|
||||||
|
'nthread': 1,
|
||||||
|
'tree_method': 'gpu_hist'},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
|
bst = output['booster']
|
||||||
|
history = output['history']
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, bst, dtrain)
|
||||||
|
print('Evaluation history:', history)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# or use any other clusters
|
||||||
|
cluster = LocalCUDACluster(n_workers=4, threads_per_worker=1)
|
||||||
|
client = Client(cluster)
|
||||||
|
main(client)
|
||||||
30
demo/dask/sklearn_cpu_training.py
Normal file
30
demo/dask/sklearn_cpu_training.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
'''Dask interface demo:
|
||||||
|
|
||||||
|
Use scikit-learn regressor interface with CPU histogram tree method.'''
|
||||||
|
from dask.distributed import Client
|
||||||
|
from dask.distributed import LocalCluster
|
||||||
|
from dask import array as da
|
||||||
|
import xgboost
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cluster = LocalCluster(n_workers=2, silence_logs=False) # or use any other clusters
|
||||||
|
client = Client(cluster)
|
||||||
|
|
||||||
|
n = 100
|
||||||
|
m = 10000
|
||||||
|
partition_size = 100
|
||||||
|
X = da.random.random((m, n), partition_size)
|
||||||
|
y = da.random.random(m, partition_size)
|
||||||
|
|
||||||
|
regressor = xgboost.dask.DaskXGBRegressor(verbosity=2, n_estimators=2)
|
||||||
|
regressor.set_params(tree_method='hist')
|
||||||
|
regressor.client = client
|
||||||
|
|
||||||
|
regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
prediction = regressor.predict(X)
|
||||||
|
|
||||||
|
bst = regressor.get_booster()
|
||||||
|
history = regressor.evals_result()
|
||||||
|
|
||||||
|
print('Evaluation history:', history)
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
31
demo/dask/sklearn_gpu_training.py
Normal file
31
demo/dask/sklearn_gpu_training.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
'''Dask interface demo:
|
||||||
|
|
||||||
|
Use scikit-learn regressor interface with GPU histogram tree method.'''
|
||||||
|
|
||||||
|
from dask.distributed import Client
|
||||||
|
# It's recommended to use dask_cuda for GPU assignment
|
||||||
|
from dask_cuda import LocalCUDACluster
|
||||||
|
from dask import array as da
|
||||||
|
import xgboost
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cluster = LocalCUDACluster()
|
||||||
|
client = Client(cluster)
|
||||||
|
|
||||||
|
n = 100
|
||||||
|
m = 1000000
|
||||||
|
partition_size = 10000
|
||||||
|
X = da.random.random((m, n), partition_size)
|
||||||
|
y = da.random.random(m, partition_size)
|
||||||
|
|
||||||
|
regressor = xgboost.dask.DaskXGBRegressor(verbosity=2)
|
||||||
|
regressor.set_params(tree_method='gpu_hist')
|
||||||
|
regressor.client = client
|
||||||
|
|
||||||
|
regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
prediction = regressor.predict(X)
|
||||||
|
|
||||||
|
bst = regressor.get_booster()
|
||||||
|
history = regressor.evals_result()
|
||||||
|
|
||||||
|
print('Evaluation history:', history)
|
||||||
@ -80,9 +80,10 @@ Dask API
|
|||||||
--------
|
--------
|
||||||
.. automodule:: xgboost.dask
|
.. automodule:: xgboost.dask
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.run
|
.. autofunction:: xgboost.dask.DaskDMatrix
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.create_worker_dmatrix
|
.. autofunction:: xgboost.dask.predict
|
||||||
|
|
||||||
.. autofunction:: xgboost.dask.get_local_data
|
.. autofunction:: xgboost.dask.DaskXGBClassifier
|
||||||
|
|
||||||
|
.. autofunction:: xgboost.dask.DaskXGBRegressor
|
||||||
|
|||||||
92
doc/tutorials/dask.rst
Normal file
92
doc/tutorials/dask.rst
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
#############################
|
||||||
|
Distributed XGBoost with Dask
|
||||||
|
#############################
|
||||||
|
|
||||||
|
`Dask <https://dask.org>`_ is a parallel computing library built on Python. Dask allows
|
||||||
|
easy management of distributed workers and excels handling large distributed data science
|
||||||
|
workflows. The implementation in XGBoost originates from `dask-xgboost
|
||||||
|
<https://github.com/dask/dask-xgboost>`_ with some extended functionalities and a
|
||||||
|
different interface. Right now it is still under construction and may change (with proper
|
||||||
|
warnings) in the future.
|
||||||
|
|
||||||
|
************
|
||||||
|
Requirements
|
||||||
|
************
|
||||||
|
|
||||||
|
Dask is trivial to install using either pip or conda. `See here for official install
|
||||||
|
documentation <https://docs.dask.org/en/latest/install.html>`_. For accelerating XGBoost
|
||||||
|
with GPU, `dask-cuda <https://github.com/rapidsai/dask-cuda>`_ is recommended for creating
|
||||||
|
GPU clusters.
|
||||||
|
|
||||||
|
|
||||||
|
********
|
||||||
|
Overview
|
||||||
|
********
|
||||||
|
|
||||||
|
There are 3 different components in dask from a user's perspective, namely a scheduler,
|
||||||
|
bunch of workers and some clients connecting to the scheduler. For using XGBoost with
|
||||||
|
dask, one needs to call XGBoost dask interface from the client side. A small example
|
||||||
|
illustrates the basic usage:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
cluster = LocalCluster(n_workers=4, threads_per_worker=1)
|
||||||
|
client = Client(cluster)
|
||||||
|
|
||||||
|
dtrain = xgb.dask.DaskDMatrix(client, X, y) # X and y are dask dataframes or arrays
|
||||||
|
|
||||||
|
output = xgb.dask.train(client,
|
||||||
|
{'verbosity': 2,
|
||||||
|
'nthread': 1,
|
||||||
|
'tree_method': 'hist'},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
|
|
||||||
|
Here we first create a cluster in signle-node mode wtih ``distributed.LocalCluster``, then
|
||||||
|
connect a ``client`` to this cluster, setting up environment for later computation.
|
||||||
|
Similar to non-distributed interface, we create a ``DMatrix`` object and pass it to
|
||||||
|
``train`` along with some other parameters. Except in dask interface, client is an extra
|
||||||
|
argument for carrying out the computation, when set to ``None`` XGBoost will use the
|
||||||
|
default client returned from dask.
|
||||||
|
|
||||||
|
There are two sets of APIs implemented in XGBoost. The first set is functional API
|
||||||
|
illustrated in above example. Given the data and a set of parameters, `train` function
|
||||||
|
returns a model and the computation history as Python dictionary
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{'booster': Booster,
|
||||||
|
'history': dict}
|
||||||
|
|
||||||
|
For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predict``
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, output, dtrain)
|
||||||
|
|
||||||
|
Or equivalently, pass ``output['booster']``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
||||||
|
|
||||||
|
Here ``prediction`` is a dask ``Array`` object containing predictions from model.
|
||||||
|
|
||||||
|
Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-Learn
|
||||||
|
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
|
||||||
|
for more examples.
|
||||||
|
|
||||||
|
|
||||||
|
***********
|
||||||
|
Limitations
|
||||||
|
***********
|
||||||
|
|
||||||
|
Basic functionalities including training and generating predictions for regression and
|
||||||
|
classification are implemented. But there are still some other limitations we haven't
|
||||||
|
addressed yet.
|
||||||
|
|
||||||
|
- Label encoding for Scikit-Learn classifier.
|
||||||
|
- Ranking
|
||||||
|
- Callback functions are not tested.
|
||||||
|
- To use cross validation one needs to explicitly train different models instead of using
|
||||||
|
a functional API like ``xgboost.cv``.
|
||||||
@ -21,3 +21,4 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
|||||||
param_tuning
|
param_tuning
|
||||||
external_memory
|
external_memory
|
||||||
custom_metric_obj
|
custom_metric_obj
|
||||||
|
dask
|
||||||
|
|||||||
@ -11,9 +11,9 @@ import os
|
|||||||
from .core import DMatrix, Booster
|
from .core import DMatrix, Booster
|
||||||
from .training import train, cv
|
from .training import train, cv
|
||||||
from . import rabit # noqa
|
from . import rabit # noqa
|
||||||
from . import dask # noqa
|
|
||||||
from . import tracker # noqa
|
from . import tracker # noqa
|
||||||
from .tracker import RabitTracker # noqa
|
from .tracker import RabitTracker # noqa
|
||||||
|
from . import dask
|
||||||
try:
|
try:
|
||||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||||
@ -30,4 +30,4 @@ __all__ = ['DMatrix', 'Booster',
|
|||||||
'RabitTracker',
|
'RabitTracker',
|
||||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||||
'XGBRFClassifier', 'XGBRFRegressor',
|
'XGBRFClassifier', 'XGBRFRegressor',
|
||||||
'plot_importance', 'plot_tree', 'to_graphviz']
|
'plot_importance', 'plot_tree', 'to_graphviz', 'dask']
|
||||||
|
|||||||
@ -96,14 +96,17 @@ except ImportError:
|
|||||||
|
|
||||||
# pandas
|
# pandas
|
||||||
try:
|
try:
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame, Series
|
||||||
from pandas import MultiIndex
|
from pandas import MultiIndex
|
||||||
|
from pandas import concat as pandas_concat
|
||||||
|
|
||||||
PANDAS_INSTALLED = True
|
PANDAS_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
MultiIndex = object
|
MultiIndex = object
|
||||||
DataFrame = object
|
DataFrame = object
|
||||||
|
Series = object
|
||||||
|
pandas_concat = None
|
||||||
PANDAS_INSTALLED = False
|
PANDAS_INSTALLED = False
|
||||||
|
|
||||||
# dt
|
# dt
|
||||||
@ -169,16 +172,35 @@ except ImportError:
|
|||||||
|
|
||||||
# dask
|
# dask
|
||||||
try:
|
try:
|
||||||
from dask.dataframe import DataFrame as DaskDataFrame
|
import dask
|
||||||
from dask.dataframe import Series as DaskSeries
|
from dask import delayed
|
||||||
from dask.array import Array as DaskArray
|
from dask import dataframe as dd
|
||||||
|
from dask import array as da
|
||||||
|
from dask.distributed import Client, get_client
|
||||||
|
from dask.distributed import comm as distributed_comm
|
||||||
|
from dask.distributed import wait as distributed_wait
|
||||||
from distributed import get_worker as distributed_get_worker
|
from distributed import get_worker as distributed_get_worker
|
||||||
|
|
||||||
DASK_INSTALLED = True
|
DASK_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
DaskDataFrame = object
|
dd = None
|
||||||
DaskSeries = object
|
da = None
|
||||||
DaskArray = object
|
Client = None
|
||||||
|
delayed = None
|
||||||
|
get_client = None
|
||||||
|
distributed_comm = None
|
||||||
|
distributed_wait = None
|
||||||
distributed_get_worker = None
|
distributed_get_worker = None
|
||||||
|
dask = None
|
||||||
|
|
||||||
DASK_INSTALLED = False
|
DASK_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sparse
|
||||||
|
import scipy.sparse as scipy_sparse
|
||||||
|
SCIPY_INSTALLED = True
|
||||||
|
except ImportError:
|
||||||
|
sparse = False
|
||||||
|
scipy_sparse = False
|
||||||
|
SCIPY_INSTALLED = False
|
||||||
|
|||||||
@ -106,6 +106,28 @@ def from_cstr_to_pystr(data, length):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _expect(expectations, got):
|
||||||
|
'''Translate input error into string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
expectations: sequence
|
||||||
|
a list of expected value.
|
||||||
|
got:
|
||||||
|
actual input
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
msg: str'''
|
||||||
|
msg = 'Expecting '
|
||||||
|
for t in range(len(expectations) - 1):
|
||||||
|
msg += str(expectations[t])
|
||||||
|
msg += ' or '
|
||||||
|
msg += str(expectations[-1])
|
||||||
|
msg += '. Got ' + str(got)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def _log_callback(msg):
|
def _log_callback(msg):
|
||||||
"""Redirect logs from native library into Python console"""
|
"""Redirect logs from native library into Python console"""
|
||||||
print("{0:s}".format(py_str(msg)))
|
print("{0:s}".format(py_str(msg)))
|
||||||
@ -513,7 +535,8 @@ class DMatrix(object):
|
|||||||
and type if memory use is a concern.
|
and type if memory use is a concern.
|
||||||
"""
|
"""
|
||||||
if len(mat.shape) != 2:
|
if len(mat.shape) != 2:
|
||||||
raise ValueError('Input numpy.ndarray must be 2 dimensional')
|
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
||||||
|
mat.shape)
|
||||||
# flatten the array by rows and ensure it is float32.
|
# flatten the array by rows and ensure it is float32.
|
||||||
# we try to avoid data copies if possible (reshape returns a view when possible
|
# we try to avoid data copies if possible (reshape returns a view when possible
|
||||||
# and we explicitly tell np.array to try and avoid copying)
|
# and we explicitly tell np.array to try and avoid copying)
|
||||||
@ -1010,7 +1033,7 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
for d in cache:
|
for d in cache:
|
||||||
if not isinstance(d, DMatrix):
|
if not isinstance(d, DMatrix):
|
||||||
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
|
raise TypeError('invalid cache item: {}'.format(type(d).__name__), cache)
|
||||||
self._validate_features(d)
|
self._validate_features(d)
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||||
@ -1353,6 +1376,10 @@ class Booster(object):
|
|||||||
if pred_interactions:
|
if pred_interactions:
|
||||||
option_mask |= 0x10
|
option_mask |= 0x10
|
||||||
|
|
||||||
|
if not isinstance(data, DMatrix):
|
||||||
|
raise TypeError('Expecting data to be a DMatrix object, got: ',
|
||||||
|
type(data))
|
||||||
|
|
||||||
if validate_features:
|
if validate_features:
|
||||||
self._validate_features(data)
|
self._validate_features(data)
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +1,48 @@
|
|||||||
# pylint: disable=wrong-import-position,wrong-import-order,import-error
|
# pylint: disable=too-many-arguments, too-many-locals
|
||||||
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
|
"""Dask extensions for distributed training. See
|
||||||
import os
|
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||||
import math
|
tutorial. Also xgboost/demo/dask for some examples.
|
||||||
|
|
||||||
|
There are two sets of APIs in this module, one is the functional API including
|
||||||
|
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
|
||||||
|
inherited from single-node Scikit-Learn interface.
|
||||||
|
|
||||||
|
The implementation is heavily influenced by dask_xgboost:
|
||||||
|
https://github.com/dask/dask-xgboost
|
||||||
|
|
||||||
|
"""
|
||||||
import platform
|
import platform
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from .core import DMatrix
|
|
||||||
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
|
|
||||||
distributed_get_worker)
|
|
||||||
|
|
||||||
|
from .compat import DASK_INSTALLED
|
||||||
|
from .compat import distributed_get_worker, distributed_wait, distributed_comm
|
||||||
|
from .compat import da, dd, delayed, get_client
|
||||||
|
from .compat import sparse, scipy_sparse
|
||||||
|
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||||
|
|
||||||
|
from .core import DMatrix, Booster, _expect
|
||||||
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker
|
from .tracker import RabitTracker
|
||||||
|
from .sklearn import XGBModel, XGBClassifierBase
|
||||||
|
|
||||||
|
# Current status is considered as initial support, many features are
|
||||||
|
# not properly supported yet.
|
||||||
|
#
|
||||||
|
# TODOs:
|
||||||
|
# - Callback.
|
||||||
|
# - Label encoding.
|
||||||
|
# - CV
|
||||||
|
# - Ranking
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(n_workers):
|
def _start_tracker(host, n_workers):
|
||||||
""" Start Rabit tracker """
|
"""Start Rabit tracker """
|
||||||
host = distributed_get_worker().address
|
|
||||||
if '://' in host:
|
|
||||||
host = host.rsplit('://', 1)[1]
|
|
||||||
host, port = host.split(':')
|
|
||||||
port = int(port)
|
|
||||||
env = {'DMLC_NUM_WORKER': n_workers}
|
env = {'DMLC_NUM_WORKER': n_workers}
|
||||||
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
||||||
env.update(rabit_context.slave_envs())
|
env.update(rabit_context.slave_envs())
|
||||||
@ -31,91 +54,556 @@ def _start_tracker(n_workers):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def get_local_data(data):
|
def _assert_dask_installed():
|
||||||
"""
|
if not DASK_INSTALLED:
|
||||||
Unpacks a distributed data object to get the rows local to this worker
|
raise ImportError(
|
||||||
|
'Dask needs to be installed in order to use this module')
|
||||||
:param data: A distributed dask data object
|
|
||||||
:return: Local data partition e.g. numpy or pandas
|
|
||||||
"""
|
|
||||||
if isinstance(data, DaskArray):
|
|
||||||
total_partitions = len(data.chunks[0])
|
|
||||||
else:
|
|
||||||
total_partitions = data.npartitions
|
|
||||||
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
|
|
||||||
begin_partition = partition_size * rabit.get_rank()
|
|
||||||
end_partition = min(begin_partition + partition_size, total_partitions)
|
|
||||||
if isinstance(data, DaskArray):
|
|
||||||
return data.blocks[begin_partition:end_partition].compute()
|
|
||||||
|
|
||||||
return data.partitions[begin_partition:end_partition].compute()
|
|
||||||
|
|
||||||
|
|
||||||
def create_worker_dmatrix(*args, **kwargs):
|
class RabitContext:
|
||||||
"""
|
'''A context controling rabit initialization and finalization.'''
|
||||||
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
|
def __init__(self, args):
|
||||||
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
|
self.args = args
|
||||||
get the local components.
|
|
||||||
|
|
||||||
All dask dataframe arguments must use the same partitioning.
|
def __enter__(self):
|
||||||
|
rabit.init(self.args)
|
||||||
|
logging.debug('-------------- rabit say hello ------------------')
|
||||||
|
|
||||||
:param args: DMatrix constructor args.
|
def __exit__(self, *args):
|
||||||
:return: DMatrix object containing data local to current dask worker
|
|
||||||
"""
|
|
||||||
dmatrix_args = []
|
|
||||||
dmatrix_kwargs = {}
|
|
||||||
# Convert positional args
|
|
||||||
for arg in args:
|
|
||||||
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
|
|
||||||
dmatrix_args.append(get_local_data(arg))
|
|
||||||
else:
|
|
||||||
dmatrix_args.append(arg)
|
|
||||||
|
|
||||||
# Convert keyword args
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
|
|
||||||
dmatrix_kwargs[k] = get_local_data(v)
|
|
||||||
else:
|
|
||||||
dmatrix_kwargs[k] = v
|
|
||||||
|
|
||||||
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_with_rabit(rabit_args, func, *args):
|
|
||||||
worker = distributed_get_worker()
|
|
||||||
try:
|
|
||||||
os.environ["OMP_NUM_THREADS"] = str(worker.ncores)
|
|
||||||
except AttributeError:
|
|
||||||
os.environ["OMP_NUM_THREADS"] = str(worker.nthreads)
|
|
||||||
try:
|
|
||||||
rabit.init(rabit_args)
|
|
||||||
result = func(*args)
|
|
||||||
finally:
|
|
||||||
rabit.finalize()
|
rabit.finalize()
|
||||||
return result
|
logging.debug('--------------- rabit say bye ------------------')
|
||||||
|
|
||||||
|
|
||||||
def run(client, func, *args):
|
def concat(value):
|
||||||
"""Launch arbitrary function on dask workers. Workers are connected by rabit,
|
'''To be replaced with dask builtin.'''
|
||||||
allowing distributed training. The environment variable OMP_NUM_THREADS is
|
if isinstance(value[0], numpy.ndarray):
|
||||||
defined on each worker according to dask - this means that calls to
|
return numpy.concatenate(value, axis=0)
|
||||||
xgb.train() will use the threads allocated by dask by default, unless the
|
if scipy_sparse and isinstance(value[0], scipy_sparse.spmatrix):
|
||||||
user overrides the nthread parameter.
|
return scipy_sparse.vstack(value, format='csr')
|
||||||
|
if sparse and isinstance(value[0], sparse.SparseArray):
|
||||||
|
return sparse.concatenate(value, axis=0)
|
||||||
|
if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)):
|
||||||
|
return pandas_concat(value, axis=0)
|
||||||
|
return dd.multi.concat(list(value), axis=0)
|
||||||
|
|
||||||
Note: Windows platforms are not officially
|
|
||||||
supported. Contributions are welcome here.
|
|
||||||
|
|
||||||
:param client: Dask client representing the cluster
|
def _xgb_get_client(client):
|
||||||
:param func: Python function to be executed by each worker. Typically
|
'''Simple wrapper around testing None.'''
|
||||||
contains xgboost training code.
|
ret = get_client() if client is None else client
|
||||||
:param args: Arguments to be forwarded to func
|
return ret
|
||||||
:return: Dict containing the function return value for each worker
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
class DaskDMatrix:
|
||||||
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
|
'''DMatrix holding on references to Dask DataFrame or Dask Array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client: dask.distributed.Client
|
||||||
|
Specify the dask client used for training. Use default client
|
||||||
|
returned from dask if it's set to None.
|
||||||
|
data : dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
data source of DMatrix.
|
||||||
|
label: dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
label used for trainin.
|
||||||
|
missing : float, optional
|
||||||
|
Value in the input data (e.g. `numpy.ndarray`) which needs
|
||||||
|
to be present as a missing value. If None, defaults to np.nan.
|
||||||
|
weight : dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
Weight for each instance.
|
||||||
|
feature_names : list, optional
|
||||||
|
Set names for features.
|
||||||
|
feature_types : list, optional
|
||||||
|
Set types for features
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
_feature_names = None # for previous version's pickle
|
||||||
|
_feature_types = None
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
label=None,
|
||||||
|
missing=None,
|
||||||
|
weight=None,
|
||||||
|
feature_names=None,
|
||||||
|
feature_types=None):
|
||||||
|
_assert_dask_installed()
|
||||||
|
|
||||||
|
self._feature_names = feature_names
|
||||||
|
self._feature_types = feature_types
|
||||||
|
self._missing = missing
|
||||||
|
|
||||||
|
if len(data.shape) != 2:
|
||||||
|
_expect('2 dimensions input', data.shape)
|
||||||
|
self.n_rows = data.shape[0]
|
||||||
|
self.n_cols = data.shape[1]
|
||||||
|
|
||||||
|
if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)):
|
||||||
|
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
|
||||||
|
if not any(
|
||||||
|
isinstance(label, t)
|
||||||
|
for t in (dd.DataFrame, da.Array, dd.Series, type(None))):
|
||||||
|
raise TypeError(
|
||||||
|
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||||
|
|
||||||
|
self.worker_map = None
|
||||||
|
self.has_label = label is not None
|
||||||
|
self.has_weights = weight is not None
|
||||||
|
|
||||||
|
client = _xgb_get_client(client)
|
||||||
|
client.sync(self.map_local_data, client, data, label, weight)
|
||||||
|
|
||||||
|
async def map_local_data(self, client, data, label=None, weights=None):
|
||||||
|
'''Obtain references to local data.'''
|
||||||
|
data = data.persist()
|
||||||
|
if label is not None:
|
||||||
|
label = label.persist()
|
||||||
|
if weights is not None:
|
||||||
|
weights = weights.persist()
|
||||||
|
# Breaking data into partitions, a trick borrowed from dask_xgboost.
|
||||||
|
|
||||||
|
# `to_delayed` downgrades high-level objects into numpy or pandas
|
||||||
|
# equivalents.
|
||||||
|
X_parts = data.to_delayed()
|
||||||
|
if isinstance(X_parts, numpy.ndarray):
|
||||||
|
assert X_parts.shape[1] == 1
|
||||||
|
X_parts = X_parts.flatten().tolist()
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
y_parts = label.to_delayed()
|
||||||
|
if isinstance(y_parts, numpy.ndarray):
|
||||||
|
assert y_parts.ndim == 1 or y_parts.shape[1] == 1
|
||||||
|
y_parts = y_parts.flatten().tolist()
|
||||||
|
if weights is not None:
|
||||||
|
w_parts = weights.to_delayed()
|
||||||
|
if isinstance(w_parts, numpy.ndarray):
|
||||||
|
assert w_parts.ndim == 1 or w_parts.shape[1] == 1
|
||||||
|
w_parts = w_parts.flatten().tolist()
|
||||||
|
|
||||||
|
parts = [X_parts]
|
||||||
|
if label is not None:
|
||||||
|
assert len(X_parts) == len(
|
||||||
|
y_parts), 'Partitions between X and y are not consistent'
|
||||||
|
parts.append(y_parts)
|
||||||
|
if weights is not None:
|
||||||
|
assert len(X_parts) == len(
|
||||||
|
w_parts), 'Partitions between X and weight are not consistent.'
|
||||||
|
parts.append(w_parts)
|
||||||
|
parts = list(map(delayed, zip(*parts)))
|
||||||
|
|
||||||
|
parts = client.compute(parts)
|
||||||
|
await distributed_wait(parts) # async wait for parts to be computed
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
assert part.status == 'finished'
|
||||||
|
|
||||||
|
key_to_partition = {part.key: part for part in parts}
|
||||||
|
who_has = await client.scheduler.who_has(
|
||||||
|
keys=[part.key for part in parts])
|
||||||
|
|
||||||
|
worker_map = defaultdict(list)
|
||||||
|
for key, workers in who_has.items():
|
||||||
|
worker_map[next(iter(workers))].append(key_to_partition[key])
|
||||||
|
|
||||||
|
self.worker_map = worker_map
|
||||||
|
|
||||||
|
def get_worker_parts(self, worker):
|
||||||
|
'''Get mapped parts of data in each worker.'''
|
||||||
|
list_of_parts = self.worker_map[worker.address]
|
||||||
|
assert list_of_parts, 'data in ' + worker.address + ' was moved.'
|
||||||
|
assert isinstance(list_of_parts, list)
|
||||||
|
|
||||||
|
# `get_worker_parts` is launched inside worker. In dask side
|
||||||
|
# this should be equal to `worker._get_client`.
|
||||||
|
client = get_client()
|
||||||
|
list_of_parts = client.gather(list_of_parts)
|
||||||
|
|
||||||
|
if self.has_label:
|
||||||
|
if self.has_weights:
|
||||||
|
data, labels, weights = zip(*list_of_parts)
|
||||||
|
else:
|
||||||
|
data, labels = zip(*list_of_parts)
|
||||||
|
weights = None
|
||||||
|
else:
|
||||||
|
data = [d[0] for d in list_of_parts]
|
||||||
|
labels = None
|
||||||
|
weights = None
|
||||||
|
return data, labels, weights
|
||||||
|
|
||||||
|
def get_worker_data(self, worker):
|
||||||
|
'''Get data that local to worker.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
worker: The worker used as key to data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A DMatrix object.
|
||||||
|
|
||||||
|
'''
|
||||||
|
data, labels, weights = self.get_worker_parts(worker)
|
||||||
|
|
||||||
|
data = concat(data)
|
||||||
|
|
||||||
|
if self.has_label:
|
||||||
|
labels = concat(labels)
|
||||||
|
else:
|
||||||
|
labels = None
|
||||||
|
if self.has_weights:
|
||||||
|
weights = concat(weights)
|
||||||
|
else:
|
||||||
|
weights = None
|
||||||
|
|
||||||
|
dmatrix = DMatrix(data,
|
||||||
|
labels,
|
||||||
|
weight=weights,
|
||||||
|
missing=self._missing,
|
||||||
|
feature_names=self._feature_names,
|
||||||
|
feature_types=self._feature_types)
|
||||||
|
return dmatrix
|
||||||
|
|
||||||
|
def get_worker_data_shape(self, worker):
|
||||||
|
'''Get the shape of data X in each worker.'''
|
||||||
|
data, _, _ = self.get_worker_parts(worker)
|
||||||
|
|
||||||
|
shapes = [d.shape for d in data]
|
||||||
|
rows = 0
|
||||||
|
cols = 0
|
||||||
|
for shape in shapes:
|
||||||
|
rows += shape[0]
|
||||||
|
cols += shape[1]
|
||||||
|
return (rows, cols)
|
||||||
|
|
||||||
|
def num_row(self):
|
||||||
|
return self.n_rows
|
||||||
|
|
||||||
|
def num_col(self):
|
||||||
|
return self.n_cols
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rabit_args(worker_map, client):
|
||||||
|
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||||
|
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||||
|
|
||||||
|
env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
|
||||||
|
len(worker_map))
|
||||||
|
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||||
|
return rabit_args
|
||||||
|
|
||||||
|
# train and predict methods are supposed to be "functional", which meets the
|
||||||
|
# dask paradigm. But as a side effect, the `evals_result` in single-node API
|
||||||
|
# is no longer supported since it mutates the input parameter, and it's not
|
||||||
|
# intuitive to sync the mutation result. Therefore, a dictionary containing
|
||||||
|
# evaluation history is instead returned.
|
||||||
|
|
||||||
|
|
||||||
|
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||||
|
'''Train XGBoost model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client: dask.distributed.Client
|
||||||
|
Specify the dask client used for training. Use default client
|
||||||
|
returned from dask if it's set to None.
|
||||||
|
|
||||||
|
Other parameters are the same as `xgboost.train` except for `evals_result`,
|
||||||
|
which is returned as part of function return value instead of argument.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
results: dict
|
||||||
|
A dictionary containing trained booster and evaluation history.
|
||||||
|
`history` field is the same as `eval_result` from `xgboost.train`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{'booster': xgboost.Booster,
|
||||||
|
'history': {'train': {'logloss': ['0.48253', '0.35953']},
|
||||||
|
'eval': {'logloss': ['0.480385', '0.357756']}}}
|
||||||
|
|
||||||
|
'''
|
||||||
|
_assert_dask_installed()
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == 'Windows':
|
||||||
logging.warning('Windows is not officially supported for dask/xgboost'
|
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||||
'integration. Contributions welcome.')
|
msg += ' contribution are welcomed.'
|
||||||
workers = list(client.scheduler_info()['workers'].keys())
|
logging.warning(msg)
|
||||||
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
|
|
||||||
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
|
if 'evals_result' in kwargs.keys():
|
||||||
return client.run(_run_with_rabit, rabit_args, func, *args)
|
raise ValueError(
|
||||||
|
'evals_result is not supported in dask interface.',
|
||||||
|
'The evaluation history is returned as result of training.')
|
||||||
|
|
||||||
|
client = _xgb_get_client(client)
|
||||||
|
|
||||||
|
worker_map = dtrain.worker_map
|
||||||
|
rabit_args = _get_rabit_args(worker_map, client)
|
||||||
|
|
||||||
|
def dispatched_train(worker_id):
|
||||||
|
'''Perform training on worker.'''
|
||||||
|
logging.info('Training on %d', worker_id)
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
local_dtrain = dtrain.get_worker_data(worker)
|
||||||
|
|
||||||
|
local_evals = []
|
||||||
|
if evals:
|
||||||
|
for mat, name in evals:
|
||||||
|
local_mat = mat.get_worker_data(worker)
|
||||||
|
local_evals.append((local_mat, name))
|
||||||
|
|
||||||
|
with RabitContext(rabit_args):
|
||||||
|
local_history = {}
|
||||||
|
local_param = params.copy() # just to be consistent
|
||||||
|
bst = worker_train(params=local_param,
|
||||||
|
dtrain=local_dtrain,
|
||||||
|
*args,
|
||||||
|
evals_result=local_history,
|
||||||
|
evals=local_evals,
|
||||||
|
**kwargs)
|
||||||
|
ret = {'booster': bst, 'history': local_history}
|
||||||
|
if rabit.get_rank() != 0:
|
||||||
|
ret = None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
futures = client.map(dispatched_train,
|
||||||
|
range(len(worker_map)),
|
||||||
|
workers=list(worker_map.keys()))
|
||||||
|
results = client.gather(futures)
|
||||||
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def predict(client, model, data, *args):
|
||||||
|
'''Run prediction with a trained booster.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client: dask.distributed.Client
|
||||||
|
Specify the dask client used for training. Use default client
|
||||||
|
returned from dask if it's set to None.
|
||||||
|
model: A Booster or a dictionary returned by `xgboost.dask.train`.
|
||||||
|
The trained model.
|
||||||
|
data: DaskDMatrix
|
||||||
|
Input data used for prediction.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
prediction: dask.array.Array
|
||||||
|
|
||||||
|
'''
|
||||||
|
_assert_dask_installed()
|
||||||
|
if isinstance(model, Booster):
|
||||||
|
booster = model
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
booster = model['booster']
|
||||||
|
else:
|
||||||
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
|
|
||||||
|
if not isinstance(data, DaskDMatrix):
|
||||||
|
raise TypeError(_expect([DaskDMatrix], type(data)))
|
||||||
|
|
||||||
|
worker_map = data.worker_map
|
||||||
|
client = _xgb_get_client(client)
|
||||||
|
|
||||||
|
rabit_args = _get_rabit_args(worker_map, client)
|
||||||
|
|
||||||
|
def dispatched_predict(worker_id):
|
||||||
|
'''Perform prediction on each worker.'''
|
||||||
|
logging.info('Predicting on %d', worker_id)
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
local_x = data.get_worker_data(worker)
|
||||||
|
|
||||||
|
with RabitContext(rabit_args):
|
||||||
|
local_predictions = booster.predict(data=local_x, *args)
|
||||||
|
return local_predictions
|
||||||
|
|
||||||
|
futures = client.map(dispatched_predict,
|
||||||
|
range(len(worker_map)),
|
||||||
|
workers=list(worker_map.keys()))
|
||||||
|
|
||||||
|
def dispatched_get_shape(worker_id):
|
||||||
|
'''Get shape of data in each worker.'''
|
||||||
|
logging.info('Trying to get data shape on %d', worker_id)
|
||||||
|
worker = distributed_get_worker()
|
||||||
|
rows, cols = data.get_worker_data_shape(worker)
|
||||||
|
return rows, cols
|
||||||
|
|
||||||
|
# Constructing a dask array from list of numpy arrays
|
||||||
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
|
futures_shape = client.map(dispatched_get_shape,
|
||||||
|
range(len(worker_map)),
|
||||||
|
workers=list(worker_map.keys()))
|
||||||
|
shapes = client.gather(futures_shape)
|
||||||
|
arrays = []
|
||||||
|
for i in range(len(futures_shape)):
|
||||||
|
arrays.append(da.from_delayed(futures[i], shape=shapes[i],
|
||||||
|
dtype=numpy.float32))
|
||||||
|
predictions = da.concatenate(arrays, axis=0)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluation_matrices(client, validation_set, sample_weights):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
validation_set: list of tuples
|
||||||
|
Each tuple contains a validation dataset including input X and label y.
|
||||||
|
E.g.:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
[(X_0, y_0), (X_1, y_1), ... ]
|
||||||
|
|
||||||
|
sample_weights: list of arrays
|
||||||
|
The weight vector for validation data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evals: list of validation DMatrix
|
||||||
|
'''
|
||||||
|
evals = []
|
||||||
|
if validation_set is not None:
|
||||||
|
assert isinstance(validation_set, list)
|
||||||
|
for i, e in enumerate(validation_set):
|
||||||
|
w = (sample_weights[i]
|
||||||
|
if sample_weights is not None else None)
|
||||||
|
dmat = DaskDMatrix(client=client, data=e[0], label=e[1], weight=w)
|
||||||
|
evals.append((dmat, 'validation_{}'.format(i)))
|
||||||
|
else:
|
||||||
|
evals = None
|
||||||
|
return evals
|
||||||
|
|
||||||
|
|
||||||
|
class DaskScikitLearnBase(XGBModel):
|
||||||
|
'''Base class for implementing scikit-learn interface with Dask'''
|
||||||
|
|
||||||
|
_client = None
|
||||||
|
|
||||||
|
# pylint: disable=arguments-differ
|
||||||
|
def fit(self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
sample_weights=None,
|
||||||
|
eval_set=None,
|
||||||
|
sample_weight_eval_set=None):
|
||||||
|
'''Fit the regressor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : array_like
|
||||||
|
Feature matrix
|
||||||
|
y : array_like
|
||||||
|
Labels
|
||||||
|
sample_weight : array_like
|
||||||
|
instance weights
|
||||||
|
eval_set : list, optional
|
||||||
|
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||||
|
metrics will be computed.
|
||||||
|
Validation metrics will help us track the performance of the model.
|
||||||
|
sample_weight_eval_set : list, optional
|
||||||
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
|
||||||
|
of group weights on the i-th validation set.'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
|
'''Predict with `data`.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data: data that can be used to construct a DaskDMatrix
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
prediction : dask.array.Array'''
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
'''The dask client used in this model.'''
|
||||||
|
client = _xgb_get_client(self._client)
|
||||||
|
return client
|
||||||
|
|
||||||
|
@client.setter
|
||||||
|
def client(self, clt):
|
||||||
|
self._client = clt
|
||||||
|
|
||||||
|
|
||||||
|
class DaskXGBRegressor(DaskScikitLearnBase):
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
|
||||||
|
'regression. \n\n') + '\n'.join(
|
||||||
|
XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
sample_weights=None,
|
||||||
|
eval_set=None,
|
||||||
|
sample_weight_eval_set=None):
|
||||||
|
_assert_dask_installed()
|
||||||
|
dtrain = DaskDMatrix(client=self.client,
|
||||||
|
data=X, label=y, weight=sample_weights)
|
||||||
|
params = self.get_xgb_params()
|
||||||
|
evals = _evaluation_matrices(self.client,
|
||||||
|
eval_set, sample_weight_eval_set)
|
||||||
|
|
||||||
|
results = train(self.client, params, dtrain,
|
||||||
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
|
evals=evals)
|
||||||
|
self._Booster = results['booster']
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
self.evals_result_ = results['history']
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
|
_assert_dask_installed()
|
||||||
|
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||||
|
pred_probs = predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
return pred_probs
|
||||||
|
|
||||||
|
|
||||||
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
_client = None
|
||||||
|
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
|
||||||
|
'classification.\n\n') + '\n'.join(
|
||||||
|
XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
sample_weights=None,
|
||||||
|
eval_set=None,
|
||||||
|
sample_weight_eval_set=None):
|
||||||
|
_assert_dask_installed()
|
||||||
|
dtrain = DaskDMatrix(client=self.client,
|
||||||
|
data=X, label=y, weight=sample_weights)
|
||||||
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
self.classes_ = da.unique(y).compute()
|
||||||
|
self.n_classes_ = len(self.classes_)
|
||||||
|
|
||||||
|
if self.n_classes_ > 2:
|
||||||
|
params["objective"] = "multi:softprob"
|
||||||
|
params['num_class'] = self.n_classes_
|
||||||
|
else:
|
||||||
|
params["objective"] = "binary:logistic"
|
||||||
|
params.setdefault('num_class', self.n_classes_)
|
||||||
|
|
||||||
|
evals = _evaluation_matrices(self.client,
|
||||||
|
eval_set, sample_weight_eval_set)
|
||||||
|
results = train(self.client, params, dtrain,
|
||||||
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
|
evals=evals)
|
||||||
|
self._Booster = results['booster']
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
self.evals_result_ = results['history']
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data): # pylint: disable=arguments-differ
|
||||||
|
_assert_dask_installed()
|
||||||
|
test_dmatrix = DaskDMatrix(client=self.client, data=data)
|
||||||
|
pred_probs = predict(client=self.client,
|
||||||
|
model=self.get_booster(), data=test_dmatrix)
|
||||||
|
return pred_probs
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
|
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
|
||||||
"""Scikit-Learn Wrapper interface for XGBoost."""
|
"""Scikit-Learn Wrapper interface for XGBoost."""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -282,7 +280,8 @@ class XGBModel(XGBModelBase):
|
|||||||
"object {} will be lost. ".format(type(self).__name__) +
|
"object {} will be lost. ".format(type(self).__name__) +
|
||||||
"If you did not mean to export the model to " +
|
"If you did not mean to export the model to " +
|
||||||
"a non-Python binding of XGBoost, consider " +
|
"a non-Python binding of XGBoost, consider " +
|
||||||
"using `pickle` or `joblib` to save your model.", Warning)
|
"using `pickle` or `joblib` to save your model.",
|
||||||
|
Warning)
|
||||||
self.get_booster().save_model(fname)
|
self.get_booster().save_model(fname)
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import testing as tm
|
import testing as tm
|
||||||
import pytest
|
import pytest
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
|
||||||
import sys
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -11,83 +11,86 @@ try:
|
|||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
|
from xgboost.dask import DaskDMatrix
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||||
|
|
||||||
|
kRows = 1000
|
||||||
def run_train():
|
|
||||||
# Contains one label equal to rank
|
|
||||||
dmat = xgb.DMatrix(np.array([[0]]), label=[xgb.rabit.get_rank()])
|
|
||||||
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
|
|
||||||
pred = bst.predict(dmat)
|
|
||||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
|
||||||
assert all(p == expected_result for p in pred)
|
|
||||||
|
|
||||||
|
|
||||||
def test_train(client):
|
def generate_array():
|
||||||
# Train two workers, the first has label 0, the second has label 1
|
|
||||||
# If they build the model together the output should be 0.5
|
|
||||||
xgb.dask.run(client, run_train)
|
|
||||||
# Run again to check we can have multiple sessions
|
|
||||||
xgb.dask.run(client, run_train)
|
|
||||||
|
|
||||||
|
|
||||||
def run_create_dmatrix(X, y, weights):
|
|
||||||
dmat = xgb.dask.create_worker_dmatrix(X, y, weight=weights)
|
|
||||||
# Expect this worker to get two partitions and concatenate them
|
|
||||||
assert dmat.num_row() == 50
|
|
||||||
|
|
||||||
|
|
||||||
def test_dask_dataframe(client):
|
|
||||||
n = 10
|
n = 10
|
||||||
m = 100
|
partition_size = 20
|
||||||
partition_size = 25
|
X = da.random.random((kRows, n), partition_size)
|
||||||
X = dd.from_array(np.random.random((m, n)), partition_size)
|
y = da.random.random(kRows, partition_size)
|
||||||
y = dd.from_array(np.random.random(m), partition_size)
|
return X, y
|
||||||
weights = dd.from_array(np.random.random(m), partition_size)
|
|
||||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dask_array(client):
|
def test_from_dask_dataframe(client):
|
||||||
n = 10
|
X, y = generate_array()
|
||||||
m = 100
|
|
||||||
partition_size = 25
|
X = dd.from_dask_array(X)
|
||||||
X = da.random.random((m, n), partition_size)
|
y = dd.from_dask_array(y)
|
||||||
y = da.random.random(m, partition_size)
|
|
||||||
weights = da.random.random(m, partition_size)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
xgb.dask.run(client, run_create_dmatrix, X, y, weights)
|
booster = xgb.dask.train(
|
||||||
|
client, {}, dtrain, num_boost_round=2)['booster']
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||||
|
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
|
assert prediction.shape[0] == kRows, prediction
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# evals_result is not supported in dask interface.
|
||||||
|
xgb.dask.train(
|
||||||
|
client, {}, dtrain, num_boost_round=2, evals_result={})
|
||||||
|
|
||||||
|
|
||||||
def run_get_local_data(X, y):
|
def test_from_dask_array(client):
|
||||||
X_local = xgb.dask.get_local_data(X)
|
X, y = generate_array()
|
||||||
y_local = xgb.dask.get_local_data(y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
assert (X_local.shape == (50, 10))
|
# results is {'booster': Booster, 'history': {...}}
|
||||||
assert (y_local.shape == (50,))
|
result = xgb.dask.train(client, {}, dtrain)
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, result, dtrain)
|
||||||
|
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
|
|
||||||
|
|
||||||
def test_get_local_data(client):
|
def test_regressor(client):
|
||||||
n = 10
|
X, y = generate_array()
|
||||||
m = 100
|
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
partition_size = 25
|
regressor.set_params(tree_method='hist')
|
||||||
X = da.random.random((m, n), partition_size)
|
regressor.client = client
|
||||||
y = da.random.random(m, partition_size)
|
regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
xgb.dask.run(client, run_get_local_data, X, y)
|
prediction = regressor.predict(X)
|
||||||
|
|
||||||
|
history = regressor.evals_result()
|
||||||
|
|
||||||
|
assert isinstance(prediction, da.Array)
|
||||||
|
assert isinstance(history, dict)
|
||||||
|
|
||||||
|
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||||
|
assert len(history['validation_0']['rmse']) == 2
|
||||||
|
|
||||||
|
|
||||||
def run_sklearn():
|
def test_classifier(client):
|
||||||
# Contains one label equal to rank
|
X, y = generate_array()
|
||||||
X = np.array([[0]])
|
y = (y * 10).astype(np.int32)
|
||||||
y = [xgb.rabit.get_rank()]
|
classifier = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2)
|
||||||
model = xgb.XGBRegressor(learning_rate=1.0)
|
classifier.client = client
|
||||||
model.fit(X, y)
|
classifier.fit(X, y, eval_set=[(X, y)])
|
||||||
pred = model.predict(X)
|
prediction = classifier.predict(X)
|
||||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
|
||||||
assert all(p == expected_result for p in pred)
|
|
||||||
return pred
|
|
||||||
|
|
||||||
|
history = classifier.evals_result()
|
||||||
|
|
||||||
def test_sklearn(client):
|
assert isinstance(prediction, da.Array)
|
||||||
result = xgb.dask.run(client, run_sklearn)
|
assert isinstance(history, dict)
|
||||||
print(result)
|
|
||||||
|
assert list(history.keys())[0] == 'validation_0'
|
||||||
|
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||||
|
assert len(list(history['validation_0'])) == 1
|
||||||
|
assert len(history['validation_0']['merror']) == 2
|
||||||
|
|||||||
@ -18,6 +18,7 @@ if [ ${TASK} == "python_test" ]; then
|
|||||||
conda install numpy scipy pandas matplotlib scikit-learn
|
conda install numpy scipy pandas matplotlib scikit-learn
|
||||||
|
|
||||||
python -m pip install graphviz pytest pytest-cov codecov
|
python -m pip install graphviz pytest pytest-cov codecov
|
||||||
|
python -m pip install dask distributed dask[dataframe]
|
||||||
python -m pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl
|
python -m pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl
|
||||||
python -m pytest -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
|
python -m pytest -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
|
||||||
codecov
|
codecov
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user