[backport] Do not return internal value for get_params. (#8634) (#8642)

This commit is contained in:
Jiaming Yuan 2023-01-06 02:28:39 +08:00 committed by GitHub
parent a2085bf223
commit 899e4c8988
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 45 deletions

View File

@ -674,7 +674,7 @@ class XGBModel(XGBModelBase):
self.kwargs = {} self.kwargs = {}
self.kwargs[key] = value self.kwargs[key] = value
if hasattr(self, "_Booster"): if self.__sklearn_is_fitted__():
parameters = self.get_xgb_params() parameters = self.get_xgb_params()
self.get_booster().set_param(parameters) self.get_booster().set_param(parameters)
@ -701,39 +701,12 @@ class XGBModel(XGBModelBase):
np.iinfo(np.int32).max np.iinfo(np.int32).max
) )
def parse_parameter(value: Any) -> Optional[Union[int, float, str]]:
for t in (int, float, str):
try:
ret = t(value)
return ret
except ValueError:
continue
return None
# Get internal parameter values
try:
config = json.loads(self.get_booster().save_config())
stack = [config]
internal = {}
while stack:
obj = stack.pop()
for k, v in obj.items():
if k.endswith("_param"):
for p_k, p_v in v.items():
internal[p_k] = p_v
elif isinstance(v, dict):
stack.append(v)
for k, v in internal.items():
if k in params and params[k] is None:
params[k] = parse_parameter(v)
except ValueError:
pass
return params return params
def get_xgb_params(self) -> Dict[str, Any]: def get_xgb_params(self) -> Dict[str, Any]:
"""Get xgboost specific parameters.""" """Get xgboost specific parameters."""
params = self.get_params() params: Dict[str, Any] = self.get_params()
# Parameters that should not go into native learner. # Parameters that should not go into native learner.
wrapper_specific = { wrapper_specific = {
"importance_type", "importance_type",
@ -750,6 +723,7 @@ class XGBModel(XGBModelBase):
for k, v in params.items(): for k, v in params.items():
if k not in wrapper_specific and not callable(v): if k not in wrapper_specific and not callable(v):
filtered[k] = v filtered[k] = v
return filtered return filtered
def get_num_boosting_rounds(self) -> int: def get_num_boosting_rounds(self) -> int:
@ -1070,7 +1044,7 @@ class XGBModel(XGBModelBase):
# error with incompatible data type. # error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's # Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler. # sufficient for dask interface where input is simpiler.
predictor = self.get_params().get("predictor", None) predictor = self.get_xgb_params().get("predictor", None)
if predictor in ("auto", None) and self.booster != "gblinear": if predictor in ("auto", None) and self.booster != "gblinear":
return True return True
return False return False
@ -1336,7 +1310,7 @@ class XGBModel(XGBModelBase):
------- -------
coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]`` coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]``
""" """
if self.get_params()["booster"] != "gblinear": if self.get_xgb_params()["booster"] != "gblinear":
raise AttributeError( raise AttributeError(
f"Coefficients are not defined for Booster type {self.booster}" f"Coefficients are not defined for Booster type {self.booster}"
) )
@ -1366,7 +1340,7 @@ class XGBModel(XGBModelBase):
------- -------
intercept_ : array of shape ``(1,)`` or ``[n_classes]`` intercept_ : array of shape ``(1,)`` or ``[n_classes]``
""" """
if self.get_params()["booster"] != "gblinear": if self.get_xgb_params()["booster"] != "gblinear":
raise AttributeError( raise AttributeError(
f"Intercept (bias) is not defined for Booster type {self.booster}" f"Intercept (bias) is not defined for Booster type {self.booster}"
) )

View File

@ -112,7 +112,6 @@ class TestPandas:
# test Index as columns # test Index as columns
df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2])) df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2]))
print(df.columns, isinstance(df.columns, pd.Index))
Xy = xgb.DMatrix(df) Xy = xgb.DMatrix(df)
np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"])) np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"]))

View File

