This commit is contained in:
parent
a2085bf223
commit
899e4c8988
@ -674,7 +674,7 @@ class XGBModel(XGBModelBase):
|
||||
self.kwargs = {}
|
||||
self.kwargs[key] = value
|
||||
|
||||
if hasattr(self, "_Booster"):
|
||||
if self.__sklearn_is_fitted__():
|
||||
parameters = self.get_xgb_params()
|
||||
self.get_booster().set_param(parameters)
|
||||
|
||||
@ -701,39 +701,12 @@ class XGBModel(XGBModelBase):
|
||||
np.iinfo(np.int32).max
|
||||
)
|
||||
|
||||
def parse_parameter(value: Any) -> Optional[Union[int, float, str]]:
|
||||
for t in (int, float, str):
|
||||
try:
|
||||
ret = t(value)
|
||||
return ret
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
# Get internal parameter values
|
||||
try:
|
||||
config = json.loads(self.get_booster().save_config())
|
||||
stack = [config]
|
||||
internal = {}
|
||||
while stack:
|
||||
obj = stack.pop()
|
||||
for k, v in obj.items():
|
||||
if k.endswith("_param"):
|
||||
for p_k, p_v in v.items():
|
||||
internal[p_k] = p_v
|
||||
elif isinstance(v, dict):
|
||||
stack.append(v)
|
||||
|
||||
for k, v in internal.items():
|
||||
if k in params and params[k] is None:
|
||||
params[k] = parse_parameter(v)
|
||||
except ValueError:
|
||||
pass
|
||||
return params
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
"""Get xgboost specific parameters."""
|
||||
params = self.get_params()
|
||||
params: Dict[str, Any] = self.get_params()
|
||||
|
||||
# Parameters that should not go into native learner.
|
||||
wrapper_specific = {
|
||||
"importance_type",
|
||||
@ -750,6 +723,7 @@ class XGBModel(XGBModelBase):
|
||||
for k, v in params.items():
|
||||
if k not in wrapper_specific and not callable(v):
|
||||
filtered[k] = v
|
||||
|
||||
return filtered
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
@ -1070,7 +1044,7 @@ class XGBModel(XGBModelBase):
|
||||
# error with incompatible data type.
|
||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||
# sufficient for dask interface where input is simpiler.
|
||||
predictor = self.get_params().get("predictor", None)
|
||||
predictor = self.get_xgb_params().get("predictor", None)
|
||||
if predictor in ("auto", None) and self.booster != "gblinear":
|
||||
return True
|
||||
return False
|
||||
@ -1336,7 +1310,7 @@ class XGBModel(XGBModelBase):
|
||||
-------
|
||||
coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]``
|
||||
"""
|
||||
if self.get_params()["booster"] != "gblinear":
|
||||
if self.get_xgb_params()["booster"] != "gblinear":
|
||||
raise AttributeError(
|
||||
f"Coefficients are not defined for Booster type {self.booster}"
|
||||
)
|
||||
@ -1366,7 +1340,7 @@ class XGBModel(XGBModelBase):
|
||||
-------
|
||||
intercept_ : array of shape ``(1,)`` or ``[n_classes]``
|
||||
"""
|
||||
if self.get_params()["booster"] != "gblinear":
|
||||
if self.get_xgb_params()["booster"] != "gblinear":
|
||||
raise AttributeError(
|
||||
f"Intercept (bias) is not defined for Booster type {self.booster}"
|
||||
)
|
||||
|
||||
@ -112,7 +112,6 @@ class TestPandas:
|
||||
|
||||
# test Index as columns
|
||||
df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2]))
|
||||
print(df.columns, isinstance(df.columns, pd.Index))
|
||||
Xy = xgb.DMatrix(df)
|
||||
np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"]))
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import collections
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Callable, Optional
|
||||
@ -636,26 +637,74 @@ def test_sklearn_n_jobs():
|
||||
|
||||
def test_parameters_access():
|
||||
from sklearn import datasets
|
||||
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
|
||||
|
||||
params = {"updater": "grow_gpu_hist", "subsample": 0.5, "n_jobs": -1}
|
||||
clf = xgb.XGBClassifier(n_estimators=1000, **params)
|
||||
assert clf.get_params()['updater'] == 'grow_gpu_hist'
|
||||
assert clf.get_params()['subsample'] == .5
|
||||
assert clf.get_params()['n_estimators'] == 1000
|
||||
assert clf.get_params()["updater"] == "grow_gpu_hist"
|
||||
assert clf.get_params()["subsample"] == 0.5
|
||||
assert clf.get_params()["n_estimators"] == 1000
|
||||
|
||||
clf = xgb.XGBClassifier(n_estimators=1, nthread=4)
|
||||
X, y = datasets.load_iris(return_X_y=True)
|
||||
clf.fit(X, y)
|
||||
|
||||
config = json.loads(clf.get_booster().save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 4
|
||||
assert int(config["learner"]["generic_param"]["nthread"]) == 4
|
||||
|
||||
clf.set_params(nthread=16)
|
||||
config = json.loads(clf.get_booster().save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 16
|
||||
assert int(config["learner"]["generic_param"]["nthread"]) == 16
|
||||
|
||||
clf.predict(X)
|
||||
config = json.loads(clf.get_booster().save_config())
|
||||
assert int(config['learner']['generic_param']['nthread']) == 16
|
||||
assert int(config["learner"]["generic_param"]["nthread"]) == 16
|
||||
|
||||
clf = xgb.XGBClassifier(n_estimators=2)
|
||||
assert clf.tree_method is None
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
clf.fit(X, y)
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
|
||||
def save_load(clf: xgb.XGBClassifier) -> xgb.XGBClassifier:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "model.json")
|
||||
clf.save_model(path)
|
||||
clf = xgb.XGBClassifier()
|
||||
clf.load_model(path)
|
||||
return clf
|
||||
|
||||
def get_tm(clf: xgb.XGBClassifier) -> str:
|
||||
tm = json.loads(clf.get_booster().save_config())["learner"]["gradient_booster"][
|
||||
"gbtree_train_param"
|
||||
]["tree_method"]
|
||||
return tm
|
||||
|
||||
assert get_tm(clf) == "exact"
|
||||
|
||||
clf = pickle.loads(pickle.dumps(clf))
|
||||
|
||||
assert clf.tree_method is None
|
||||
assert clf.n_estimators == 2
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
assert clf.get_params()["n_estimators"] == 2
|
||||
assert get_tm(clf) == "exact" # preserved for pickle
|
||||
|
||||
clf = save_load(clf)
|
||||
|
||||
assert clf.tree_method is None
|
||||
assert clf.n_estimators == 2
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
assert clf.get_params()["n_estimators"] == 2
|
||||
assert get_tm(clf) == "auto" # discarded for save/load_model
|
||||
|
||||
clf.set_params(tree_method="hist")
|
||||
assert clf.get_params()["tree_method"] == "hist"
|
||||
clf = pickle.loads(pickle.dumps(clf))
|
||||
assert clf.get_params()["tree_method"] == "hist"
|
||||
clf = save_load(clf)
|
||||
# FIXME(jiamingy): We should remove this behavior once we remove parameters
|
||||
# serialization for skl save/load_model.
|
||||
assert clf.get_params()["tree_method"] == "hist"
|
||||
|
||||
|
||||
def test_kwargs_error():
|
||||
@ -695,13 +744,19 @@ def test_sklearn_clone():
|
||||
|
||||
def test_sklearn_get_default_params():
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
digits_2class = load_digits(n_class=2)
|
||||
X = digits_2class['data']
|
||||
y = digits_2class['target']
|
||||
X = digits_2class["data"]
|
||||
y = digits_2class["target"]
|
||||
cls = xgb.XGBClassifier()
|
||||
assert cls.get_params()['base_score'] is None
|
||||
assert cls.get_params()["base_score"] is None
|
||||
cls.fit(X[:4, ...], y[:4, ...])
|
||||
assert cls.get_params()['base_score'] is not None
|
||||
base_score = float(
|
||||
json.loads(cls.get_booster().save_config())["learner"]["learner_model_param"][
|
||||
"base_score"
|
||||
]
|
||||
)
|
||||
np.testing.assert_equal(base_score, 0.5)
|
||||
|
||||
|
||||
def run_validation_weights(model):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user