131 lines
4.6 KiB
Python
131 lines
4.6 KiB
Python
# pylint: disable=wrong-import-position,wrong-import-order,import-error
|
|
"""Dask extensions for distributed training. See xgboost/demo/dask for examples."""
|
|
import os
|
|
import sys
|
|
import math
|
|
import platform
|
|
import logging
|
|
from threading import Thread
|
|
from . import rabit
|
|
from .core import DMatrix
|
|
from .compat import (DaskDataFrame, DaskSeries, DaskArray,
|
|
distributed_get_worker)
|
|
|
|
# Try to find the dmlc tracker script
|
|
# For developers it will be the following
|
|
TRACKER_PATH = os.path.dirname(__file__) + "/../../dmlc-core/tracker/dmlc_tracker"
|
|
sys.path.append(TRACKER_PATH)
|
|
try:
|
|
from tracker import RabitTracker # noqa
|
|
except ImportError:
|
|
# If packaged it will be local
|
|
from .tracker import RabitTracker # noqa
|
|
|
|
|
|
def _start_tracker(n_workers):
|
|
""" Start Rabit tracker """
|
|
host = distributed_get_worker().address
|
|
if '://' in host:
|
|
host = host.rsplit('://', 1)[1]
|
|
host, port = host.split(':')
|
|
port = int(port)
|
|
env = {'DMLC_NUM_WORKER': n_workers}
|
|
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
|
|
env.update(rabit_context.slave_envs())
|
|
|
|
rabit_context.start(n_workers)
|
|
thread = Thread(target=rabit_context.join)
|
|
thread.daemon = True
|
|
thread.start()
|
|
return env
|
|
|
|
|
|
def get_local_data(data):
|
|
"""
|
|
Unpacks a distributed data object to get the rows local to this worker
|
|
|
|
:param data: A distributed dask data object
|
|
:return: Local data partition e.g. numpy or pandas
|
|
"""
|
|
if isinstance(data, DaskArray):
|
|
total_partitions = len(data.chunks[0])
|
|
else:
|
|
total_partitions = data.npartitions
|
|
partition_size = int(math.ceil(total_partitions / rabit.get_world_size()))
|
|
begin_partition = partition_size * rabit.get_rank()
|
|
end_partition = min(begin_partition + partition_size, total_partitions)
|
|
if isinstance(data, DaskArray):
|
|
return data.blocks[begin_partition:end_partition].compute()
|
|
|
|
return data.partitions[begin_partition:end_partition].compute()
|
|
|
|
|
|
def create_worker_dmatrix(*args, **kwargs):
|
|
"""
|
|
Creates a DMatrix object local to a given worker. Simply forwards arguments onto the standard
|
|
DMatrix constructor, if one of the arguments is a dask dataframe, unpack the data frame to
|
|
get the local components.
|
|
|
|
All dask dataframe arguments must use the same partitioning.
|
|
|
|
:param args: DMatrix constructor args.
|
|
:return: DMatrix object containing data local to current dask worker
|
|
"""
|
|
dmatrix_args = []
|
|
dmatrix_kwargs = {}
|
|
# Convert positional args
|
|
for arg in args:
|
|
if isinstance(arg, (DaskDataFrame, DaskSeries, DaskArray)):
|
|
dmatrix_args.append(get_local_data(arg))
|
|
else:
|
|
dmatrix_args.append(arg)
|
|
|
|
# Convert keyword args
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, (DaskDataFrame, DaskSeries, DaskArray)):
|
|
dmatrix_kwargs[k] = get_local_data(v)
|
|
else:
|
|
dmatrix_kwargs[k] = v
|
|
|
|
return DMatrix(*dmatrix_args, **dmatrix_kwargs)
|
|
|
|
|
|
def _run_with_rabit(rabit_args, func, *args):
|
|
worker = distributed_get_worker()
|
|
try:
|
|
os.environ["OMP_NUM_THREADS"] = str(worker.ncores)
|
|
except AttributeError:
|
|
os.environ["OMP_NUM_THREADS"] = str(worker.nthreads)
|
|
try:
|
|
rabit.init(rabit_args)
|
|
result = func(*args)
|
|
finally:
|
|
rabit.finalize()
|
|
return result
|
|
|
|
|
|
def run(client, func, *args):
|
|
"""Launch arbitrary function on dask workers. Workers are connected by rabit,
|
|
allowing distributed training. The environment variable OMP_NUM_THREADS is
|
|
defined on each worker according to dask - this means that calls to
|
|
xgb.train() will use the threads allocated by dask by default, unless the
|
|
user overrides the nthread parameter.
|
|
|
|
Note: Windows platforms are not officially
|
|
supported. Contributions are welcome here.
|
|
|
|
:param client: Dask client representing the cluster
|
|
:param func: Python function to be executed by each worker. Typically
|
|
contains xgboost training code.
|
|
:param args: Arguments to be forwarded to func
|
|
:return: Dict containing the function return value for each worker
|
|
|
|
"""
|
|
if platform.system() == 'Windows':
|
|
logging.warning('Windows is not officially supported for dask/xgboost'
|
|
'integration. Contributions welcome.')
|
|
workers = list(client.scheduler_info()['workers'].keys())
|
|
env = client.run(_start_tracker, len(workers), workers=[workers[0]])
|
|
rabit_args = [('%s=%s' % item).encode() for item in env[workers[0]].items()]
|
|
return client.run(_run_with_rabit, rabit_args, func, *args)
|