@ -2,6 +2,7 @@ import collections
import importlib.util import importlib.util
import json import json
import os import os
import pickle
import random import random
import tempfile import tempfile
from typing import Callable, Optional from typing import Callable, Optional
@ -636,26 +637,74 @@ def test_sklearn_n_jobs():
def test_parameters_access(): def test_parameters_access():
from sklearn import datasets from sklearn import datasets
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
params = {"updater": "grow_gpu_hist", "subsample": 0.5, "n_jobs": -1}
clf = xgb.XGBClassifier(n_estimators=1000, **params) clf = xgb.XGBClassifier(n_estimators=1000, **params)
assert clf.get_params()['updater'] == 'grow_gpu_hist' assert clf.get_params()["updater"] == "grow_gpu_hist"
assert clf.get_params()['subsample'] == .5 assert clf.get_params()["subsample"] == 0.5
assert clf.get_params()['n_estimators'] == 1000 assert clf.get_params()["n_estimators"] == 1000
clf = xgb.XGBClassifier(n_estimators=1, nthread=4) clf = xgb.XGBClassifier(n_estimators=1, nthread=4)
X, y = datasets.load_iris(return_X_y=True) X, y = datasets.load_iris(return_X_y=True)
clf.fit(X, y) clf.fit(X, y)
config = json.loads(clf.get_booster().save_config()) config = json.loads(clf.get_booster().save_config())
assert int(config['learner']['generic_param']['nthread']) == 4 assert int(config["learner"]["generic_param"]["nthread"]) == 4
clf.set_params(nthread=16) clf.set_params(nthread=16)
config = json.loads(clf.get_booster().save_config()) config = json.loads(clf.get_booster().save_config())
assert int(config['learner']['generic_param']['nthread']) == 16 assert int(config["learner"]["generic_param"]["nthread"]) == 16
clf.predict(X) clf.predict(X)
config = json.loads(clf.get_booster().save_config()) config = json.loads(clf.get_booster().save_config())
assert int(config['learner']['generic_param']['nthread']) == 16 assert int(config["learner"]["generic_param"]["nthread"]) == 16
clf = xgb.XGBClassifier(n_estimators=2)
assert clf.tree_method is None
assert clf.get_params()["tree_method"] is None
clf.fit(X, y)
assert clf.get_params()["tree_method"] is None
def save_load(clf: xgb.XGBClassifier) -> xgb.XGBClassifier:
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.json")
clf.save_model(path)
clf = xgb.XGBClassifier()
clf.load_model(path)
return clf
def get_tm(clf: xgb.XGBClassifier) -> str:
tm = json.loads(clf.get_booster().save_config())["learner"]["gradient_booster"][
"gbtree_train_param"
]["tree_method"]
return tm
assert get_tm(clf) == "exact"
clf = pickle.loads(pickle.dumps(clf))
assert clf.tree_method is None
assert clf.n_estimators == 2
assert clf.get_params()["tree_method"] is None
assert clf.get_params()["n_estimators"] == 2
assert get_tm(clf) == "exact" # preserved for pickle
clf = save_load(clf)
assert clf.tree_method is None
assert clf.n_estimators == 2
assert clf.get_params()["tree_method"] is None
assert clf.get_params()["n_estimators"] == 2
assert get_tm(clf) == "auto" # discarded for save/load_model
clf.set_params(tree_method="hist")
assert clf.get_params()["tree_method"] == "hist"
clf = pickle.loads(pickle.dumps(clf))
assert clf.get_params()["tree_method"] == "hist"
clf = save_load(clf)
# FIXME(jiamingy): We should remove this behavior once we remove parameters
# serialization for skl save/load_model.
assert clf.get_params()["tree_method"] == "hist"
def test_kwargs_error(): def test_kwargs_error():
@ -695,13 +744,19 @@ def test_sklearn_clone():
def test_sklearn_get_default_params(): def test_sklearn_get_default_params():
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
digits_2class = load_digits(n_class=2) digits_2class = load_digits(n_class=2)
X = digits_2class['data'] X = digits_2class["data"]
y = digits_2class['target'] y = digits_2class["target"]
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
assert cls.get_params()['base_score'] is None assert cls.get_params()["base_score"] is None
cls.fit(X[:4, ...], y[:4, ...]) cls.fit(X[:4, ...], y[:4, ...])
assert cls.get_params()['base_score'] is not None base_score = float(
json.loads(cls.get_booster().save_config())["learner"]["learner_model_param"][
"base_score"
]
)
np.testing.assert_equal(base_score, 0.5)
def run_validation_weights(model): def run_validation_weights(model):