[breaking][skl] Remove parameter serialization. (#8963)

- Remove parameter serialization in the scikit-learn interface.

The scikit-lear interface `save_model` will save only the model and discard all
hyper-parameters. This is to align with the native XGBoost interface, which distinguishes
the hyper-parameter and model parameters.

With the scikit-learn interface, model parameters are attributes of the estimator. For
instance, `n_features_in_`, `n_classes_` are always accessible with
`estimator.n_features_in_` and `estimator.n_classes_`, but not with the
`estimator.get_params`.

- Define a `load_model` method for classifier to load its own attributes.

- Set n_estimators to None by default.
This commit is contained in:
Jiaming Yuan 2023-03-27 21:34:10 +08:00 committed by GitHub
parent 90645c4957
commit c2b3a13e70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 134 additions and 160 deletions

View File

@ -43,6 +43,8 @@ FPreProcCallable = Callable
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103 c_bst_ulong = ctypes.c_uint64 # pylint: disable=C0103
ModelIn = Union[str, bytearray, os.PathLike]
CTypeT = TypeVar( CTypeT = TypeVar(
"CTypeT", "CTypeT",
ctypes.c_void_p, ctypes.c_void_p,

View File

@ -88,31 +88,6 @@ def is_cudf_available() -> bool:
return False return False
class XGBoostLabelEncoder(LabelEncoder):
"""Label encoder with JSON serialization methods."""
def to_json(self) -> Dict:
"""Returns a JSON compatible dictionary"""
meta = {}
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta
def from_json(self, doc: Dict) -> None:
# pylint: disable=attribute-defined-outside-init
"""Load the encoder back from a JSON compatible dict."""
meta = {}
for k, v in doc.items():
if k == "classes_":
self.classes_ = np.array(v)
continue
meta[k] = v
self.__dict__.update(meta)
try: try:
import scipy.sparse as scipy_sparse import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr from scipy.sparse import csr_matrix as scipy_csr

View File

@ -47,6 +47,7 @@ from ._typing import (
FeatureInfo, FeatureInfo,
FeatureNames, FeatureNames,
FeatureTypes, FeatureTypes,
ModelIn,
NumpyOrCupy, NumpyOrCupy,
c_bst_ulong, c_bst_ulong,
) )
@ -2477,7 +2478,7 @@ class Booster:
) )
return ctypes2buffer(cptr, length.value) return ctypes2buffer(cptr, length.value)
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: def load_model(self, fname: ModelIn) -> None:
"""Load the model from a file or bytearray. Path to file can be local """Load the model from a file or bytearray. Path to file can be local
or as an URI. or as an URI.

View File

