[dask] Honor nthreads from dask worker. (#5414)
This commit is contained in:
parent
21b671aa06
commit
761a5dbdfc
@ -22,7 +22,6 @@ def main(client):
|
|||||||
# evaluation metrics.
|
# evaluation metrics.
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(client,
|
||||||
{'verbosity': 1,
|
{'verbosity': 1,
|
||||||
'nthread': 1,
|
|
||||||
'tree_method': 'hist'},
|
'tree_method': 'hist'},
|
||||||
dtrain,
|
dtrain,
|
||||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
@ -37,6 +36,6 @@ def main(client):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# or use other clusters for scaling
|
# or use other clusters for scaling
|
||||||
with LocalCluster(n_workers=7, threads_per_worker=1) as cluster:
|
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
main(client)
|
main(client)
|
||||||
|
|||||||
@ -22,7 +22,6 @@ def main(client):
|
|||||||
# evaluation metrics.
|
# evaluation metrics.
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(client,
|
||||||
{'verbosity': 2,
|
{'verbosity': 2,
|
||||||
'nthread': 1,
|
|
||||||
# Golden line for GPU training
|
# Golden line for GPU training
|
||||||
'tree_method': 'gpu_hist'},
|
'tree_method': 'gpu_hist'},
|
||||||
dtrain,
|
dtrain,
|
||||||
@ -41,6 +40,6 @@ if __name__ == '__main__':
|
|||||||
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
|
# `LocalCUDACluster` is used for assigning GPU to XGBoost processes. Here
|
||||||
# `n_workers` represents the number of GPUs since we use one GPU per worker
|
# `n_workers` represents the number of GPUs since we use one GPU per worker
|
||||||
# process.
|
# process.
|
||||||
with LocalCUDACluster(n_workers=2, threads_per_worker=1) as cluster:
|
with LocalCUDACluster(n_workers=2, threads_per_worker=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
main(client)
|
main(client)
|
||||||
|
|||||||
@ -37,7 +37,6 @@ illustrates the basic usage:
|
|||||||
|
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(client,
|
||||||
{'verbosity': 2,
|
{'verbosity': 2,
|
||||||
'nthread': 1,
|
|
||||||
'tree_method': 'hist'},
|
'tree_method': 'hist'},
|
||||||
dtrain,
|
dtrain,
|
||||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
@ -76,6 +75,32 @@ Another set of API is a Scikit-Learn wrapper, which mimics the stateful Scikit-L
|
|||||||
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
|
interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask``
|
||||||
for more examples.
|
for more examples.
|
||||||
|
|
||||||
|
*******
|
||||||
|
Threads
|
||||||
|
*******
|
||||||
|
|
||||||
|
XGBoost has built in support for parallel computation through threads by the setting
|
||||||
|
``nthread`` parameter (``n_jobs`` for scikit-learn). If these parameters are set, they
|
||||||
|
will override the configuration in Dask. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||||
|
|
||||||
|
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
|
||||||
|
threads in each process for both training and prediction. But if ``nthread`` parameter is
|
||||||
|
set:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
output = xgb.dask.train(client,
|
||||||
|
{'verbosity': 1,
|
||||||
|
'nthread': 8,
|
||||||
|
'tree_method': 'hist'},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=4, evals=[(dtrain, 'train')])
|
||||||
|
|
||||||
|
XGBoost will use 8 threads in each training process.
|
||||||
|
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||||
|
|||||||
@ -42,6 +42,9 @@ from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
|||||||
# - Ranking
|
# - Ranking
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(host, n_workers):
|
def _start_tracker(host, n_workers):
|
||||||
"""Start Rabit tracker """
|
"""Start Rabit tracker """
|
||||||
env = {'DMLC_NUM_WORKER': n_workers}
|
env = {'DMLC_NUM_WORKER': n_workers}
|
||||||
@ -62,7 +65,7 @@ def _assert_dask_support():
|
|||||||
if platform.system() == 'Windows':
|
if platform.system() == 'Windows':
|
||||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||||
msg += ' contribution are welcomed.'
|
msg += ' contribution are welcomed.'
|
||||||
logging.warning(msg)
|
LOGGER.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
class RabitContext:
|
class RabitContext:
|
||||||
@ -75,11 +78,11 @@ class RabitContext:
|
|||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
rabit.init(self.args)
|
rabit.init(self.args)
|
||||||
logging.debug('-------------- rabit say hello ------------------')
|
LOGGER.debug('-------------- rabit say hello ------------------')
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
rabit.finalize()
|
rabit.finalize()
|
||||||
logging.debug('--------------- rabit say bye ------------------')
|
LOGGER.debug('--------------- rabit say bye ------------------')
|
||||||
|
|
||||||
|
|
||||||
def concat(value):
|
def concat(value):
|
||||||
@ -301,7 +304,7 @@ class DaskDMatrix:
|
|||||||
'All workers associated with this DMatrix: {workers}'.format(
|
'All workers associated with this DMatrix: {workers}'.format(
|
||||||
address=worker.address,
|
address=worker.address,
|
||||||
workers=set(self.worker_map.keys()))
|
workers=set(self.worker_map.keys()))
|
||||||
logging.warning(msg)
|
LOGGER.warning(msg)
|
||||||
d = DMatrix(numpy.empty((0, 0)),
|
d = DMatrix(numpy.empty((0, 0)),
|
||||||
feature_names=self.feature_names,
|
feature_names=self.feature_names,
|
||||||
feature_types=self.feature_types)
|
feature_types=self.feature_types)
|
||||||
@ -324,7 +327,8 @@ class DaskDMatrix:
|
|||||||
weight=weights,
|
weight=weights,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
feature_names=self.feature_names,
|
feature_names=self.feature_names,
|
||||||
feature_types=self.feature_types)
|
feature_types=self.feature_types,
|
||||||
|
nthread=worker.nthreads)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
def get_worker_data_shape(self, worker):
|
def get_worker_data_shape(self, worker):
|
||||||
@ -399,7 +403,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
|
|
||||||
def dispatched_train(worker_addr):
|
def dispatched_train(worker_addr):
|
||||||
'''Perform training on a single worker.'''
|
'''Perform training on a single worker.'''
|
||||||
logging.info('Training on %s', str(worker_addr))
|
LOGGER.info('Training on %s', str(worker_addr))
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
with RabitContext(rabit_args):
|
with RabitContext(rabit_args):
|
||||||
local_dtrain = dtrain.get_worker_data(worker)
|
local_dtrain = dtrain.get_worker_data(worker)
|
||||||
@ -415,6 +419,15 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
|
|
||||||
local_history = {}
|
local_history = {}
|
||||||
local_param = params.copy() # just to be consistent
|
local_param = params.copy() # just to be consistent
|
||||||
|
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||||
|
if 'nthread' in local_param.keys():
|
||||||
|
msg += '`nthread` is specified. ' + msg
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
elif 'n_jobs' in local_param.keys():
|
||||||
|
msg = '`n_jobs` is specified. ' + msg
|
||||||
|
LOGGER.warning(msg)
|
||||||
|
else:
|
||||||
|
local_param['nthread'] = worker.nthreads
|
||||||
bst = worker_train(params=local_param,
|
bst = worker_train(params=local_param,
|
||||||
dtrain=local_dtrain,
|
dtrain=local_dtrain,
|
||||||
*args,
|
*args,
|
||||||
@ -477,15 +490,17 @@ def predict(client, model, data, *args):
|
|||||||
|
|
||||||
def dispatched_predict(worker_id):
|
def dispatched_predict(worker_id):
|
||||||
'''Perform prediction on each worker.'''
|
'''Perform prediction on each worker.'''
|
||||||
logging.info('Predicting on %d', worker_id)
|
LOGGER.info('Predicting on %d', worker_id)
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
list_of_parts = data.get_worker_x_ordered(worker)
|
list_of_parts = data.get_worker_x_ordered(worker)
|
||||||
predictions = []
|
predictions = []
|
||||||
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
for part, order in list_of_parts:
|
for part, order in list_of_parts:
|
||||||
local_x = DMatrix(part,
|
local_x = DMatrix(part,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
missing=missing)
|
missing=missing,
|
||||||
|
nthread=worker.nthreads)
|
||||||
predt = booster.predict(data=local_x,
|
predt = booster.predict(data=local_x,
|
||||||
validate_features=local_x.num_row() != 0,
|
validate_features=local_x.num_row() != 0,
|
||||||
*args)
|
*args)
|
||||||
@ -495,7 +510,7 @@ def predict(client, model, data, *args):
|
|||||||
|
|
||||||
def dispatched_get_shape(worker_id):
|
def dispatched_get_shape(worker_id):
|
||||||
'''Get shape of data in each worker.'''
|
'''Get shape of data in each worker.'''
|
||||||
logging.info('Trying to get data shape on %d', worker_id)
|
LOGGER.info('Trying to get data shape on %d', worker_id)
|
||||||
worker = distributed_get_worker()
|
worker = distributed_get_worker()
|
||||||
list_of_parts = data.get_worker_x_ordered(worker)
|
list_of_parts = data.get_worker_x_ordered(worker)
|
||||||
shapes = []
|
shapes = []
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import pytest
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -60,7 +61,7 @@ def test_from_dask_dataframe():
|
|||||||
|
|
||||||
|
|
||||||
def test_from_dask_array():
|
def test_from_dask_array():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=5, threads_per_worker=5) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
X, y = generate_array()
|
X, y = generate_array()
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
@ -74,11 +75,15 @@ def test_from_dask_array():
|
|||||||
# force prediction to be computed
|
# force prediction to be computed
|
||||||
prediction = prediction.compute()
|
prediction = prediction.compute()
|
||||||
|
|
||||||
single_node_predt = result['booster'].predict(
|
booster = result['booster']
|
||||||
|
single_node_predt = booster.predict(
|
||||||
xgb.DMatrix(X.compute())
|
xgb.DMatrix(X.compute())
|
||||||
)
|
)
|
||||||
np.testing.assert_allclose(prediction, single_node_predt)
|
np.testing.assert_allclose(prediction, single_node_predt)
|
||||||
|
|
||||||
|
config = json.loads(booster.save_config())
|
||||||
|
assert int(config['learner']['generic_param']['nthread']) == 5
|
||||||
|
|
||||||
|
|
||||||
def test_dask_regressor():
|
def test_dask_regressor():
|
||||||
with LocalCluster(n_workers=5) as cluster:
|
with LocalCluster(n_workers=5) as cluster:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user