[dask] Honor nthreads from dask worker. (#5414)
This commit is contained in:
@@ -42,6 +42,9 @@ from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
||||
# - Ranking
|
||||
|
||||
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
|
||||
|
||||
def _start_tracker(host, n_workers):
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
@@ -62,7 +65,7 @@ def _assert_dask_support():
|
||||
if platform.system() == 'Windows':
|
||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||
msg += ' contribution are welcomed.'
|
||||
logging.warning(msg)
|
||||
LOGGER.warning(msg)
|
||||
|
||||
|
||||
class RabitContext:
|
||||
@@ -75,11 +78,11 @@ class RabitContext:
|
||||
|
||||
def __enter__(self):
|
||||
rabit.init(self.args)
|
||||
logging.debug('-------------- rabit say hello ------------------')
|
||||
LOGGER.debug('-------------- rabit say hello ------------------')
|
||||
|
||||
def __exit__(self, *args):
|
||||
rabit.finalize()
|
||||
logging.debug('--------------- rabit say bye ------------------')
|
||||
LOGGER.debug('--------------- rabit say bye ------------------')
|
||||
|
||||
|
||||
def concat(value):
|
||||
@@ -301,7 +304,7 @@ class DaskDMatrix:
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
logging.warning(msg)
|
||||
LOGGER.warning(msg)
|
||||
d = DMatrix(numpy.empty((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
@@ -324,7 +327,8 @@ class DaskDMatrix:
|
||||
weight=weights,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
feature_types=self.feature_types,
|
||||
nthread=worker.nthreads)
|
||||
return dmatrix
|
||||
|
||||
def get_worker_data_shape(self, worker):
|
||||
@@ -399,7 +403,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
def dispatched_train(worker_addr):
|
||||
'''Perform training on a single worker.'''
|
||||
logging.info('Training on %s', str(worker_addr))
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed_get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = dtrain.get_worker_data(worker)
|
||||
@@ -415,6 +419,15 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
|
||||
local_history = {}
|
||||
local_param = params.copy() # just to be consistent
|
||||
msg = 'Overriding `nthreads` defined in dask worker.'
|
||||
if 'nthread' in local_param.keys():
|
||||
msg += '`nthread` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
elif 'n_jobs' in local_param.keys():
|
||||
msg = '`n_jobs` is specified. ' + msg
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
local_param['nthread'] = worker.nthreads
|
||||
bst = worker_train(params=local_param,
|
||||
dtrain=local_dtrain,
|
||||
*args,
|
||||
@@ -477,15 +490,17 @@ def predict(client, model, data, *args):
|
||||
|
||||
def dispatched_predict(worker_id):
|
||||
'''Perform prediction on each worker.'''
|
||||
logging.info('Predicting on %d', worker_id)
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
predictions = []
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for part, order in list_of_parts:
|
||||
local_x = DMatrix(part,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing)
|
||||
missing=missing,
|
||||
nthread=worker.nthreads)
|
||||
predt = booster.predict(data=local_x,
|
||||
validate_features=local_x.num_row() != 0,
|
||||
*args)
|
||||
@@ -495,7 +510,7 @@ def predict(client, model, data, *args):
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
logging.info('Trying to get data shape on %d', worker_id)
|
||||
LOGGER.info('Trying to get data shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
list_of_parts = data.get_worker_x_ordered(worker)
|
||||
shapes = []
|
||||
|
||||
Reference in New Issue
Block a user