Sklearn kwargs (#2338)

* Added kwargs support for Sklearn API

* Updated NEWS and CONTRIBUTORS

* Fixed CONTRIBUTORS.md

* Added clarification of **kwargs and test for proper usage

* Fixed lint error

* Fixed more lint errors and clf assigned but never used

* Fixed more lint errors

* Fixed more lint errors

* Fixed issue with changes from different branch bleeding over

* Fixed issue with changes from other branch bleeding over

* Added note that kwargs may not be compatible with Sklearn

* Fixed linting on kwargs note
This commit is contained in:
gaw89
2017-05-23 22:47:53 -04:00
committed by Yuan (Terry) Tang
parent 6cea1e3fb7
commit 0f3a404d91
4 changed files with 38 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ import random
import xgboost as xgb
import testing as tm
import warnings
from nose.tools import raises
rng = np.random.RandomState(1994)
@@ -363,3 +364,22 @@ def test_nthread_deprecation():
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
assert w[0].category == DeprecationWarning
def test_kwargs():
tm._skip_if_no_sklearn()
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_estimators=1000, **params)
assert clf.get_params()['updater'] == 'grow_gpu'
assert clf.get_params()['subsample'] == .5
assert clf.get_params()['n_estimators'] == 1000
@raises(TypeError)
def test_kwargs_error():
tm._skip_if_no_sklearn()
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)