@@ -78,7 +78,6 @@ from .data import _is_cudf_ser, _is_cupy_array
|
||||
from .sklearn import (
|
||||
XGBClassifier,
|
||||
XGBClassifierBase,
|
||||
XGBClassifierMixIn,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRankerMixIn,
|
||||
@@ -1854,7 +1853,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||
["estimators", "model"],
|
||||
)
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(
|
||||
self,
|
||||
@@ -2036,10 +2035,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
|
||||
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
||||
return preds
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
@@ -43,19 +43,6 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
|
||||
from .training import train
|
||||
|
||||
|
||||
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for classification."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _load_model_attributes(self, booster: Booster) -> None:
|
||||
config = json.loads(booster.save_config())
|
||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
# binary classification is treated as regression in XGBoost.
|
||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
|
||||
base classes.
|
||||
@@ -845,21 +832,38 @@ class XGBModel(XGBModelBase):
|
||||
self.get_booster().load_model(fname)
|
||||
|
||||
meta_str = self.get_booster().attr("scikit_learn")
|
||||
if meta_str is None:
|
||||
return
|
||||
if meta_str is not None:
|
||||
meta = json.loads(meta_str)
|
||||
t = meta.get("_estimator_type", None)
|
||||
if t is not None and t != self._get_type():
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. Expecting: "
|
||||
f"{self._get_type()}, got: {t}"
|
||||
)
|
||||
|
||||
meta = json.loads(meta_str)
|
||||
t = meta.get("_estimator_type", None)
|
||||
if t is not None and t != self._get_type():
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. Expecting: "
|
||||
f"{self._get_type()}, got: {t}"
|
||||
)
|
||||
self.feature_types = self.get_booster().feature_types
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
config = json.loads(self.get_booster().save_config())
|
||||
self._load_model_attributes(config)
|
||||
|
||||
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||
|
||||
def _load_model_attributes(self, config: dict) -> None:
|
||||
"""Load model attributes without hyper-parameters."""
|
||||
from sklearn.base import is_classifier
|
||||
|
||||
booster = self.get_booster()
|
||||
|
||||
self.objective = config["learner"]["objective"]["name"]
|
||||
self.booster = config["learner"]["gradient_booster"]["name"]
|
||||
self.base_score = config["learner"]["learner_model_param"]["base_score"]
|
||||
self.feature_types = booster.feature_types
|
||||
|
||||
if is_classifier(self):
|
||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
# binary classification is treated as regression in XGBoost.
|
||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
def _configure_fit(
|
||||
self,
|
||||
@@ -1409,7 +1413,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
|
||||
Number of boosting rounds.
|
||||
""",
|
||||
)
|
||||
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
@@ -1637,10 +1641,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
def classes_(self) -> np.ndarray:
|
||||
return np.arange(self.n_classes_)
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"scikit-learn API for XGBoost random forest classification.",
|
||||
|
||||
Reference in New Issue
Block a user