@ -60,7 +60,7 @@ from typing import (
import numpy import numpy
from . import collective, config from . import collective, config
from ._typing import _T, FeatureNames, FeatureTypes from ._typing import _T, FeatureNames, FeatureTypes, ModelIn
from .callback import TrainingCallback from .callback import TrainingCallback
from .compat import DataFrame, LazyLoader, concat, lazy_isinstance from .compat import DataFrame, LazyLoader, concat, lazy_isinstance
from .core import ( from .core import (
@ -76,6 +76,7 @@ from .core import (
from .sklearn import ( from .sklearn import (
XGBClassifier, XGBClassifier,
XGBClassifierBase, XGBClassifierBase,
XGBClassifierMixIn,
XGBModel, XGBModel,
XGBRanker, XGBRanker,
XGBRankerMixIn, XGBRankerMixIn,
@ -1839,7 +1840,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
"Implementation of the scikit-learn API for XGBoost classification.", "Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"], ["estimators", "model"],
) )
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
async def _fit_async( async def _fit_async(
self, self,
@ -2019,6 +2020,10 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1) preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds return preds
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc( @xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost Ranking. """Implementation of the Scikit-Learn API for XGBoost Ranking.

View File

@ -55,7 +55,7 @@ def find_lib_path() -> List[str]:
# XGBOOST_BUILD_DOC is defined by sphinx conf. # XGBOOST_BUILD_DOC is defined by sphinx conf.
if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False): if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False):
link = "https://xgboost.readthedocs.io/en/latest/build.html" link = "https://xgboost.readthedocs.io/en/stable/install.html"
msg = ( msg = (
"Cannot find XGBoost Library in the candidate path. " "Cannot find XGBoost Library in the candidate path. "
+ "List of candidates:\n- " + "List of candidates:\n- "

View File

@ -22,23 +22,18 @@ from typing import (
import numpy as np import numpy as np
from scipy.special import softmax from scipy.special import softmax
from ._typing import ArrayLike, FeatureNames, FeatureTypes from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
from .callback import TrainingCallback from .callback import TrainingCallback
# Do not use class names on scikit-learn directly. Re-define the classes on # Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn # .compat to guarantee the behavior without scikit-learn
from .compat import ( from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase
SKLEARN_INSTALLED,
XGBClassifierBase,
XGBModelBase,
XGBoostLabelEncoder,
XGBRegressorBase,
)
from .config import config_context from .config import config_context
from .core import ( from .core import (
Booster, Booster,
DMatrix, DMatrix,
Metric, Metric,
Objective,
QuantileDMatrix, QuantileDMatrix,
XGBoostError, XGBoostError,
_convert_ntree_limit, _convert_ntree_limit,
@ -49,9 +44,24 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
from .training import train from .training import train
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
"""MixIn for classification."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def _load_model_attributes(self, booster: Booster) -> None:
config = json.loads(booster.save_config())
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
# binary classification is treated as regression in XGBoost.
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
class XGBRankerMixIn: # pylint: disable=too-few-public-methods class XGBRankerMixIn: # pylint: disable=too-few-public-methods
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
classes.""" base classes.
"""
_estimator_type = "ranker" _estimator_type = "ranker"
@ -74,7 +84,7 @@ SklObjective = Optional[
def _objective_decorator( def _objective_decorator(
func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
) -> Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]: ) -> Objective:
"""Decorate an objective function """Decorate an objective function
Converts an objective function using the typical sklearn metrics Converts an objective function using the typical sklearn metrics
@ -173,7 +183,7 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
__estimator_doc = """ __estimator_doc = """
n_estimators : int n_estimators : Optional[int]
Number of gradient boosted trees. Equivalent to number of boosting Number of gradient boosted trees. Equivalent to number of boosting
rounds. rounds.
""" """
@ -598,6 +608,9 @@ def _wrap_evaluation_matrices(
return train_dmatrix, evals return train_dmatrix, evals
DEFAULT_N_ESTIMATORS = 100
@xgboost_model_doc( @xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", """Implementation of the Scikit-Learn API for XGBoost.""",
["estimators", "model", "objective"], ["estimators", "model", "objective"],
@ -611,7 +624,7 @@ class XGBModel(XGBModelBase):
max_bin: Optional[int] = None, max_bin: Optional[int] = None,
grow_policy: Optional[str] = None, grow_policy: Optional[str] = None,
learning_rate: Optional[float] = None, learning_rate: Optional[float] = None,
n_estimators: int = 100, n_estimators: Optional[int] = None,
verbosity: Optional[int] = None, verbosity: Optional[int] = None,
objective: SklObjective = None, objective: SklObjective = None,
booster: Optional[str] = None, booster: Optional[str] = None,
@ -797,7 +810,7 @@ class XGBModel(XGBModelBase):
def get_num_boosting_rounds(self) -> int: def get_num_boosting_rounds(self) -> int:
"""Gets the number of xgboost boosting rounds.""" """Gets the number of xgboost boosting rounds."""
return self.n_estimators return DEFAULT_N_ESTIMATORS if self.n_estimators is None else self.n_estimators
def _get_type(self) -> str: def _get_type(self) -> str:
if not hasattr(self, "_estimator_type"): if not hasattr(self, "_estimator_type"):
@ -809,72 +822,33 @@ class XGBModel(XGBModelBase):
def save_model(self, fname: Union[str, os.PathLike]) -> None: def save_model(self, fname: Union[str, os.PathLike]) -> None:
meta: Dict[str, Any] = {} meta: Dict[str, Any] = {}
for k, v in self.__dict__.items(): # For validation.
if k == "_le":
meta["_le"] = self._le.to_json()
continue
if k == "_Booster":
continue
if k == "classes_":
# numpy array is not JSON serializable
meta["classes_"] = self.classes_.tolist()
continue
if k == "feature_types":
# Use the `feature_types` attribute from booster instead.
meta["feature_types"] = None
continue
try:
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(
str(k) + " is not saved in Scikit-Learn meta.", UserWarning
)
meta["_estimator_type"] = self._get_type() meta["_estimator_type"] = self._get_type()
meta_str = json.dumps(meta) meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str) self.get_booster().set_attr(scikit_learn=meta_str)
self.get_booster().save_model(fname) self.get_booster().save_model(fname)
# Delete the attribute after save
self.get_booster().set_attr(scikit_learn=None) self.get_booster().set_attr(scikit_learn=None)
save_model.__doc__ = f"""{Booster.save_model.__doc__}""" save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: def load_model(self, fname: ModelIn) -> None:
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if not hasattr(self, "_Booster"): if not self.__sklearn_is_fitted__():
self._Booster = Booster({"n_jobs": self.n_jobs}) self._Booster = Booster({"n_jobs": self.n_jobs})
self.get_booster().load_model(fname) self.get_booster().load_model(fname)
meta_str = self.get_booster().attr("scikit_learn") meta_str = self.get_booster().attr("scikit_learn")
if meta_str is None: if meta_str is None:
# FIXME(jiaming): This doesn't have to be a problem as most of the needed
# information like num_class and objective is in Learner class.
warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.")
return return
meta = json.loads(meta_str) meta = json.loads(meta_str)
states = {} t = meta.get("_estimator_type", None)
for k, v in meta.items(): if t is not None and t != self._get_type():
if k == "_le": raise TypeError(
self._le = XGBoostLabelEncoder() "Loading an estimator with different type. Expecting: "
self._le.from_json(v) f"{self._get_type()}, got: {t}"
continue )
# FIXME(jiaming): This can be removed once label encoder is gone since we can self.feature_types = self.get_booster().feature_types
# generate it from `np.arange(self.n_classes_)`
if k == "classes_":
self.classes_ = np.array(v)
continue
if k == "feature_types":
self.feature_types = self.get_booster().feature_types
continue
if k == "_estimator_type":
if self._get_type() != v:
raise TypeError(
"Loading an estimator with different type. "
f"Expecting: {self._get_type()}, got: {v}"
)
continue
states[k] = v
self.__dict__.update(states)
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None) self.get_booster().set_attr(scikit_learn=None)
load_model.__doc__ = f"""{Booster.load_model.__doc__}""" load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
@ -965,7 +939,6 @@ class XGBModel(XGBModelBase):
"Experimental support for categorical data is not implemented for" "Experimental support for categorical data is not implemented for"
" current tree method yet." " current tree method yet."
) )
return model, metric, params, early_stopping_rounds, callbacks return model, metric, params, early_stopping_rounds, callbacks
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix: def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
@ -1086,9 +1059,7 @@ class XGBModel(XGBModelBase):
params = self.get_xgb_params() params = self.get_xgb_params()
if callable(self.objective): if callable(self.objective):
obj: Optional[ obj: Optional[Objective] = _objective_decorator(self.objective)
Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
] = _objective_decorator(self.objective)
params["objective"] = "reg:squarederror" params["objective"] = "reg:squarederror"
else: else:
obj = None obj = None
@ -1304,8 +1275,10 @@ class XGBModel(XGBModelBase):
@property @property
def feature_names_in_(self) -> np.ndarray: def feature_names_in_(self) -> np.ndarray:
"""Names of features seen during :py:meth:`fit`. Defined only when `X` has feature """Names of features seen during :py:meth:`fit`. Defined only when `X` has
names that are all strings.""" feature names that are all strings.
"""
feature_names = self.get_booster().feature_names feature_names = self.get_booster().feature_names
if feature_names is None: if feature_names is None:
raise AttributeError( raise AttributeError(
@ -1453,26 +1426,19 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
"Implementation of the scikit-learn API for XGBoost classification.", "Implementation of the scikit-learn API for XGBoost classification.",
["model", "objective"], ["model", "objective"],
extra_parameters=""" extra_parameters="""
n_estimators : int n_estimators : Optional[int]
Number of boosting rounds. Number of boosting rounds.
""", """,
) )
class XGBClassifier(XGBModel, XGBClassifierBase): class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args @_deprecate_positional_args
def __init__( def __init__(
self, self,
*, *,
objective: SklObjective = "binary:logistic", objective: SklObjective = "binary:logistic",
use_label_encoder: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# must match the parameters for `get_params`
self.use_label_encoder = use_label_encoder
if use_label_encoder is True:
raise ValueError("Label encoder was removed in 1.6.0.")
if use_label_encoder is not None:
warnings.warn("`use_label_encoder` is deprecated in 1.7.0.")
super().__init__(objective=objective, **kwargs) super().__init__(objective=objective, **kwargs)
@_deprecate_positional_args @_deprecate_positional_args
@ -1496,38 +1462,38 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable = attribute-defined-outside-init,too-many-statements # pylint: disable = attribute-defined-outside-init,too-many-statements
with config_context(verbosity=self.verbosity): with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {} evals_result: TrainingCallback.EvalsLog = {}
# We keep the n_classes_ as a simple member instead of loading it from
# booster in a Python property. This way we can have efficient and
# thread-safe prediction.
if _is_cudf_df(y) or _is_cudf_ser(y): if _is_cudf_df(y) or _is_cudf_ser(y):
import cupy as cp # pylint: disable=E0401 import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y.values) classes = cp.unique(y.values)
self.n_classes_ = len(self.classes_) self.n_classes_ = len(classes)
expected_classes = cp.arange(self.n_classes_) expected_classes = cp.array(self.classes_)
elif _is_cupy_array(y): elif _is_cupy_array(y):
import cupy as cp # pylint: disable=E0401 import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y) classes = cp.unique(y)
self.n_classes_ = len(self.classes_) self.n_classes_ = len(classes)
expected_classes = cp.arange(self.n_classes_) expected_classes = cp.array(self.classes_)
else: else:
self.classes_ = np.unique(np.asarray(y)) classes = np.unique(np.asarray(y))
self.n_classes_ = len(self.classes_) self.n_classes_ = len(classes)
expected_classes = np.arange(self.n_classes_) expected_classes = self.classes_
if ( if (
self.classes_.shape != expected_classes.shape classes.shape != expected_classes.shape
or not (self.classes_ == expected_classes).all() or not (classes == expected_classes).all()
): ):
raise ValueError( raise ValueError(
f"Invalid classes inferred from unique values of `y`. " f"Invalid classes inferred from unique values of `y`. "
f"Expected: {expected_classes}, got {self.classes_}" f"Expected: {expected_classes}, got {classes}"
) )
params = self.get_xgb_params() params = self.get_xgb_params()
if callable(self.objective): if callable(self.objective):
obj: Optional[ obj: Optional[Objective] = _objective_decorator(self.objective)
Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
] = _objective_decorator(self.objective)
# Use default value. Is it really not used ? # Use default value. Is it really not used ?
params["objective"] = "binary:logistic" params["objective"] = "binary:logistic"
else: else:
@ -1616,7 +1582,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if len(class_probs.shape) > 1 and self.n_classes_ != 2: if len(class_probs.shape) > 1 and self.n_classes_ != 2:
# multi-class, turns softprob into softmax # multi-class, turns softprob into softmax
column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore column_indexes: np.ndarray = np.argmax(class_probs, axis=1)
elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1:
# multi-label # multi-label
column_indexes = np.zeros(class_probs.shape) column_indexes = np.zeros(class_probs.shape)
@ -1628,8 +1594,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
column_indexes = np.repeat(0, class_probs.shape[0]) column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
if hasattr(self, "_le"):
return self._le.inverse_transform(column_indexes)
return column_indexes return column_indexes
def predict_proba( def predict_proba(
@ -1693,17 +1657,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin=base_margin, base_margin=base_margin,
iteration_range=iteration_range, iteration_range=iteration_range,
) )
# If model is loaded from a raw booster there's no `n_classes_` return _cls_predict_proba(self.n_classes_, class_probs, np.vstack)
return _cls_predict_proba(
getattr(self, "n_classes_", 0), class_probs, np.vstack @property
) def classes_(self) -> np.ndarray:
return np.arange(self.n_classes_)
def load_model(self, fname: ModelIn) -> None:
super().load_model(fname)
self._load_model_attributes(self.get_booster())
@xgboost_model_doc( @xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.", "scikit-learn API for XGBoost random forest classification.",
["model", "objective"], ["model", "objective"],
extra_parameters=""" extra_parameters="""
n_estimators : int n_estimators : Optional[int]
Number of trees in random forest to fit. Number of trees in random forest to fit.
""", """,
) )
@ -1730,7 +1699,7 @@ class XGBRFClassifier(XGBClassifier):
def get_xgb_params(self) -> Dict[str, Any]: def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params() params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators params["num_parallel_tree"] = super().get_num_boosting_rounds()
return params return params
def get_num_boosting_rounds(self) -> int: def get_num_boosting_rounds(self) -> int:
@ -1778,7 +1747,7 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
"scikit-learn API for XGBoost random forest regression.", "scikit-learn API for XGBoost random forest regression.",
["model", "objective"], ["model", "objective"],
extra_parameters=""" extra_parameters="""
n_estimators : int n_estimators : Optional[int]
Number of trees in random forest to fit. Number of trees in random forest to fit.
""", """,
) )
@ -1805,7 +1774,7 @@ class XGBRFRegressor(XGBRegressor):
def get_xgb_params(self) -> Dict[str, Any]: def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params() params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators params["num_parallel_tree"] = super().get_num_boosting_rounds()
return params return params
def get_num_boosting_rounds(self) -> int: def get_num_boosting_rounds(self) -> int:

View File

@ -39,6 +39,7 @@ import xgboost
from xgboost import XGBClassifier, XGBRanker, XGBRegressor from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from xgboost.compat import is_cudf_available from xgboost.compat import is_cudf_available
from xgboost.core import Booster from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
from xgboost.training import train as worker_train from xgboost.training import train as worker_train
from .data import ( from .data import (
@ -215,6 +216,7 @@ class _SparkXGBParams(
filtered_params_dict = { filtered_params_dict = {
k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params k: params_dict[k] for k in params_dict if k not in _unsupported_xgb_params
} }
filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS
return filtered_params_dict return filtered_params_dict
def _set_xgb_params_default(self): def _set_xgb_params_default(self):

View File

@ -66,7 +66,6 @@ def run_scikit_model_check(name, path):
cls.load_model(path) cls.load_model(path)
if name.find('0.90') == -1: if name.find('0.90') == -1:
assert len(cls.classes_) == gm.kClasses assert len(cls.classes_) == gm.kClasses
assert len(cls._le.classes_) == gm.kClasses
assert cls.n_classes_ == gm.kClasses assert cls.n_classes_ == gm.kClasses
assert (len(cls.get_booster().get_dump()) == assert (len(cls.get_booster().get_dump()) ==
gm.kRounds * gm.kForests * gm.kClasses), path gm.kRounds * gm.kForests * gm.kClasses), path

View File

@ -38,36 +38,34 @@ def test_binary_classification():
assert err < 0.1 assert err < 0.1
@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob']) @pytest.mark.parametrize("objective", ["multi:softmax", "multi:softprob"])
def test_multiclass_classification(objective): def test_multiclass_classification(objective):
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
def check_pred(preds, labels, output_margin): def check_pred(preds, labels, output_margin):
if output_margin: if output_margin:
err = sum(1 for i in range(len(preds)) err = sum(
if preds[i].argmax() != labels[i]) / float(len(preds)) 1 for i in range(len(preds)) if preds[i].argmax() != labels[i]
) / float(len(preds))
else: else:
err = sum(1 for i in range(len(preds)) err = sum(1 for i in range(len(preds)) if preds[i] != labels[i]) / float(
if preds[i] != labels[i]) / float(len(preds)) len(preds)
)
assert err < 0.4 assert err < 0.4
iris = load_iris() X, y = load_iris(return_X_y=True)
y = iris['target']
X = iris['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y): for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index]) xgb_model = xgb.XGBClassifier(objective=objective).fit(
assert (xgb_model.get_booster().num_boosted_rounds() == X[train_index], y[train_index]
xgb_model.n_estimators) )
assert xgb_model.get_booster().num_boosted_rounds() == 100
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit # test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
ntree_limit=3) preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds3 = xgb_model.predict(X[test_index], output_margin=True, preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False,
ntree_limit=3)
labels = y[test_index] labels = y[test_index]
check_pred(preds, labels, output_margin=False) check_pred(preds, labels, output_margin=False)
@ -761,9 +759,9 @@ def test_parameters_access():
clf = save_load(clf) clf = save_load(clf)
assert clf.tree_method is None assert clf.tree_method is None
assert clf.n_estimators == 2 assert clf.n_estimators is None
assert clf.get_params()["tree_method"] is None assert clf.get_params()["tree_method"] is None
assert clf.get_params()["n_estimators"] == 2 assert clf.get_params()["n_estimators"] is None
assert get_tm(clf) == "auto" # discarded for save/load_model assert get_tm(clf) == "auto" # discarded for save/load_model
clf.set_params(tree_method="hist") clf.set_params(tree_method="hist")
@ -771,9 +769,7 @@ def test_parameters_access():
clf = pickle.loads(pickle.dumps(clf)) clf = pickle.loads(pickle.dumps(clf))
assert clf.get_params()["tree_method"] == "hist" assert clf.get_params()["tree_method"] == "hist"
clf = save_load(clf) clf = save_load(clf)
# FIXME(jiamingy): We should remove this behavior once we remove parameters assert clf.get_params()["tree_method"] is None
# serialization for skl save/load_model.
assert clf.get_params()["tree_method"] == "hist"
def test_kwargs_error(): def test_kwargs_error():
@ -902,6 +898,7 @@ def save_load_model(model_path):
xgb_model.load_model(model_path) xgb_model.load_model(model_path)
assert isinstance(xgb_model.classes_, np.ndarray) assert isinstance(xgb_model.classes_, np.ndarray)
np.testing.assert_equal(xgb_model.classes_, np.array([0, 1]))
assert isinstance(xgb_model._Booster, xgb.Booster) assert isinstance(xgb_model._Booster, xgb.Booster)
preds = xgb_model.predict(X[test_index]) preds = xgb_model.predict(X[test_index])
@ -933,8 +930,10 @@ def test_save_load_model():
save_load_model(model_path) save_load_model(model_path)
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, 'digits.model.json') model_path = os.path.join(tempdir, 'digits.model.ubj')
digits = load_digits(n_class=2) digits = load_digits(n_class=2)
y = digits['target'] y = digits['target']
X = digits['data'] X = digits['data']
@ -959,6 +958,28 @@ def test_save_load_model():
predt_1 = cls.predict(X) predt_1 = cls.predict(X)
assert np.allclose(predt_0, predt_1) assert np.allclose(predt_0, predt_1)
# mclass
X, y = load_digits(n_class=10, return_X_y=True)
# small test_size to force early stop
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.01, random_state=1
)
clf = xgb.XGBClassifier(
n_estimators=64, tree_method="hist", early_stopping_rounds=2
)
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
score = clf.best_score
clf.save_model(model_path)
clf = xgb.XGBClassifier()
clf.load_model(model_path)
assert clf.classes_.size == 10
np.testing.assert_equal(clf.classes_, np.arange(10))
assert clf.n_classes_ == 10
assert clf.best_iteration == 27
assert clf.best_score == score
def test_RFECV(): def test_RFECV():
from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris from sklearn.datasets import load_breast_cancer, load_diabetes, load_iris