diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 7359ec124..71f593b71 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -16,6 +16,12 @@ from .compat import (SKLEARN_INSTALLED, XGBModelBase, XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder) +class XGBRankerMixIn: # pylint: disable=too-few-public-methods + """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base + classes.""" + _estimator_type = "ranker" + + def _objective_decorator(func): """Decorate an objective function @@ -407,6 +413,14 @@ class XGBModel(XGBModelBase): """Gets the number of xgboost boosting rounds.""" return self.n_estimators + def _get_type(self) -> str: + if not hasattr(self, '_estimator_type'): + raise TypeError( + "`_estimator_type` undefined. " + "Please use appropriate mixin to define estimator type." + ) + return self._estimator_type # pylint: disable=no-member + def save_model(self, fname: str): """Save the model to a file. @@ -442,7 +456,7 @@ class XGBModel(XGBModelBase): meta[k] = v except TypeError: warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') - meta['type'] = type(self).__name__ + meta['_estimator_type'] = self._get_type() meta_str = json.dumps(meta) self.get_booster().set_attr(scikit_learn=meta_str) self.get_booster().save_model(fname) @@ -484,11 +498,12 @@ class XGBModel(XGBModelBase): if k == 'use_label_encoder': self.use_label_encoder = bool(v) continue - if k == 'type' and type(self).__name__ != v: - msg = 'Current model type: {}, '.format(type(self).__name__) + \ - 'type of model in file: {}'.format(v) - raise TypeError(msg) - if k == 'type': + if k == "_estimator_type": + if self._get_type() != v: + raise TypeError( + "Loading an estimator with different type. " + f"Expecting: {self._get_type()}, got: {v}" + ) continue states[k] = v self.__dict__.update(states) @@ -1211,7 +1226,7 @@ class XGBRFRegressor(XGBRegressor): then your group array should be ``[3, 4]``. ''') -class XGBRanker(XGBModel): +class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args def __init__(self, *, objective='rank:pairwise', **kwargs): diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index d3c2f988a..180a64400 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -10,6 +10,7 @@ from typing import List, Tuple, Union, Dict, Optional, Callable, Type import asyncio import tempfile from sklearn.datasets import make_classification +import sklearn import os import subprocess from hypothesis import given, settings, note @@ -261,6 +262,9 @@ def test_dask_regressor() -> None: with Client(cluster) as client: X, y, w = generate_array(with_weights=True) regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) + assert regressor._estimator_type == "regressor" + assert sklearn.base.is_regressor(regressor) + regressor.set_params(tree_method='hist') regressor.client = client regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)]) @@ -285,6 +289,9 @@ def test_dask_classifier() -> None: y = (y * 10).astype(np.int32) classifier = xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2, eval_metric='merror') + assert classifier._estimator_type == "classifier" + assert sklearn.base.is_classifier(classifier) + classifier.client = client classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)]) prediction = classifier.predict(X) @@ -946,6 +953,35 @@ class TestWithDask: # Subtract the on disk resource from each worker assert cnt - n_workers == n_partitions + @pytest.mark.skipif(**tm.no_sklearn()) + def test_sklearn_io(self, client: 'Client') -> None: + from sklearn.datasets import load_digits + X_, y_ = load_digits(return_X_y=True) + X, y = da.from_array(X_), da.from_array(y_) + cls = xgb.dask.DaskXGBClassifier(n_estimators=10) + cls.client = client + cls.fit(X, y) + predt_0 = cls.predict(X) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'cls.json') + cls.save_model(path) + + cls = xgb.dask.DaskXGBClassifier() + cls.load_model(path) + assert cls.n_classes_ == 10 + predt_1 = cls.predict(X) + + np.testing.assert_allclose(predt_0.compute(), predt_1.compute()) + + # Use single node to load + cls = xgb.XGBClassifier() + cls.load_model(path) + assert cls.n_classes_ == 10 + predt_2 = cls.predict(X_) + + np.testing.assert_allclose(predt_0.compute(), predt_2) + class TestDaskCallbacks: @pytest.mark.skipif(**tm.no_sklearn()) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 5d105b5a0..99a1a5702 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1099,3 +1099,26 @@ def test_boost_from_prediction_approx(): @pytest.mark.skipif(**tm.no_sklearn()) def test_boost_from_prediction_exact(): run_boost_from_prediction('exact') + + +def test_estimator_type(): + assert xgb.XGBClassifier._estimator_type == "classifier" + assert xgb.XGBRFClassifier._estimator_type == "classifier" + assert xgb.XGBRegressor._estimator_type == "regressor" + assert xgb.XGBRFRegressor._estimator_type == "regressor" + assert xgb.XGBRanker._estimator_type == "ranker" + + from sklearn.datasets import load_digits + + X, y = load_digits(n_class=2, return_X_y=True) + cls = xgb.XGBClassifier(n_estimators=2).fit(X, y) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "cls.json") + cls.save_model(path) + + reg = xgb.XGBRegressor() + with pytest.raises(TypeError): + reg.load_model(path) + + cls = xgb.XGBClassifier() + cls.load_model(path) # no error