Add native support for Dask (#4473)

* Add native support for Dask

* Add multi-GPU demo

* Add sklearn example
This commit is contained in:
Rory Mitchell
2019-05-27 13:29:28 +12:00
committed by GitHub
parent 55e645c5f5
commit 09b90d9329
13 changed files with 407 additions and 16 deletions

View File

@@ -3,8 +3,10 @@
from __future__ import absolute_import
import io
import sys
import shutil
import os
from setuptools import setup, find_packages
# import subprocess
sys.path.insert(0, '.')
@@ -27,6 +29,10 @@ for libfile in libpath['find_lib_path']():
continue
print("Install libxgboost from: %s" % LIB_PATH)
# Get dmlc tracker script
shutil.copy('../dmlc-core/tracker/dmlc_tracker/tracker.py', 'xgboost/')
# Please use setup_pip.py for generating and deploying pip installation
# detailed instruction in setup_pip.py
setup(name='xgboost',

View File

@@ -11,6 +11,7 @@ import os
from .core import DMatrix, Booster
from .training import train, cv
from . import rabit # noqa
from . import dask # noqa
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor

View File

@@ -6,49 +6,48 @@ from __future__ import absolute_import
import sys
PY3 = (sys.version_info[0] == 3)
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = (str,)
def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8')
else:
STRING_TYPES = (basestring,) # pylint: disable=undefined-variable
def py_str(x):
"""convert c string back to python string"""
return x
try:
import cPickle as pickle # noqa
import cPickle as pickle # noqa
except ImportError:
import pickle # noqa
import pickle # noqa
# pandas
try:
from pandas import DataFrame
from pandas import MultiIndex
PANDAS_INSTALLED = True
except ImportError:
# pylint: disable=too-few-public-methods
class MultiIndex(object):
""" dummy for pandas.MultiIndex """
# pylint: disable=too-few-public-methods
class DataFrame(object):
""" dummy for pandas.DataFrame """
MultiIndex = object
DataFrame = object
PANDAS_INSTALLED = False
# dt
try:
# Workaround for #4473, compatibility with dask
if sys.__stdin__.closed:
sys.__stdin__ = None
import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
@@ -60,6 +59,7 @@ except ImportError:
class DataTable(object):
""" dummy for datatable.DataTable """
DT_INSTALLED = False
# sklearn
@@ -67,6 +67,7 @@ try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
try:
from sklearn.model_selection import KFold, StratifiedKFold
except ImportError:
@@ -92,3 +93,20 @@ except ImportError:
XGBKFold = None
XGBStratifiedKFold = None
XGBLabelEncoder = None
# dask
try:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import Series as DaskSeries
from dask.array import Array as DaskArray
from distributed import get_worker as distributed_get_worker
DASK_INSTALLED = True
except ImportError:
DaskDataFrame = object
DaskSeries = object
DaskArray = object
distributed_get_worker = None
DASK_INSTALLED = False

View File

@@ -0,0 +1,123 @@
# pylint: disable=wrong-import-position,wrong-import-order,import-error
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
import os
import sys
import math
import platform
import logging
from threading import Thread
from . import rabit
from .core import DMatrix
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
distributed_get_worker)
# Try to find the dmlc tracker script
# For developers it will be the following
TRACKER_PATH = os.path.dirname(__file__) + "/../../dmlc-core/tracker/dmlc_tracker"
sys.path.append(TRACKER_PATH)
try:
from tracker import RabitTracker # noqa
except ImportError:
# If packaged it will be local
from .tracker import RabitTracker # noqa
def _start_tracker(n_workers):
""" Start Rabit tracker """
host = distributed_get_worker().address
if '://' in host:
host = host.rsplit('://', 1)[1]
host, port = host.split(':')
port = int(port)
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
return env
def get_local_data(data):
"""
Unpacks a distributed data object to get the rows local to this worker
:param data: A distributed dask data object
:return: Local data partition e.g. numpy or pandas
"""
if isinstance(data, DaskArray):
total_partitions = len(data.chunks[0])
else:
total_partitions = data.npartitions
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
begin_partition = partition_size * rabit.get_rank()
end_partition = min(begin_partition + partition_size, total_partitions)
if isinstance(data, DaskArray):
return data.blocks[begin_partition:end_partition].compute()
return data.partitions[begin_partition:end_partition].compute()
def create_worker_dmatrix(*args, **kwargs):
"""
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
get the local components.
All dask dataframe arguments must use the same partitioning.
:param args: DMatrix constructor args.
:return: DMatrix object containing data local to current dask worker
"""
dmatrix_args = []
dmatrix_kwargs = {}
# Convert positional args
for arg in args:
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_args.append(get_local_data(arg))
else:
dmatrix_args.append(arg)
# Convert keyword args
for k, v in kwargs.items():
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
dmatrix_kwargs[k] = get_local_data(v)
else:
dmatrix_kwargs[k] = v
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
def _run_with_rabit(rabit_args, func, *args):
os.environ["OMP_NUM_THREADS"] = str(distributed_get_worker().ncores)
try:
rabit.init(rabit_args)
result = func(*args)
finally:
rabit.finalize()
return result
def run(client, func, *args):
"""
Launch arbitrary function on dask workers. Workers are connected by rabit, allowing
distributed training. The environment variable OMP_NUM_THREADS is defined on each worker
according to dask - this means that calls to xgb.train() will use the threads allocated by
dask by default, unless the user overrides the nthread parameter.
Note: Windows platforms are not officially supported. Contributions are welcome here.
:param client: Dask client representing the cluster
:param func: Python function to be executed by each worker. Typically contains xgboost
training code.
:param args: Arguments to be forwarded to func
:return: Dict containing the function return value for each worker
"""
if platform.system() == 'Windows':
logging.warning(
'Windows is not officially supported for dask/xgboost integration. Contributions '
'welcome.')
workers = list(client.scheduler_info()['workers'].keys())
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
return client.run(_run_with_rabit, rabit_args, func, *args)