[dask] Honor nthreads from dask worker. (#5414)

This commit is contained in:
Jiaming Yuan 2020-03-16 04:51:24 +08:00 committed by GitHub
parent 21b671aa06
commit 761a5dbdfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 16 deletions

View File

@ -22,7 +22,6 @@ def main(client):
# evaluation metrics.
output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 1,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
@ -37,6 +36,6 @@ def main(client):
if __name__ == '__main__':
# or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=1) as cluster:
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)

View File

@ -22,7 +22,6 @@ def main(client):
# evaluation metrics.
output = xgb.dask.train(client,
{'verbosity': 2,
'nthread': 1,
# Golden line for GPU training
'tree_method': 'gpu_hist'},
dtrain,
@ -41,6 +40,6 @@ if __name__ == '__main__':
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
# `n_workers` represents the number of GPUs since we use one GPU per worker
# process.
with LocalCUDACluster(n_workers=2, threads_per_worker=1) as cluster:
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)

View File

@ -37,7 +37,6 @@ illustrates the basic usage:
output = xgb.dask.train(client,
{'verbosity': 2,
'nthread': 1,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
@ -76,6 +75,32 @@ Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-L
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
for more examples.
*******
Threads
*******
XGBoost has built in support for parallel computation through threads by the setting
``nthread`` parameter (``n_jobs`` for scikit-learn). If these parameters are set, they
will override the configuration in Dask. For example:
.. code-block:: python
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
threads in each process for both training and prediction. But if ``nthread`` parameter is
set:
.. code-block:: python
output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 8,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
XGBoost will use 8 threads in each training process.
*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors

View File

@ -42,6 +42,9 @@ from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
# - Ranking
LOGGER = logging.getLogger('[xgboost.dask]')
def _start_tracker(host, n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
@ -62,7 +65,7 @@ def _assert_dask_support():
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
logging.warning(msg)
LOGGER.warning(msg)
class RabitContext:
@ -75,11 +78,11 @@ class RabitContext:
def __enter__(self):
rabit.init(self.args)
logging.debug('-------------- rabit say hello ------------------')
LOGGER.debug('-------------- rabit say hello ------------------')
def __exit__(self, *args):
rabit.finalize()
logging.debug('--------------- rabit say bye ------------------')
LOGGER.debug('--------------- rabit say bye ------------------')
def concat(value):
@ -301,7 +304,7 @@ class DaskDMatrix:
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
logging.warning(msg)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types)
@ -324,7 +327,8 @@ class DaskDMatrix:
weight=weights,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types)
feature_types=self.feature_types,
nthread=worker.nthreads)
return dmatrix
def get_worker_data_shape(self, worker):
@ -399,7 +403,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
logging.info('Training on %s', str(worker_addr))
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker()
with RabitContext(rabit_args):
local_dtrain = dtrain.get_worker_data(worker)
@ -415,6 +419,15 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys():
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys():
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
@ -477,15 +490,17 @@ def predict(client, model, data, *args):
def dispatched_predict(worker_id):
'''Perform prediction on each worker.'''
logging.info('Predicting on %d', worker_id)
LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for part, order in list_of_parts:
local_x = DMatrix(part,
feature_names=feature_names,
feature_types=feature_types,
missing=missing)
missing=missing,
nthread=worker.nthreads)
predt = booster.predict(data=local_x,
validate_features=local_x.num_row() != 0,
*args)
@ -495,7 +510,7 @@ def predict(client, model, data, *args):
def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
logging.info('Trying to get data shape on %d', worker_id)
LOGGER.info('Trying to get data shape on %d', worker_id)
worker = distributed_get_worker()
list_of_parts = data.get_worker_x_ordered(worker)
shapes = []

View File

@ -3,6 +3,7 @@ import pytest
import xgboost as xgb
import sys
import numpy as np
import json
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -60,7 +61,7 @@ def test_from_dask_dataframe():
def test_from_dask_array():
with LocalCluster(n_workers=5) as cluster:
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
with Client(cluster) as client:
X, y = generate_array()
dtrain = DaskDMatrix(client, X, y)
@ -74,11 +75,15 @@ def test_from_dask_array():
# force prediction to be computed
prediction = prediction.compute()
single_node_predt = result['booster'].predict(
booster = result['booster']
single_node_predt = booster.predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)
config = json.loads(booster.save_config())
assert int(config['learner']['generic_param']['nthread']) == 5
def test_dask_regressor():
with LocalCluster(n_workers=5) as cluster: