[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)
* add a testcase causing RuntimeError * move seed/random_state/nthread/n_jobs check to get_xgb_params() * fix failed test
This commit is contained in:
parent
41efe32aa5
commit
65d2513714
@ -155,25 +155,10 @@ class XGBModel(XGBModelBase):
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.kwargs = kwargs
|
||||
self._Booster = None
|
||||
if seed:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
self.seed = seed
|
||||
self.random_state = seed
|
||||
else:
|
||||
self.seed = random_state
|
||||
self.random_state = random_state
|
||||
|
||||
if nthread:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
self.nthread = nthread
|
||||
self.n_jobs = nthread
|
||||
else:
|
||||
self.nthread = n_jobs
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
self.random_state = random_state
|
||||
self.nthread = nthread
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatibility code
|
||||
@ -211,12 +196,24 @@ class XGBModel(XGBModelBase):
|
||||
def get_xgb_params(self):
|
||||
"""Get xgboost type parameters."""
|
||||
xgb_params = self.get_params()
|
||||
xgb_params.pop('random_state')
|
||||
xgb_params.pop('n_jobs')
|
||||
random_state = xgb_params.pop('random_state')
|
||||
if xgb_params['seed'] is not None:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['seed'] = random_state
|
||||
n_jobs = xgb_params.pop('n_jobs')
|
||||
if xgb_params['nthread'] is not None:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['nthread'] = n_jobs
|
||||
|
||||
xgb_params['silent'] = 1 if self.silent else 0
|
||||
|
||||
if self.nthread <= 0:
|
||||
if xgb_params['nthread'] <= 0:
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
|
||||
@ -334,17 +334,17 @@ def test_sklearn_random_state():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(random_state=402)
|
||||
assert clf.get_params()['seed'] == 402
|
||||
assert clf.get_xgb_params()['seed'] == 402
|
||||
|
||||
clf = xgb.XGBClassifier(seed=401)
|
||||
assert clf.get_params()['seed'] == 401
|
||||
assert clf.get_xgb_params()['seed'] == 401
|
||||
|
||||
|
||||
def test_seed_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(seed=1)
|
||||
xgb.XGBClassifier(seed=1).get_xgb_params()
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
@ -352,17 +352,17 @@ def test_sklearn_n_jobs():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=1)
|
||||
assert clf.get_params()['nthread'] == 1
|
||||
assert clf.get_xgb_params()['nthread'] == 1
|
||||
|
||||
clf = xgb.XGBClassifier(nthread=2)
|
||||
assert clf.get_params()['nthread'] == 2
|
||||
assert clf.get_xgb_params()['nthread'] == 2
|
||||
|
||||
|
||||
def test_nthread_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(nthread=1)
|
||||
xgb.XGBClassifier(nthread=1).get_xgb_params()
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
@ -383,3 +383,12 @@ def test_kwargs_error():
|
||||
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||
assert isinstance(clf, xgb.XGBClassifier)
|
||||
|
||||
|
||||
def test_sklearn_clone():
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.base import clone
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
||||
clf.n_jobs = -1
|
||||
clone(clf)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user