From 8ca06ab329b29f3fc3a92174bcc35ff9896d3a5c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 17 Mar 2020 13:07:20 +0800 Subject: [PATCH] [dask] Check non-equal when setting threads. (#5421) * Check non-equal. `nthread` can be restored from internal parameter, which is mis-interpreted as user defined parameter. * Check None. --- python-package/xgboost/dask.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a06a25502..541adcd84 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -421,10 +421,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): local_history = {} local_param = params.copy() # just to be consistent msg = 'Overriding `nthreads` defined in dask worker.' - if 'nthread' in local_param.keys(): + if 'nthread' in local_param.keys() and \ + local_param['nthread'] is not None and \ + local_param['nthread'] != worker.nthreads: msg += '`nthread` is specified. ' + msg LOGGER.warning(msg) - elif 'n_jobs' in local_param.keys(): + elif 'n_jobs' in local_param.keys() and \ + local_param['n_jobs'] is not None and \ + local_param['n_jobs'] != worker.nthreads: msg = '`n_jobs` is specified. ' + msg LOGGER.warning(msg) else: