[SKL] Propagate parameters to booster during set_param. (#6416)
This commit is contained in:
parent
cc581b3b6b
commit
2ce2a1a4d8
@ -334,6 +334,10 @@ class XGBModel(XGBModelBase):
|
|||||||
else:
|
else:
|
||||||
self.kwargs[key] = value
|
self.kwargs[key] = value
|
||||||
|
|
||||||
|
if hasattr(self, '_Booster'):
|
||||||
|
parameters = self.get_xgb_params()
|
||||||
|
self.get_booster().set_param(parameters)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_params(self, deep=True):
|
def get_params(self, deep=True):
|
||||||
|
|||||||
@ -542,13 +542,29 @@ def test_sklearn_n_jobs():
|
|||||||
assert clf.get_xgb_params()['n_jobs'] == 2
|
assert clf.get_xgb_params()['n_jobs'] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs():
|
def test_parameters_access():
|
||||||
|
from sklearn import datasets
|
||||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||||
clf = xgb.XGBClassifier(n_estimators=1000, **params)
|
clf = xgb.XGBClassifier(n_estimators=1000, **params)
|
||||||
assert clf.get_params()['updater'] == 'grow_gpu_hist'
|
assert clf.get_params()['updater'] == 'grow_gpu_hist'
|
||||||
assert clf.get_params()['subsample'] == .5
|
assert clf.get_params()['subsample'] == .5
|
||||||
assert clf.get_params()['n_estimators'] == 1000
|
assert clf.get_params()['n_estimators'] == 1000
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=1, nthread=4)
|
||||||
|
X, y = datasets.load_iris(return_X_y=True)
|
||||||
|
clf.fit(X, y)
|
||||||
|
|
||||||
|
config = json.loads(clf.get_booster().save_config())
|
||||||
|
assert int(config['learner']['generic_param']['nthread']) == 4
|
||||||
|
|
||||||
|
clf.set_params(nthread=16)
|
||||||
|
config = json.loads(clf.get_booster().save_config())
|
||||||
|
assert int(config['learner']['generic_param']['nthread']) == 16
|
||||||
|
|
||||||
|
clf.predict(X)
|
||||||
|
config = json.loads(clf.get_booster().save_config())
|
||||||
|
assert int(config['learner']['generic_param']['nthread']) == 16
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs_error():
|
def test_kwargs_error():
|
||||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user