Support _estimator_type. (#6582)

* Use `_estimator_type`.

For more info, see: https://scikit-learn.org/stable/developers/develop.html#estimator-types

* Model trained from dask can be loaded by single node skl interface.
This commit is contained in:
Jiaming Yuan 2021-01-08 10:01:16 +08:00 committed by GitHub
parent 8747885a8b
commit f5ff90cd87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 7 deletions

View File

@ -16,6 +16,12 @@ from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
classes."""
_estimator_type = "ranker"
def _objective_decorator(func):
"""Decorate an objective function
@ -407,6 +413,14 @@ class XGBModel(XGBModelBase):
"""Gets the number of xgboost boosting rounds."""
return self.n_estimators
def _get_type(self) -> str:
if not hasattr(self, '_estimator_type'):
raise TypeError(
"`_estimator_type` undefined. "
"Please use appropriate mixin to define estimator type."
)
return self._estimator_type # pylint: disable=no-member
def save_model(self, fname: str):
"""Save the model to a file.
@ -442,7 +456,7 @@ class XGBModel(XGBModelBase):
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
meta['type'] = type(self).__name__
meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
self.get_booster().save_model(fname)
@ -484,11 +498,12 @@ class XGBModel(XGBModelBase):
if k == 'use_label_encoder':
self.use_label_encoder = bool(v)
continue
if k == 'type' and type(self).__name__ != v:
msg = 'Current model type: {}, '.format(type(self).__name__) + \
'type of model in file: {}'.format(v)
raise TypeError(msg)
if k == 'type':
if k == "_estimator_type":
if self._get_type() != v:
raise TypeError(
"Loading an estimator with different type. "
f"Expecting: {self._get_type()}, got: {v}"
)
continue
states[k] = v
self.__dict__.update(states)
@ -1211,7 +1226,7 @@ class XGBRFRegressor(XGBRegressor):
then your group array should be ``[3, 4]``.
''')
class XGBRanker(XGBModel):
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
def __init__(self, *, objective='rank:pairwise', **kwargs):

View File

@ -10,6 +10,7 @@ from typing import List, Tuple, Union, Dict, Optional, Callable, Type
import asyncio
import tempfile
from sklearn.datasets import make_classification
import sklearn
import os
import subprocess
from hypothesis import given, settings, note
@ -261,6 +262,9 @@ def test_dask_regressor() -> None:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
assert regressor._estimator_type == "regressor"
assert sklearn.base.is_regressor(regressor)
regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
@ -285,6 +289,9 @@ def test_dask_classifier() -> None:
y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric='merror')
assert classifier._estimator_type == "classifier"
assert sklearn.base.is_classifier(classifier)
classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)
@ -946,6 +953,35 @@ class TestWithDask:
# Subtract the on disk resource from each worker
assert cnt - n_workers == n_partitions
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_io(self, client: 'Client') -> None:
from sklearn.datasets import load_digits
X_, y_ = load_digits(return_X_y=True)
X, y = da.from_array(X_), da.from_array(y_)
cls = xgb.dask.DaskXGBClassifier(n_estimators=10)
cls.client = client
cls.fit(X, y)
predt_0 = cls.predict(X)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'cls.json')
cls.save_model(path)
cls = xgb.dask.DaskXGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_1 = cls.predict(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
# Use single node to load
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_2 = cls.predict(X_)
np.testing.assert_allclose(predt_0.compute(), predt_2)
class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())

View File

@ -1099,3 +1099,26 @@ def test_boost_from_prediction_approx():
@pytest.mark.skipif(**tm.no_sklearn())
def test_boost_from_prediction_exact():
run_boost_from_prediction('exact')
def test_estimator_type():
assert xgb.XGBClassifier._estimator_type == "classifier"
assert xgb.XGBRFClassifier._estimator_type == "classifier"
assert xgb.XGBRegressor._estimator_type == "regressor"
assert xgb.XGBRFRegressor._estimator_type == "regressor"
assert xgb.XGBRanker._estimator_type == "ranker"
from sklearn.datasets import load_digits
X, y = load_digits(n_class=2, return_X_y=True)
cls = xgb.XGBClassifier(n_estimators=2).fit(X, y)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "cls.json")
cls.save_model(path)
reg = xgb.XGBRegressor()
with pytest.raises(TypeError):
reg.load_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path) # no error