[breaking][skl] Remove parameter serialization. (#8963)
- Remove parameter serialization in the scikit-learn interface. The scikit-lear interface `save_model` will save only the model and discard all hyper-parameters. This is to align with the native XGBoost interface, which distinguishes the hyper-parameter and model parameters. With the scikit-learn interface, model parameters are attributes of the estimator. For instance, `n_features_in_`, `n_classes_` are always accessible with `estimator.n_features_in_` and `estimator.n_classes_`, but not with the `estimator.get_params`. - Define a `load_model` method for classifier to load its own attributes. - Set n_estimators to None by default.
This commit is contained in:
parent
90645c4957
commit
c2b3a13e70
@ -43,6 +43,8 @@ FPreProcCallable = Callable
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
|
||||
|
||||
ModelIn = Union[str, bytearray, os.PathLike]
|
||||
|
||||
CTypeT = TypeVar(
|
||||
"CTypeT",
|
||||
ctypes.c_void_p,
|
||||
|
||||
@ -88,31 +88,6 @@ def is_cudf_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class XGBoostLabelEncoder(LabelEncoder):
|
||||
"""Label encoder with JSON serialization methods."""
|
||||
|
||||
def to_json(self) -> Dict:
|
||||
"""Returns a JSON compatible dictionary"""
|
||||
meta = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
meta[k] = v.tolist()
|
||||
else:
|
||||
meta[k] = v
|
||||
return meta
|
||||
|
||||
def from_json(self, doc: Dict) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
"""Load the encoder back from a JSON compatible dict."""
|
||||
meta = {}
|
||||
for k, v in doc.items():
|
||||
if k == "classes_":
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
meta[k] = v
|
||||
self.__dict__.update(meta)
|
||||
|
||||
|
||||
try:
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
|
||||
@ -47,6 +47,7 @@ from ._typing import (
|
||||
FeatureInfo,
|
||||
FeatureNames,
|
||||
FeatureTypes,
|
||||
ModelIn,
|
||||
NumpyOrCupy,
|
||||
c_bst_ulong,
|
||||
)
|
||||
@ -2477,7 +2478,7 @@ class Booster:
|
||||
)
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
"""Load the model from a file or bytearray. Path to file can be local
|
||||
or as an URI.
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ from typing import (
|
||||
import numpy
|
||||
|
||||
from . import collective, config
|
||||
from ._typing import _T, FeatureNames, FeatureTypes
|
||||
from ._typing import _T, FeatureNames, FeatureTypes, ModelIn
|
||||
from .callback import TrainingCallback
|
||||
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
|
||||
from .core import (
|
||||
@ -76,6 +76,7 @@ from .core import (
|
||||
from .sklearn import (
|
||||
XGBClassifier,
|
||||
XGBClassifierBase,
|
||||
XGBClassifierMixIn,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRankerMixIn,
|
||||
@ -1839,7 +1840,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||
["estimators", "model"],
|
||||
)
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(
|
||||
self,
|
||||
@ -2019,6 +2020,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
||||
return preds
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
@ -55,7 +55,7 @@ def find_lib_path() -> List[str]:
|
||||
|
||||
# XGBOOST_BUILD_DOC is defined by sphinx conf.
|
||||
if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False):
|
||||
link = "https://xgboost.readthedocs.io/en/latest/build.html"
|
||||
link = "https://xgboost.readthedocs.io/en/stable/install.html"
|
||||
msg = (
|
||||
"Cannot find XGBoost Library in the candidate path. "
|
||||
+ "List of candidates:\n- "
|
||||
|
||||
@ -22,23 +22,18 @@ from typing import (
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
from ._typing import ArrayLike, FeatureNames, FeatureTypes
|
||||
from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
|
||||
from .callback import TrainingCallback
|
||||
|
||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||
# .compat to guarantee the behavior without scikit-learn
|
||||
from .compat import (
|
||||
SKLEARN_INSTALLED,
|
||||
XGBClassifierBase,
|
||||
XGBModelBase,
|
||||
XGBoostLabelEncoder,
|
||||
XGBRegressorBase,
|
||||
)
|
||||
from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase
|
||||
from .config import config_context
|
||||
from .core import (
|
||||
Booster,
|
||||
DMatrix,
|
||||
Metric,
|
||||
Objective,
|
||||
QuantileDMatrix,
|
||||
XGBoostError,
|
||||
_convert_ntree_limit,
|
||||
@ -49,9 +44,24 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
|
||||
from .training import train
|
||||
|
||||
|
||||
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for classification."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _load_model_attributes(self, booster: Booster) -> None:
|
||||
config = json.loads(booster.save_config())
|
||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
# binary classification is treated as regression in XGBoost.
|
||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
|
||||
classes."""
|
||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
|
||||
base classes.
|
||||
|
||||
"""
|
||||
|
||||
_estimator_type = "ranker"
|
||||
|
||||
@ -74,7 +84,7 @@ SklObjective = Optional[
|
||||
|
||||
def _objective_decorator(
|
||||
func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
||||
) -> Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]:
|
||||
) -> Objective:
|
||||
"""Decorate an objective function
|
||||
|
||||
Converts an objective function using the typical sklearn metrics
|
||||
@ -173,7 +183,7 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
|
||||
|
||||
|
||||
__estimator_doc = """
|
||||
n_estimators : int
|
||||
n_estimators : Optional[int]
|
||||
Number of gradient boosted trees. Equivalent to number of boosting
|
||||
rounds.
|
||||
"""
|
||||
@ -598,6 +608,9 @@ def _wrap_evaluation_matrices(
|
||||
return train_dmatrix, evals
|
||||
|
||||
|
||||
DEFAULT_N_ESTIMATORS = 100
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||
["estimators", "model", "objective"],
|
||||
@ -611,7 +624,7 @@ class XGBModel(XGBModelBase):
|
||||
max_bin: Optional[int] = None,
|
||||
grow_policy: Optional[str] = None,
|
||||
learning_rate: Optional[float] = None,
|
||||
n_estimators: int = 100,
|
||||
n_estimators: Optional[int] = None,
|
||||
verbosity: Optional[int] = None,
|
||||
objective: SklObjective = None,
|
||||
booster: Optional[str] = None,
|
||||
@ -797,7 +810,7 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
"""Gets the number of xgboost boosting rounds."""
|
||||
return self.n_estimators
|
||||
return DEFAULT_N_ESTIMATORS if self.n_estimators is None else self.n_estimators
|
||||
|
||||
def _get_type(self) -> str:
|
||||
if not hasattr(self, "_estimator_type"):
|
||||
@ -809,72 +822,33 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def save_model(self, fname: Union[str, os.PathLike]) -> None:
|
||||
meta: Dict[str, Any] = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "_le":
|
||||
meta["_le"] = self._le.to_json()
|
||||
continue
|
||||
if k == "_Booster":
|
||||
continue
|
||||
if k == "classes_":
|
||||
# numpy array is not JSON serializable
|
||||
meta["classes_"] = self.classes_.tolist()
|
||||
continue
|
||||
if k == "feature_types":
|
||||
# Use the `feature_types` attribute from booster instead.
|
||||
meta["feature_types"] = None
|
||||
continue
|
||||
try:
|
||||
json.dumps({k: v})
|
||||
meta[k] = v
|
||||
except TypeError:
|
||||
warnings.warn(
|
||||
str(k) + " is not saved in Scikit-Learn meta.", UserWarning
|
||||
)
|
||||
# For validation.
|
||||
meta["_estimator_type"] = self._get_type()
|
||||
meta_str = json.dumps(meta)
|
||||
self.get_booster().set_attr(scikit_learn=meta_str)
|
||||
self.get_booster().save_model(fname)
|
||||
# Delete the attribute after save
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
|
||||
save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
|
||||
|
||||
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if not hasattr(self, "_Booster"):
|
||||
if not self.__sklearn_is_fitted__():
|
||||
self._Booster = Booster({"n_jobs": self.n_jobs})
|
||||
self.get_booster().load_model(fname)
|
||||
|
||||
meta_str = self.get_booster().attr("scikit_learn")
|
||||
if meta_str is None:
|
||||
# FIXME(jiaming): This doesn't have to be a problem as most of the needed
|
||||
# information like num_class and objective is in Learner class.
|
||||
warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.")
|
||||
return
|
||||
|
||||
meta = json.loads(meta_str)
|
||||
states = {}
|
||||
for k, v in meta.items():
|
||||
if k == "_le":
|
||||
self._le = XGBoostLabelEncoder()
|
||||
self._le.from_json(v)
|
||||
continue
|
||||
# FIXME(jiaming): This can be removed once label encoder is gone since we can
|
||||
# generate it from `np.arange(self.n_classes_)`
|
||||
if k == "classes_":
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
if k == "feature_types":
|
||||
self.feature_types = self.get_booster().feature_types
|
||||
continue
|
||||
if k == "_estimator_type":
|
||||
if self._get_type() != v:
|
||||
t = meta.get("_estimator_type", None)
|
||||
if t is not None and t != self._get_type():
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. "
|
||||
f"Expecting: {self._get_type()}, got: {v}"
|
||||
"Loading an estimator with different type. Expecting: "
|
||||
f"{self._get_type()}, got: {t}"
|
||||
)
|
||||
continue
|
||||
states[k] = v
|
||||
self.__dict__.update(states)
|
||||
# Delete the attribute after load
|
||||
self.feature_types = self.get_booster().feature_types
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
|
||||
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||
@ -965,7 +939,6 @@ class XGBModel(XGBModelBase):
|
||||
"Experimental support for categorical data is not implemented for"
|
||||
" current tree method yet."
|
||||
)
|
||||
|
||||
return model, metric, params, early_stopping_rounds, callbacks
|
||||
|
||||
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
||||
@ -1086,9 +1059,7 @@ class XGBModel(XGBModelBase):
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
obj: Optional[
|
||||
Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
] = _objective_decorator(self.objective)
|
||||
obj: Optional[Objective] = _objective_decorator(self.objective)
|
||||
params["objective"] = "reg:squarederror"
|
||||
else:
|
||||
obj = None
|
||||
@ -1304,8 +1275,10 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
@property
|
||||
def feature_names_in_(self) -> np.ndarray:
|
||||
"""Names of features seen during :py:meth:`fit`. Defined only when `X` has feature
|
||||
names that are all strings."""
|
||||
"""Names of features seen during :py:meth:`fit`. Defined only when `X` has
|
||||
feature names that are all strings.
|
||||
|
||||
"""
|
||||
feature_names = self.get_booster().feature_names
|
||||
if feature_names is None:
|
||||
raise AttributeError(
|
||||
@ -1453,26 +1426,19 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
|
||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
n_estimators : Optional[int]
|
||||
Number of boosting rounds.
|
||||
""",
|
||||
)
|
||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
objective: SklObjective = "binary:logistic",
|
||||
use_label_encoder: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# must match the parameters for `get_params`
|
||||
self.use_label_encoder = use_label_encoder
|
||||
if use_label_encoder is True:
|
||||
raise ValueError("Label encoder was removed in 1.6.0.")
|
||||
if use_label_encoder is not None:
|
||||
warnings.warn("`use_label_encoder` is deprecated in 1.7.0.")
|
||||
super().__init__(objective=objective, **kwargs)
|
||||
|
||||
@_deprecate_positional_args
|
||||
@ -1496,38 +1462,38 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
||||
with config_context(verbosity=self.verbosity):
|
||||
evals_result: TrainingCallback.EvalsLog = {}
|
||||
|
||||
# We keep the n_classes_ as a simple member instead of loading it from
|
||||
# booster in a Python property. This way we can have efficient and
|
||||
# thread-safe prediction.
|
||||
if _is_cudf_df(y) or _is_cudf_ser(y):
|
||||
import cupy as cp # pylint: disable=E0401
|
||||
|
||||
self.classes_ = cp.unique(y.values)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
expected_classes = cp.arange(self.n_classes_)
|
||||
classes = cp.unique(y.values)
|
||||
self.n_classes_ = len(classes)
|
||||
expected_classes = cp.array(self.classes_)
|
||||
elif _is_cupy_array(y):
|
||||
import cupy as cp # pylint: disable=E0401
|
||||
|
||||
self.classes_ = cp.unique(y)
|
||||
self.n_classes_ = len(self.classes_)
|
||||
expected_classes = cp.arange(self.n_classes_)
|
||||
classes = cp.unique(y)
|
||||
self.n_classes_ = len(classes)
|
||||
expected_classes = cp.array(self.classes_)
|
||||
else:
|
||||
self.classes_ = np.unique(np.asarray(y))
|
||||
self.n_classes_ = len(self.classes_)
|
||||
expected_classes = np.arange(self.n_classes_)
|
||||
classes = np.unique(np.asarray(y))
|
||||
self.n_classes_ = len(classes)
|
||||
expected_classes = self.classes_
|
||||
if (
|
||||
self.classes_.shape != expected_classes.shape
|
||||
or not (self.classes_ == expected_classes).all()
|
||||
classes.shape != expected_classes.shape
|
||||
or not (classes == expected_classes).all()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid classes inferred from unique values of `y`. "
|
||||
f"Expected: {expected_classes}, got {self.classes_}"
|
||||
f"Expected: {expected_classes}, got {classes}"
|
||||
)
|
||||
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
obj: Optional[
|
||||
Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
] = _objective_decorator(self.objective)
|
||||
obj: Optional[Objective] = _objective_decorator(self.objective)
|
||||
# Use default value. Is it really not used ?
|
||||
params["objective"] = "binary:logistic"
|
||||
else:
|
||||
@ -1616,7 +1582,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
if len(class_probs.shape) > 1 and self.n_classes_ != 2:
|
||||
# multi-class, turns softprob into softmax
|
||||
column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore
|
||||
column_indexes: np.ndarray = np.argmax(class_probs, axis=1)
|
||||
elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1:
|
||||
# multi-label
|
||||
column_indexes = np.zeros(class_probs.shape)
|
||||
@ -1628,8 +1594,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
column_indexes = np.repeat(0, class_probs.shape[0])
|
||||
column_indexes[class_probs > 0.5] = 1
|
||||
|
||||
if hasattr(self, "_le"):
|
||||
return self._le.inverse_transform(column_indexes)
|
||||
return column_indexes
|
||||
|
||||
def predict_proba(
|
||||
@ -1693,17 +1657,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
# If model is loaded from a raw booster there's no `n_classes_`
|
||||
return _cls_predict_proba(
|
||||
getattr(self, "n_classes_", 0), class_probs, np.vstack
|
||||
)
|
||||
return _cls_predict_proba(self.n_classes_, class_probs, np.vstack)
|
||||
|
||||
@property
|
||||
def classes_(self) -> np.ndarray:
|
||||
return np.arange(self.n_classes_)
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"scikit-learn API for XGBoost random forest classification.",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
n_estimators : Optional[int]
|
||||
Number of trees in random forest to fit.
|
||||
""",
|
||||
)
|
||||
@ -1730,7 +1699,7 @@ class XGBRFClassifier(XGBClassifier):
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
params = super().get_xgb_params()
|
||||
params["num_parallel_tree"] = self.n_estimators
|
||||
params["num_parallel_tree"] = super().get_num_boosting_rounds()
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
@ -1778,7 +1747,7 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||
"scikit-learn API for XGBoost random forest regression.",
|
||||
["model", "objective"],
|
||||
extra_parameters="""
|
||||
n_estimators : int
|
||||
n_estimators : Optional[int]
|
||||
Number of trees in random forest to fit.
|
||||
""",
|
||||
)
|
||||
@ -1805,7 +1774,7 @@ class XGBRFRegressor(XGBRegressor):
|
||||
|
||||
def get_xgb_params(self) -> Dict[str, Any]:
|
||||
params = super().get_xgb_params()
|
||||
params["num_parallel_tree"] = self.n_estimators
|
||||
params["num_parallel_tree"] = super().get_num_boosting_rounds()
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self) -> int:
|
||||
|
||||
@ -39,6 +39,7 @@ import xgboost
|
||||
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.core import Booster
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .data import (
|
||||
@ -215,6 +216,7 @@ class _SparkXGBParams(
|
||||
filtered_params_dict = {
|
||||
k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params
|
||||
}
|
||||
filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS
|
||||
return filtered_params_dict
|
||||
|
||||
def _set_xgb_params_default(self):
|
||||
|
||||
@ -66,7 +66,6 @@ def run_scikit_model_check(name, path):
|
||||
cls.load_model(path)
|
||||
if name.find('0.90') == -1:
|
||||
assert len(cls.classes_) == gm.kClasses
|
||||
assert len(cls._le.classes_) == gm.kClasses
|
||||
assert cls.n_classes_ == gm.kClasses
|
||||
assert (len(cls.get_booster().get_dump()) ==
|
||||
gm.kRounds * gm.kForests * gm.kClasses), path
|
||||
|
||||
@ -38,36 +38,34 @@ def test_binary_classification():
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob'])
|
||||
@pytest.mark.parametrize("objective", ["multi:softmax", "multi:softprob"])
|
||||
def test_multiclass_classification(objective):
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
def check_pred(preds, labels, output_margin):
|
||||
if output_margin:
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if preds[i].argmax() != labels[i]) / float(len(preds))
|
||||
err = sum(
|
||||
1 for i in range(len(preds)) if preds[i].argmax() != labels[i]
|
||||
) / float(len(preds))
|
||||
else:
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if preds[i] != labels[i]) / float(len(preds))
|
||||
err = sum(1 for i in range(len(preds)) if preds[i] != labels[i]) / float(
|
||||
len(preds)
|
||||
)
|
||||
assert err < 0.4
|
||||
|
||||
iris = load_iris()
|
||||
y = iris['target']
|
||||
X = iris['data']
|
||||
X, y = load_iris(return_X_y=True)
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index])
|
||||
assert (xgb_model.get_booster().num_boosted_rounds() ==
|
||||
xgb_model.n_estimators)
|
||||
xgb_model = xgb.XGBClassifier(objective=objective).fit(
|
||||
X[train_index], y[train_index]
|
||||
)
|
||||
assert xgb_model.get_booster().num_boosted_rounds() == 100
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
# test other params in XGBClassifier().fit
|
||||
preds2 = xgb_model.predict(X[test_index], output_margin=True,
|
||||
ntree_limit=3)
|
||||
preds3 = xgb_model.predict(X[test_index], output_margin=True,
|
||||
ntree_limit=0)
|
||||
preds4 = xgb_model.predict(X[test_index], output_margin=False,
|
||||
ntree_limit=3)
|
||||
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
|
||||
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
|
||||
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
|
||||
labels = y[test_index]
|
||||
|
||||
check_pred(preds, labels, output_margin=False)
|
||||
@ -761,9 +759,9 @@ def test_parameters_access():
|
||||
clf = save_load(clf)
|
||||
|
||||
assert clf.tree_method is None
|
||||
assert clf.n_estimators == 2
|
||||
assert clf.n_estimators is None
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
assert clf.get_params()["n_estimators"] == 2
|
||||
assert clf.get_params()["n_estimators"] is None
|
||||
assert get_tm(clf) == "auto" # discarded for save/load_model
|
||||
|
||||
clf.set_params(tree_method="hist")
|
||||
@ -771,9 +769,7 @@ def test_parameters_access():
|
||||
clf = pickle.loads(pickle.dumps(clf))
|
||||
assert clf.get_params()["tree_method"] == "hist"
|
||||
clf = save_load(clf)
|
||||
# FIXME(jiamingy): We should remove this behavior once we remove parameters
|
||||
# serialization for skl save/load_model.
|
||||
assert clf.get_params()["tree_method"] == "hist"
|
||||
assert clf.get_params()["tree_method"] is None
|
||||
|
||||
|
||||
def test_kwargs_error():
|
||||
@ -902,6 +898,7 @@ def save_load_model(model_path):
|
||||
xgb_model.load_model(model_path)
|
||||
|
||||
assert isinstance(xgb_model.classes_, np.ndarray)
|
||||
np.testing.assert_equal(xgb_model.classes_, np.array([0, 1]))
|
||||
assert isinstance(xgb_model._Booster, xgb.Booster)
|
||||
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
@ -933,8 +930,10 @@ def test_save_load_model():
|
||||
save_load_model(model_path)
|
||||
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||
model_path = os.path.join(tempdir, 'digits.model.ubj')
|
||||
digits = load_digits(n_class=2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
@ -959,6 +958,28 @@ def test_save_load_model():
|
||||
predt_1 = cls.predict(X)
|
||||
assert np.allclose(predt_0, predt_1)
|
||||
|
||||
# mclass
|
||||
X, y = load_digits(n_class=10, return_X_y=True)
|
||||
# small test_size to force early stop
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.01, random_state=1
|
||||
)
|
||||
clf = xgb.XGBClassifier(
|
||||
n_estimators=64, tree_method="hist", early_stopping_rounds=2
|
||||
)
|
||||
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
score = clf.best_score
|
||||
clf.save_model(model_path)
|
||||
|
||||
clf = xgb.XGBClassifier()
|
||||
clf.load_model(model_path)
|
||||
assert clf.classes_.size == 10
|
||||
np.testing.assert_equal(clf.classes_, np.arange(10))
|
||||
assert clf.n_classes_ == 10
|
||||
|
||||
assert clf.best_iteration == 27
|
||||
assert clf.best_score == score
|
||||
|
||||
|
||||
def test_RFECV():
|
||||
from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user