[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)
* add a testcase causing RuntimeError * move seed/random_state/nthread/n_jobs check to get_xgb_params() * fix failed test
This commit is contained in:
parent
41efe32aa5
commit
65d2513714
@ -155,25 +155,10 @@ class XGBModel(XGBModelBase):
|
|||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self._Booster = None
|
self._Booster = None
|
||||||
if seed:
|
self.seed = seed
|
||||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
self.random_state = random_state
|
||||||
'Please use random_state instead.'
|
self.nthread = nthread
|
||||||
'seed is deprecated.', DeprecationWarning)
|
self.n_jobs = n_jobs
|
||||||
self.seed = seed
|
|
||||||
self.random_state = seed
|
|
||||||
else:
|
|
||||||
self.seed = random_state
|
|
||||||
self.random_state = random_state
|
|
||||||
|
|
||||||
if nthread:
|
|
||||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
|
||||||
'Please use n_jobs instead.'
|
|
||||||
'nthread is deprecated.', DeprecationWarning)
|
|
||||||
self.nthread = nthread
|
|
||||||
self.n_jobs = nthread
|
|
||||||
else:
|
|
||||||
self.nthread = n_jobs
|
|
||||||
self.n_jobs = n_jobs
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# backward compatibility code
|
# backward compatibility code
|
||||||
@ -211,12 +196,24 @@ class XGBModel(XGBModelBase):
|
|||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
"""Get xgboost type parameters."""
|
"""Get xgboost type parameters."""
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
xgb_params.pop('random_state')
|
random_state = xgb_params.pop('random_state')
|
||||||
xgb_params.pop('n_jobs')
|
if xgb_params['seed'] is not None:
|
||||||
|
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||||
|
'Please use random_state instead.'
|
||||||
|
'seed is deprecated.', DeprecationWarning)
|
||||||
|
else:
|
||||||
|
xgb_params['seed'] = random_state
|
||||||
|
n_jobs = xgb_params.pop('n_jobs')
|
||||||
|
if xgb_params['nthread'] is not None:
|
||||||
|
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||||
|
'Please use n_jobs instead.'
|
||||||
|
'nthread is deprecated.', DeprecationWarning)
|
||||||
|
else:
|
||||||
|
xgb_params['nthread'] = n_jobs
|
||||||
|
|
||||||
xgb_params['silent'] = 1 if self.silent else 0
|
xgb_params['silent'] = 1 if self.silent else 0
|
||||||
|
|
||||||
if self.nthread <= 0:
|
if xgb_params['nthread'] <= 0:
|
||||||
xgb_params.pop('nthread', None)
|
xgb_params.pop('nthread', None)
|
||||||
return xgb_params
|
return xgb_params
|
||||||
|
|
||||||
|
|||||||
@ -334,17 +334,17 @@ def test_sklearn_random_state():
|
|||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(random_state=402)
|
clf = xgb.XGBClassifier(random_state=402)
|
||||||
assert clf.get_params()['seed'] == 402
|
assert clf.get_xgb_params()['seed'] == 402
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(seed=401)
|
clf = xgb.XGBClassifier(seed=401)
|
||||||
assert clf.get_params()['seed'] == 401
|
assert clf.get_xgb_params()['seed'] == 401
|
||||||
|
|
||||||
|
|
||||||
def test_seed_deprecation():
|
def test_seed_deprecation():
|
||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
xgb.XGBClassifier(seed=1)
|
xgb.XGBClassifier(seed=1).get_xgb_params()
|
||||||
assert w[0].category == DeprecationWarning
|
assert w[0].category == DeprecationWarning
|
||||||
|
|
||||||
|
|
||||||
@ -352,17 +352,17 @@ def test_sklearn_n_jobs():
|
|||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(n_jobs=1)
|
clf = xgb.XGBClassifier(n_jobs=1)
|
||||||
assert clf.get_params()['nthread'] == 1
|
assert clf.get_xgb_params()['nthread'] == 1
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(nthread=2)
|
clf = xgb.XGBClassifier(nthread=2)
|
||||||
assert clf.get_params()['nthread'] == 2
|
assert clf.get_xgb_params()['nthread'] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_nthread_deprecation():
|
def test_nthread_deprecation():
|
||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
xgb.XGBClassifier(nthread=1)
|
xgb.XGBClassifier(nthread=1).get_xgb_params()
|
||||||
assert w[0].category == DeprecationWarning
|
assert w[0].category == DeprecationWarning
|
||||||
|
|
||||||
|
|
||||||
@ -383,3 +383,12 @@ def test_kwargs_error():
|
|||||||
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||||
assert isinstance(clf, xgb.XGBClassifier)
|
assert isinstance(clf, xgb.XGBClassifier)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sklearn_clone():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.base import clone
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
||||||
|
clf.n_jobs = -1
|
||||||
|
clone(clf)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user