[python-package] fix sklearn n_jobs/nthreads and seed/random_state bug (#2378)
* add a testcase causing RuntimeError * move seed/random_state/nthread/n_jobs check to get_xgb_params() * fix failed test
This commit is contained in:
committed by
Yuan (Terry) Tang
parent
41efe32aa5
commit
65d2513714
@@ -334,17 +334,17 @@ def test_sklearn_random_state():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(random_state=402)
|
||||
assert clf.get_params()['seed'] == 402
|
||||
assert clf.get_xgb_params()['seed'] == 402
|
||||
|
||||
clf = xgb.XGBClassifier(seed=401)
|
||||
assert clf.get_params()['seed'] == 401
|
||||
assert clf.get_xgb_params()['seed'] == 401
|
||||
|
||||
|
||||
def test_seed_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(seed=1)
|
||||
xgb.XGBClassifier(seed=1).get_xgb_params()
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
@@ -352,17 +352,17 @@ def test_sklearn_n_jobs():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=1)
|
||||
assert clf.get_params()['nthread'] == 1
|
||||
assert clf.get_xgb_params()['nthread'] == 1
|
||||
|
||||
clf = xgb.XGBClassifier(nthread=2)
|
||||
assert clf.get_params()['nthread'] == 2
|
||||
assert clf.get_xgb_params()['nthread'] == 2
|
||||
|
||||
|
||||
def test_nthread_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(nthread=1)
|
||||
xgb.XGBClassifier(nthread=1).get_xgb_params()
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
@@ -383,3 +383,12 @@ def test_kwargs_error():
|
||||
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||
assert isinstance(clf, xgb.XGBClassifier)
|
||||
|
||||
|
||||
def test_sklearn_clone():
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.base import clone
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
||||
clf.n_jobs = -1
|
||||
clone(clf)
|
||||
|
||||
Reference in New Issue
Block a user