[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)

* add a testcase causing RuntimeError

* move seed/random_state/nthread/n_jobs check to get_xgb_params()

* fix failed test
This commit is contained in:
wxchan
2017-06-12 21:33:42 +08:00
committed by Yuan (Terry) Tang
parent 41efe32aa5
commit 65d2513714
2 changed files with 34 additions and 28 deletions

View File

@@ -155,25 +155,10 @@ class XGBModel(XGBModelBase):
self.missing = missing if missing is not None else np.nan
self.kwargs = kwargs
self._Booster = None
if seed:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
self.seed = seed
self.random_state = seed
else:
self.seed = random_state
self.random_state = random_state
if nthread:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
self.nthread = nthread
self.n_jobs = nthread
else:
self.nthread = n_jobs
self.n_jobs = n_jobs
self.seed = seed
self.random_state = random_state
self.nthread = nthread
self.n_jobs = n_jobs
def __setstate__(self, state):
# backward compatibility code
@@ -211,12 +196,24 @@ class XGBModel(XGBModelBase):
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
xgb_params.pop('random_state')
xgb_params.pop('n_jobs')
random_state = xgb_params.pop('random_state')
if xgb_params['seed'] is not None:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
else:
xgb_params['seed'] = random_state
n_jobs = xgb_params.pop('n_jobs')
if xgb_params['nthread'] is not None:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
else:
xgb_params['nthread'] = n_jobs
xgb_params['silent'] = 1 if self.silent else 0
if self.nthread <= 0:
if xgb_params['nthread'] <= 0:
xgb_params.pop('nthread', None)
return xgb_params