[sklearn] Fix loading model attributes. (#9808)

This commit is contained in:
Jiaming Yuan 2023-11-27 17:19:01 +08:00 committed by GitHub
parent 3f4e22015a
commit e9f149481e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 42 deletions

View File

@ -79,7 +79,6 @@ from xgboost.data import _is_cudf_ser, _is_cupy_array
from xgboost.sklearn import (
XGBClassifier,
XGBClassifierBase,
XGBClassifierMixIn,
XGBModel,
XGBRanker,
XGBRankerMixIn,
@ -1863,7 +1862,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
"Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"],
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self,
@ -2045,10 +2044,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Ranking.

View File

@ -43,19 +43,6 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
from .training import train
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
"""MixIn for classification."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def _load_model_attributes(self, booster: Booster) -> None:
config = json.loads(booster.save_config())
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
# binary classification is treated as regression in XGBoost.
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
base classes.
@ -850,9 +837,7 @@ class XGBModel(XGBModelBase):
self.get_booster().load_model(fname)
meta_str = self.get_booster().attr("scikit_learn")
if meta_str is None:
return
if meta_str is not None:
meta = json.loads(meta_str)
t = meta.get("_estimator_type", None)
if t is not None and t != self._get_type():
@ -860,11 +845,30 @@ class XGBModel(XGBModelBase):
"Loading an estimator with different type. Expecting: "
f"{self._get_type()}, got: {t}"
)
self.feature_types = self.get_booster().feature_types
self.get_booster().set_attr(scikit_learn=None)
config = json.loads(self.get_booster().save_config())
self._load_model_attributes(config)
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
def _load_model_attributes(self, config: dict) -> None:
"""Load model attributes without hyper-parameters."""
from sklearn.base import is_classifier
booster = self.get_booster()
self.objective = config["learner"]["objective"]["name"]
self.booster = config["learner"]["gradient_booster"]["name"]
self.base_score = config["learner"]["learner_model_param"]["base_score"]
self.feature_types = booster.feature_types
if is_classifier(self):
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
# binary classification is treated as regression in XGBoost.
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
# pylint: disable=too-many-branches
def _configure_fit(
self,
@ -1414,7 +1418,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
Number of boosting rounds.
""",
)
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args
def __init__(
@ -1642,10 +1646,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
def classes_(self) -> np.ndarray:
return np.arange(self.n_classes_)
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.",

View File

@ -944,6 +944,7 @@ def save_load_model(model_path):
predt_0 = clf.predict(X)
clf.save_model(model_path)
clf.load_model(model_path)
assert clf.booster == "gblinear"
predt_1 = clf.predict(X)
np.testing.assert_allclose(predt_0, predt_1)
assert clf.best_iteration == best_iteration
@ -959,25 +960,26 @@ def save_load_model(model_path):
def test_save_load_model():
with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model')
model_path = os.path.join(tempdir, "digits.model")
save_load_model(model_path)
with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json')
model_path = os.path.join(tempdir, "digits.model.json")
save_load_model(model_path)
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.ubj')
model_path = os.path.join(tempdir, "digits.model.ubj")
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
booster = xgb.train({'tree_method': 'hist',
'objective': 'binary:logistic'},
y = digits["target"]
X = digits["data"]
booster = xgb.train(
{"tree_method": "hist", "objective": "binary:logistic"},
dtrain=xgb.DMatrix(X, y),
num_boost_round=4)
num_boost_round=4,
)
predt_0 = booster.predict(xgb.DMatrix(X))
booster.save_model(model_path)
cls = xgb.XGBClassifier()
@ -1011,6 +1013,8 @@ def test_save_load_model():
clf = xgb.XGBClassifier()
clf.load_model(model_path)
assert clf.classes_.size == 10
assert clf.objective == "multi:softprob"
np.testing.assert_equal(clf.classes_, np.arange(10))
assert clf.n_classes_ == 10

View File

@ -1931,6 +1931,7 @@ class TestWithDask:
cls.client = client
cls.fit(X, y)
predt_0 = cls.predict(X)
proba_0 = cls.predict_proba(X)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.pkl")
@ -1940,7 +1941,9 @@ class TestWithDask:
with open(path, "rb") as fd:
cls = pickle.load(fd)
predt_1 = cls.predict(X)
proba_1 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
np.testing.assert_allclose(proba_0.compute(), proba_1.compute())
path = os.path.join(tmpdir, "cls.json")
cls.save_model(path)
@ -1949,16 +1952,20 @@ class TestWithDask:
cls.load_model(path)
assert cls.n_classes_ == 10
predt_2 = cls.predict(X)
proba_2 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
np.testing.assert_allclose(proba_0.compute(), proba_2.compute())
# Use single node to load
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_3 = cls.predict(X_)
proba_3 = cls.predict_proba(X_)
np.testing.assert_allclose(predt_0.compute(), predt_3)
np.testing.assert_allclose(proba_0.compute(), proba_3)
def test_dask_unsupported_features(client: "Client") -> None: