Sklearn convention update (#2323)
* Added n_jobs and random_state to keep up to date with sklearn API. Deprecated nthread and seed. Added tests for new params and deprecations. * Fixed docstring to reflect updates to n_jobs and random_state. * Fixed whitespace issues and removed nose import. * Added deprecation note for nthread and seed in docstring. * Attempted fix of deprecation tests. * Second attempted fix to tests. * Set n_jobs to 1.
This commit is contained in:
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import random
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import warnings
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@@ -326,3 +327,39 @@ def test_split_value_histograms():
|
||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||
|
||||
|
||||
def test_sklearn_random_state():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(random_state=402)
|
||||
assert clf.get_params()['seed'] == 402
|
||||
|
||||
clf = xgb.XGBClassifier(seed=401)
|
||||
assert clf.get_params()['seed'] == 401
|
||||
|
||||
|
||||
def test_seed_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(seed=1)
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
def test_sklearn_n_jobs():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=1)
|
||||
assert clf.get_params()['nthread'] == 1
|
||||
|
||||
clf = xgb.XGBClassifier(nthread=2)
|
||||
assert clf.get_params()['nthread'] == 2
|
||||
|
||||
|
||||
def test_nthread_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(nthread=1)
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
Reference in New Issue
Block a user