Sklearn kwargs (#2338)
* Added kwargs support for Sklearn API * Updated NEWS and CONTRIBUTORS * Fixed CONTRIBUTORS.md * Added clarification of **kwargs and test for proper usage * Fixed lint error * Fixed more lint errors and clf assigned but never used * Fixed more lint errors * Fixed more lint errors * Fixed issue with changes from different branch bleeding over * Fixed issue with changes from other branch bleeding over * Added note that kwargs may not be compatible with Sklearn * Fixed linting on kwargs note
This commit is contained in:
@@ -3,6 +3,7 @@ import random
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import warnings
|
||||
from nose.tools import raises
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@@ -363,3 +364,22 @@ def test_nthread_deprecation():
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(nthread=1)
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
def test_kwargs():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||
clf = xgb.XGBClassifier(n_estimators=1000, **params)
|
||||
assert clf.get_params()['updater'] == 'grow_gpu'
|
||||
assert clf.get_params()['subsample'] == .5
|
||||
assert clf.get_params()['n_estimators'] == 1000
|
||||
|
||||
|
||||
@raises(TypeError)
|
||||
def test_kwargs_error():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||
assert isinstance(clf, xgb.XGBClassifier)
|
||||
|
||||
Reference in New Issue
Block a user