parent
41ce8f28b2
commit
e4ee4e79dc
@ -78,7 +78,6 @@ from .data import _is_cudf_ser, _is_cupy_array
|
|||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
XGBClassifier,
|
XGBClassifier,
|
||||||
XGBClassifierBase,
|
XGBClassifierBase,
|
||||||
XGBClassifierMixIn,
|
|
||||||
XGBModel,
|
XGBModel,
|
||||||
XGBRanker,
|
XGBRanker,
|
||||||
XGBRankerMixIn,
|
XGBRankerMixIn,
|
||||||
@ -1854,7 +1853,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||||
["estimators", "model"],
|
["estimators", "model"],
|
||||||
)
|
)
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
# pylint: disable=missing-class-docstring
|
# pylint: disable=missing-class-docstring
|
||||||
async def _fit_async(
|
async def _fit_async(
|
||||||
self,
|
self,
|
||||||
@ -2036,10 +2035,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
|
|||||||
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def load_model(self, fname: ModelIn) -> None:
|
|
||||||
super().load_model(fname)
|
|
||||||
self._load_model_attributes(self.get_booster())
|
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||||
|
|||||||
@ -43,19 +43,6 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
|
|||||||
from .training import train
|
from .training import train
|
||||||
|
|
||||||
|
|
||||||
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
|
|
||||||
"""MixIn for classification."""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def _load_model_attributes(self, booster: Booster) -> None:
|
|
||||||
config = json.loads(booster.save_config())
|
|
||||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
|
||||||
# binary classification is treated as regression in XGBoost.
|
|
||||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
|
||||||
|
|
||||||
|
|
||||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
|
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
|
||||||
base classes.
|
base classes.
|
||||||
@ -845,9 +832,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.get_booster().load_model(fname)
|
self.get_booster().load_model(fname)
|
||||||
|
|
||||||
meta_str = self.get_booster().attr("scikit_learn")
|
meta_str = self.get_booster().attr("scikit_learn")
|
||||||
if meta_str is None:
|
if meta_str is not None:
|
||||||
return
|
|
||||||
|
|
||||||
meta = json.loads(meta_str)
|
meta = json.loads(meta_str)
|
||||||
t = meta.get("_estimator_type", None)
|
t = meta.get("_estimator_type", None)
|
||||||
if t is not None and t != self._get_type():
|
if t is not None and t != self._get_type():
|
||||||
@ -855,11 +840,30 @@ class XGBModel(XGBModelBase):
|
|||||||
"Loading an estimator with different type. Expecting: "
|
"Loading an estimator with different type. Expecting: "
|
||||||
f"{self._get_type()}, got: {t}"
|
f"{self._get_type()}, got: {t}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feature_types = self.get_booster().feature_types
|
self.feature_types = self.get_booster().feature_types
|
||||||
self.get_booster().set_attr(scikit_learn=None)
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
config = json.loads(self.get_booster().save_config())
|
||||||
|
self._load_model_attributes(config)
|
||||||
|
|
||||||
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||||
|
|
||||||
|
def _load_model_attributes(self, config: dict) -> None:
|
||||||
|
"""Load model attributes without hyper-parameters."""
|
||||||
|
from sklearn.base import is_classifier
|
||||||
|
|
||||||
|
booster = self.get_booster()
|
||||||
|
|
||||||
|
self.objective = config["learner"]["objective"]["name"]
|
||||||
|
self.booster = config["learner"]["gradient_booster"]["name"]
|
||||||
|
self.base_score = config["learner"]["learner_model_param"]["base_score"]
|
||||||
|
self.feature_types = booster.feature_types
|
||||||
|
|
||||||
|
if is_classifier(self):
|
||||||
|
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||||
|
# binary classification is treated as regression in XGBoost.
|
||||||
|
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||||
|
|
||||||
# pylint: disable=too-many-branches
|
# pylint: disable=too-many-branches
|
||||||
def _configure_fit(
|
def _configure_fit(
|
||||||
self,
|
self,
|
||||||
@ -1409,7 +1413,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
|
|||||||
Number of boosting rounds.
|
Number of boosting rounds.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1637,10 +1641,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
|||||||
def classes_(self) -> np.ndarray:
|
def classes_(self) -> np.ndarray:
|
||||||
return np.arange(self.n_classes_)
|
return np.arange(self.n_classes_)
|
||||||
|
|
||||||
def load_model(self, fname: ModelIn) -> None:
|
|
||||||
super().load_model(fname)
|
|
||||||
self._load_model_attributes(self.get_booster())
|
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"scikit-learn API for XGBoost random forest classification.",
|
"scikit-learn API for XGBoost random forest classification.",
|
||||||
|
|||||||
@ -940,6 +940,7 @@ def save_load_model(model_path):
|
|||||||
predt_0 = clf.predict(X)
|
predt_0 = clf.predict(X)
|
||||||
clf.save_model(model_path)
|
clf.save_model(model_path)
|
||||||
clf.load_model(model_path)
|
clf.load_model(model_path)
|
||||||
|
assert clf.booster == "gblinear"
|
||||||
predt_1 = clf.predict(X)
|
predt_1 = clf.predict(X)
|
||||||
np.testing.assert_allclose(predt_0, predt_1)
|
np.testing.assert_allclose(predt_0, predt_1)
|
||||||
assert clf.best_iteration == best_iteration
|
assert clf.best_iteration == best_iteration
|
||||||
@ -955,25 +956,26 @@ def save_load_model(model_path):
|
|||||||
|
|
||||||
def test_save_load_model():
|
def test_save_load_model():
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
model_path = os.path.join(tempdir, 'digits.model')
|
model_path = os.path.join(tempdir, "digits.model")
|
||||||
save_load_model(model_path)
|
save_load_model(model_path)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
model_path = os.path.join(tempdir, 'digits.model.json')
|
model_path = os.path.join(tempdir, "digits.model.json")
|
||||||
save_load_model(model_path)
|
save_load_model(model_path)
|
||||||
|
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
model_path = os.path.join(tempdir, 'digits.model.ubj')
|
model_path = os.path.join(tempdir, "digits.model.ubj")
|
||||||
digits = load_digits(n_class=2)
|
digits = load_digits(n_class=2)
|
||||||
y = digits['target']
|
y = digits["target"]
|
||||||
X = digits['data']
|
X = digits["data"]
|
||||||
booster = xgb.train({'tree_method': 'hist',
|
booster = xgb.train(
|
||||||
'objective': 'binary:logistic'},
|
{"tree_method": "hist", "objective": "binary:logistic"},
|
||||||
dtrain=xgb.DMatrix(X, y),
|
dtrain=xgb.DMatrix(X, y),
|
||||||
num_boost_round=4)
|
num_boost_round=4,
|
||||||
|
)
|
||||||
predt_0 = booster.predict(xgb.DMatrix(X))
|
predt_0 = booster.predict(xgb.DMatrix(X))
|
||||||
booster.save_model(model_path)
|
booster.save_model(model_path)
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
@ -1007,6 +1009,8 @@ def test_save_load_model():
|
|||||||
clf = xgb.XGBClassifier()
|
clf = xgb.XGBClassifier()
|
||||||
clf.load_model(model_path)
|
clf.load_model(model_path)
|
||||||
assert clf.classes_.size == 10
|
assert clf.classes_.size == 10
|
||||||
|
assert clf.objective == "multi:softprob"
|
||||||
|
|
||||||
np.testing.assert_equal(clf.classes_, np.arange(10))
|
np.testing.assert_equal(clf.classes_, np.arange(10))
|
||||||
assert clf.n_classes_ == 10
|
assert clf.n_classes_ == 10
|
||||||
|
|
||||||
|
|||||||
@ -1932,6 +1932,7 @@ class TestWithDask:
|
|||||||
cls.client = client
|
cls.client = client
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
predt_0 = cls.predict(X)
|
predt_0 = cls.predict(X)
|
||||||
|
proba_0 = cls.predict_proba(X)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
path = os.path.join(tmpdir, "model.pkl")
|
path = os.path.join(tmpdir, "model.pkl")
|
||||||
@ -1941,7 +1942,9 @@ class TestWithDask:
|
|||||||
with open(path, "rb") as fd:
|
with open(path, "rb") as fd:
|
||||||
cls = pickle.load(fd)
|
cls = pickle.load(fd)
|
||||||
predt_1 = cls.predict(X)
|
predt_1 = cls.predict(X)
|
||||||
|
proba_1 = cls.predict_proba(X)
|
||||||
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
|
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
|
||||||
|
np.testing.assert_allclose(proba_0.compute(), proba_1.compute())
|
||||||
|
|
||||||
path = os.path.join(tmpdir, "cls.json")
|
path = os.path.join(tmpdir, "cls.json")
|
||||||
cls.save_model(path)
|
cls.save_model(path)
|
||||||
@ -1950,16 +1953,20 @@ class TestWithDask:
|
|||||||
cls.load_model(path)
|
cls.load_model(path)
|
||||||
assert cls.n_classes_ == 10
|
assert cls.n_classes_ == 10
|
||||||
predt_2 = cls.predict(X)
|
predt_2 = cls.predict(X)
|
||||||
|
proba_2 = cls.predict_proba(X)
|
||||||
|
|
||||||
np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
|
np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
|
||||||
|
np.testing.assert_allclose(proba_0.compute(), proba_2.compute())
|
||||||
|
|
||||||
# Use single node to load
|
# Use single node to load
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
cls.load_model(path)
|
cls.load_model(path)
|
||||||
assert cls.n_classes_ == 10
|
assert cls.n_classes_ == 10
|
||||||
predt_3 = cls.predict(X_)
|
predt_3 = cls.predict(X_)
|
||||||
|
proba_3 = cls.predict_proba(X_)
|
||||||
|
|
||||||
np.testing.assert_allclose(predt_0.compute(), predt_3)
|
np.testing.assert_allclose(predt_0.compute(), predt_3)
|
||||||
|
np.testing.assert_allclose(proba_0.compute(), proba_3)
|
||||||
|
|
||||||
|
|
||||||
def test_dask_unsupported_features(client: "Client") -> None:
|
def test_dask_unsupported_features(client: "Client") -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user