[backport][sklearn] Fix loading model attributes. (#9808) (#9880)

This commit is contained in:
Jiaming Yuan 2023-12-13 14:20:04 +08:00 committed by GitHub
parent 41ce8f28b2
commit e4ee4e79dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 42 deletions

View File

@ -78,7 +78,6 @@ from .data import _is_cudf_ser, _is_cupy_array
from .sklearn import ( from .sklearn import (
XGBClassifier, XGBClassifier,
XGBClassifierBase, XGBClassifierBase,
XGBClassifierMixIn,
XGBModel, XGBModel,
XGBRanker, XGBRanker,
XGBRankerMixIn, XGBRankerMixIn,
@ -1854,7 +1853,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
"Implementation of the scikit-learn API for XGBoost classification.", "Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"], ["estimators", "model"],
) )
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
async def _fit_async( async def _fit_async(
self, self,
@ -2036,10 +2035,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1) preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds return preds
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc( @xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Ranking. """Implementation of the Scikit-Learn API for XGBoost Ranking.

View File

@ -43,19 +43,6 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
from .training import train from .training import train
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
"""MixIn for classification."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def _load_model_attributes(self, booster: Booster) -> None:
config = json.loads(booster.save_config())
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
# binary classification is treated as regression in XGBoost.
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
class XGBRankerMixIn: # pylint: disable=too-few-public-methods class XGBRankerMixIn: # pylint: disable=too-few-public-methods
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
base classes. base classes.
@ -845,21 +832,38 @@ class XGBModel(XGBModelBase):
self.get_booster().load_model(fname) self.get_booster().load_model(fname)
meta_str = self.get_booster().attr("scikit_learn") meta_str = self.get_booster().attr("scikit_learn")
if meta_str is None: if meta_str is not None:
return meta = json.loads(meta_str)
t = meta.get("_estimator_type", None)
if t is not None and t != self._get_type():
raise TypeError(
"Loading an estimator with different type. Expecting: "
f"{self._get_type()}, got: {t}"
)
meta = json.loads(meta_str)
t = meta.get("_estimator_type", None)
if t is not None and t != self._get_type():
raise TypeError(
"Loading an estimator with different type. Expecting: "
f"{self._get_type()}, got: {t}"
)
self.feature_types = self.get_booster().feature_types self.feature_types = self.get_booster().feature_types
self.get_booster().set_attr(scikit_learn=None) self.get_booster().set_attr(scikit_learn=None)
config = json.loads(self.get_booster().save_config())
self._load_model_attributes(config)
load_model.__doc__ = f"""{Booster.load_model.__doc__}""" load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
def _load_model_attributes(self, config: dict) -> None:
"""Load model attributes without hyper-parameters."""
from sklearn.base import is_classifier
booster = self.get_booster()
self.objective = config["learner"]["objective"]["name"]
self.booster = config["learner"]["gradient_booster"]["name"]
self.base_score = config["learner"]["learner_model_param"]["base_score"]
self.feature_types = booster.feature_types
if is_classifier(self):
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
# binary classification is treated as regression in XGBoost.
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def _configure_fit( def _configure_fit(
self, self,
@ -1409,7 +1413,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
Number of boosting rounds. Number of boosting rounds.
""", """,
) )
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args @_deprecate_positional_args
def __init__( def __init__(
@ -1637,10 +1641,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
def classes_(self) -> np.ndarray: def classes_(self) -> np.ndarray:
return np.arange(self.n_classes_) return np.arange(self.n_classes_)
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc( @xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.", "scikit-learn API for XGBoost random forest classification.",

View File

@ -940,6 +940,7 @@ def save_load_model(model_path):
predt_0 = clf.predict(X) predt_0 = clf.predict(X)
clf.save_model(model_path) clf.save_model(model_path)
clf.load_model(model_path) clf.load_model(model_path)
assert clf.booster == "gblinear"
predt_1 = clf.predict(X) predt_1 = clf.predict(X)
np.testing.assert_allclose(predt_0, predt_1) np.testing.assert_allclose(predt_0, predt_1)
assert clf.best_iteration == best_iteration assert clf.best_iteration == best_iteration
@ -955,25 +956,26 @@ def save_load_model(model_path):
def test_save_load_model(): def test_save_load_model():
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model') model_path = os.path.join(tempdir, "digits.model")
save_load_model(model_path) save_load_model(model_path)
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json') model_path = os.path.join(tempdir, "digits.model.json")
save_load_model(model_path) save_load_model(model_path)
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.ubj') model_path = os.path.join(tempdir, "digits.model.ubj")
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
y = digits['target'] y = digits["target"]
X = digits['data'] X = digits["data"]
booster = xgb.train({'tree_method': 'hist', booster = xgb.train(
'objective': 'binary:logistic'}, {"tree_method": "hist", "objective": "binary:logistic"},
dtrain=xgb.DMatrix(X, y), dtrain=xgb.DMatrix(X, y),
num_boost_round=4) num_boost_round=4,
)
predt_0 = booster.predict(xgb.DMatrix(X)) predt_0 = booster.predict(xgb.DMatrix(X))
booster.save_model(model_path) booster.save_model(model_path)
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
@ -1007,6 +1009,8 @@ def test_save_load_model():
clf = xgb.XGBClassifier() clf = xgb.XGBClassifier()
clf.load_model(model_path) clf.load_model(model_path)
assert clf.classes_.size == 10 assert clf.classes_.size == 10
assert clf.objective == "multi:softprob"
np.testing.assert_equal(clf.classes_, np.arange(10)) np.testing.assert_equal(clf.classes_, np.arange(10))
assert clf.n_classes_ == 10 assert clf.n_classes_ == 10

View File

@ -1932,6 +1932,7 @@ class TestWithDask:
cls.client = client cls.client = client
cls.fit(X, y) cls.fit(X, y)
predt_0 = cls.predict(X) predt_0 = cls.predict(X)
proba_0 = cls.predict_proba(X)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.pkl") path = os.path.join(tmpdir, "model.pkl")
@ -1941,7 +1942,9 @@ class TestWithDask:
with open(path, "rb") as fd: with open(path, "rb") as fd:
cls = pickle.load(fd) cls = pickle.load(fd)
predt_1 = cls.predict(X) predt_1 = cls.predict(X)
proba_1 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute()) np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
np.testing.assert_allclose(proba_0.compute(), proba_1.compute())
path = os.path.join(tmpdir, "cls.json") path = os.path.join(tmpdir, "cls.json")
cls.save_model(path) cls.save_model(path)
@ -1950,16 +1953,20 @@ class TestWithDask:
cls.load_model(path) cls.load_model(path)
assert cls.n_classes_ == 10 assert cls.n_classes_ == 10
predt_2 = cls.predict(X) predt_2 = cls.predict(X)
proba_2 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_2.compute()) np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
np.testing.assert_allclose(proba_0.compute(), proba_2.compute())
# Use single node to load # Use single node to load
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
cls.load_model(path) cls.load_model(path)
assert cls.n_classes_ == 10 assert cls.n_classes_ == 10
predt_3 = cls.predict(X_) predt_3 = cls.predict(X_)
proba_3 = cls.predict_proba(X_)
np.testing.assert_allclose(predt_0.compute(), predt_3) np.testing.assert_allclose(predt_0.compute(), predt_3)
np.testing.assert_allclose(proba_0.compute(), proba_3)
def test_dask_unsupported_features(client: "Client") -> None: def test_dask_unsupported_features(client: "Client") -> None: