[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)
* add a testcase causing RuntimeError * move seed/random_state/nthread/n_jobs check to get_xgb_params() * fix failed test
This commit is contained in:
committed by
Yuan (Terry) Tang
parent
41efe32aa5
commit
65d2513714
@@ -155,25 +155,10 @@ class XGBModel(XGBModelBase):
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.kwargs = kwargs
|
||||
self._Booster = None
|
||||
if seed:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
self.seed = seed
|
||||
self.random_state = seed
|
||||
else:
|
||||
self.seed = random_state
|
||||
self.random_state = random_state
|
||||
|
||||
if nthread:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
self.nthread = nthread
|
||||
self.n_jobs = nthread
|
||||
else:
|
||||
self.nthread = n_jobs
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
self.random_state = random_state
|
||||
self.nthread = nthread
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatibility code
|
||||
@@ -211,12 +196,24 @@ class XGBModel(XGBModelBase):
|
||||
def get_xgb_params(self):
|
||||
"""Get xgboost type parameters."""
|
||||
xgb_params = self.get_params()
|
||||
xgb_params.pop('random_state')
|
||||
xgb_params.pop('n_jobs')
|
||||
random_state = xgb_params.pop('random_state')
|
||||
if xgb_params['seed'] is not None:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['seed'] = random_state
|
||||
n_jobs = xgb_params.pop('n_jobs')
|
||||
if xgb_params['nthread'] is not None:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['nthread'] = n_jobs
|
||||
|
||||
xgb_params['silent'] = 1 if self.silent else 0
|
||||
|
||||
if self.nthread <= 0:
|
||||
if xgb_params['nthread'] <= 0:
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user