[dask] Honor nthreads from dask worker. (#5414)
This commit is contained in:
parent
21b671aa06
commit
761a5dbdfc
@ -22,7 +22,6 @@ def main(client):
|
||||
# evaluation metrics.
|
||||
output = xgb.dask.train(client,
|
||||
{'verbosity': 1,
|
||||
'nthread': 1,
|
||||
'tree_method': 'hist'},
|
||||
dtrain,
|
||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||
@ -37,6 +36,6 @@ def main(client):
|
||||
|
||||
if __name__ == '__main__':
|
||||
# or use other clusters for scaling
|
||||
with LocalCluster(n_workers=7, threads_per_worker=1) as cluster:
|
||||
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
main(client)
|
||||
|
||||
@ -22,7 +22,6 @@ def main(client):
|
||||
# evaluation metrics.
|
||||
output = xgb.dask.train(client,
|
||||
{'verbosity': 2,
|
||||
'nthread': 1,
|
||||
# Golden line for GPU training
|
||||
'tree_method': 'gpu_hist'},
|
||||
dtrain,
|
||||
@ -41,6 +40,6 @@ if __name__ == '__main__':
|
||||
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
|
||||
# `n_workers` represents the number of GPUs since we use one GPU per worker
|
||||
# process.
|
||||
with LocalCUDACluster(n_workers=2, threads_per_worker=1) as cluster:
|
||||
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
main(client)
|
||||
|
||||
@ -37,7 +37,6 @@ illustrates the basic usage:
|
||||
|
||||
output = xgb.dask.train(client,
|
||||
{'verbosity': 2,
|
||||
'nthread': 1,
|
||||
'tree_method': 'hist'},
|
||||
dtrain,
|
||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||
@ -76,6 +75,32 @@ Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-L
|
||||
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
|
||||
for more examples.
|
||||
|
||||
*******
|
||||
Threads
|
||||
*******
|
||||
|
||||
XGBoost has built in support for parallel computation through threads by the setting
|
||||
``nthread`` parameter (``n_jobs`` for scikit-learn). If these parameters are set, they
|
||||
will override the configuration in Dask. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||
|
||||
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
|
||||
threads in each process for both training and prediction. But if ``nthread`` parameter is
|
||||
set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
output = xgb.dask.train(client,
|
||||
{'verbosity': 1,
|
||||
'nthread': 8,
|
||||
'tree_method': 'hist'},
|
||||
dtrain,
|
||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||
|
||||
XGBoost will use 8 threads in each training process.
|
||||
|
||||
*****************************************************************************
|
||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||
|
||||
@ -42,6 +42,9 @@ from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
||||
# - Ranking
|
||||
|
||||
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
|
||||
|
||||
def _start_tracker(host, n_workers):
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
@ -62,7 +65,7 @@ def _assert_dask_support():
|
||||
if platform.system() == 'Windows':
|
||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||
msg += ' contribution are welcomed.'
|
||||
logging.warning(msg)
|
||||
LOGGER.warning(msg)
|
||||
|
||||
|
||||
class RabitContext:
|
||||
@ -75,11 +78,11 @@ class RabitContext:
|
||||
|
||||
def __enter__(self):
|
||||
rabit.init(self.args)
|
||||
logging.debug('-------------- rabit say hello ------------------')
|
||||
LOGGER.debug('-------------- rabit say hello ------------------')
|
||||
|
||||
def __exit__(self, *args):
|
||||
rabit.finalize()
|
||||
logging.debug('--------------- rabit say bye ------------------')
|
||||
LOGGER.debug('--------------- rabit say bye ------------------')
|
||||
|
||||
|
||||
def concat(value):
|
||||
@ -301,7 +304,7 @@ class DaskDMatrix:
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
logging.warning(msg)
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
@ -324,7 +327,8 @@ class DaskDMatrix:
|
||||
weight=weights,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
feature_types=self.feature_types,
|
||||
nthread=worker.nthreads)
|
||||
return dmatrix
|
||||
|
||||
def get_worker_data_shape(self, worker):
|
||||
@ -399,7 +403,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
def dispatched_train(worker_addr):
|
||||
'''Perform training on a single worker.'''
|
||||
logging.info('Training on %s', str(worker_addr))
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed_get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = dtrain.get_worker_data(worker)
|
||||
@ -415,6 +419,15 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
local_history = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
if 'nthread' in local_param.keys():
|
||||
msg += '`nthread` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
elif 'n_jobs' in local_param.keys():
|
||||
msg = '`n_jobs` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
local_param['nthread'] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
@ -477,15 +490,17 @@ def predict(client, model, data, *args):
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
logging.info('Predicting on %d', worker_id)
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
predictions = []
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing)
|
||||
missing=missing,
|
||||
nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
@ -495,7 +510,7 @@ def predict(client, model, data, *args):
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
logging.info('Trying to get data shape on %d', worker_id)
|
||||
LOGGER.info('Trying to get data shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
shapes = []
|
||||
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@ -60,7 +61,7 @@ def test_from_dask_dataframe():
|
||||
|
||||
|
||||
def test_from_dask_array():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
@ -74,11 +75,15 @@ def test_from_dask_array():
|
||||
# force prediction to be computed
|
||||
prediction = prediction.compute()
|
||||
|
||||
single_node_predt = result['booster'].predict(
|
||||
booster = result['booster']
|
||||
single_node_predt = booster.predict(
|
||||
xgb.DMatrix(X.compute())
|
||||
)
|
||||
np.testing.assert_allclose(prediction, single_node_predt)
|
||||
|
||||
config = json.loads(booster.save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||
|
||||
|
||||
def test_dask_regressor():
|
||||
with LocalCluster(n_workers=5) as cluster:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user