Sklearn convention update (#2323)

* Added n_jobs and random_state to keep up to date with sklearn API.
Deprecated nthread and seed.  Added tests for new params and
deprecations.

* Fixed docstring to reflect updates to n_jobs and random_state.

* Fixed whitespace issues and removed nose import.

* Added deprecation note for nthread and seed in docstring.

* Attempted fix of deprecation tests.

* Second attempted fix to tests.

* Set n_jobs to 1.
This commit is contained in:
gaw89
2017-05-22 09:22:05 -04:00
committed by Yuan (Terry) Tang
parent da1629e848
commit 6cea1e3fb7
2 changed files with 74 additions and 11 deletions

View File

@@ -2,6 +2,7 @@ import numpy as np
import random
import xgboost as xgb
import testing as tm
import warnings
rng = np.random.RandomState(1994)
@@ -326,3 +327,39 @@ def test_split_value_histograms():
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
def test_sklearn_random_state():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_params()['seed'] == 402
clf = xgb.XGBClassifier(seed=401)
assert clf.get_params()['seed'] == 401
def test_seed_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(seed=1)
assert w[0].category == DeprecationWarning
def test_sklearn_n_jobs():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_params()['nthread'] == 1
clf = xgb.XGBClassifier(nthread=2)
assert clf.get_params()['nthread'] == 2
def test_nthread_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
assert w[0].category == DeprecationWarning