diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a06a25502..541adcd84 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -421,10 +421,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs): local_history = {} local_param = params.copy() # just to be consistent msg = 'Overriding `nthreads` defined in dask worker.' - if 'nthread' in local_param.keys(): + if 'nthread' in local_param.keys() and \ + local_param['nthread'] is not None and \ + local_param['nthread'] != worker.nthreads: msg += '`nthread` is specified. ' + msg LOGGER.warning(msg) - elif 'n_jobs' in local_param.keys(): + elif 'n_jobs' in local_param.keys() and \ + local_param['n_jobs'] is not None and \ + local_param['n_jobs'] != worker.nthreads: msg = '`n_jobs` is specified. ' + msg LOGGER.warning(msg) else: