[dask] Check non-equal when setting threads. (#5421)

* Check non-equal.

`nthread` can be restored from internal parameter, which is mis-interpreted as
user defined parameter.

* Check None.
This commit is contained in:
Jiaming Yuan 2020-03-17 13:07:20 +08:00 committed by GitHub
parent b51124c158
commit 8ca06ab329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -421,10 +421,14 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys():
if 'nthread' in local_param.keys() and \
local_param['nthread'] is not None and \
local_param['nthread'] != worker.nthreads:
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys():
elif 'n_jobs' in local_param.keys() and \
local_param['n_jobs'] is not None and \
local_param['n_jobs'] != worker.nthreads:
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else: