[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)

* add a testcase causing RuntimeError

* move seed/random_state/nthread/n_jobs check to get_xgb_params()

* fix failed test
This commit is contained in:
wxchan 2017-06-12 21:33:42 +08:00 committed by Yuan (Terry) Tang
parent 41efe32aa5
commit 65d2513714
2 changed files with 34 additions and 28 deletions

View File

@ -155,25 +155,10 @@ class XGBModel(XGBModelBase):
self.missing = missing if missing is not None else np.nan
self.kwargs = kwargs
self._Booster = None
if seed:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
self.seed = seed
self.random_state = seed
else:
self.seed = random_state
self.random_state = random_state
if nthread:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
self.nthread = nthread
self.n_jobs = nthread
else:
self.nthread = n_jobs
self.n_jobs = n_jobs
self.seed = seed
self.random_state = random_state
self.nthread = nthread
self.n_jobs = n_jobs
def __setstate__(self, state):
# backward compatibility code
@ -211,12 +196,24 @@ class XGBModel(XGBModelBase):
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
xgb_params.pop('random_state')
xgb_params.pop('n_jobs')
random_state = xgb_params.pop('random_state')
if xgb_params['seed'] is not None:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
else:
xgb_params['seed'] = random_state
n_jobs = xgb_params.pop('n_jobs')
if xgb_params['nthread'] is not None:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
else:
xgb_params['nthread'] = n_jobs
xgb_params['silent'] = 1 if self.silent else 0
if self.nthread <= 0:
if xgb_params['nthread'] <= 0:
xgb_params.pop('nthread', None)
return xgb_params

View File

@ -334,17 +334,17 @@ def test_sklearn_random_state():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_params()['seed'] == 402
assert clf.get_xgb_params()['seed'] == 402
clf = xgb.XGBClassifier(seed=401)
assert clf.get_params()['seed'] == 401
assert clf.get_xgb_params()['seed'] == 401
def test_seed_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(seed=1)
xgb.XGBClassifier(seed=1).get_xgb_params()
assert w[0].category == DeprecationWarning
@ -352,17 +352,17 @@ def test_sklearn_n_jobs():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_params()['nthread'] == 1
assert clf.get_xgb_params()['nthread'] == 1
clf = xgb.XGBClassifier(nthread=2)
assert clf.get_params()['nthread'] == 2
assert clf.get_xgb_params()['nthread'] == 2
def test_nthread_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
xgb.XGBClassifier(nthread=1).get_xgb_params()
assert w[0].category == DeprecationWarning
@ -383,3 +383,12 @@ def test_kwargs_error():
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)
def test_sklearn_clone():
tm._skip_if_no_sklearn()
from sklearn.base import clone
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
clf.n_jobs = -1
clone(clf)