Support _estimator_type. (#6582)
* Use `_estimator_type`. For more info, see: https://scikit-learn.org/stable/developers/develop.html#estimator-types * Model trained from dask can be loaded by single node skl interface.
This commit is contained in:
@@ -16,6 +16,12 @@ from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||
XGBClassifierBase, XGBRegressorBase, XGBoostLabelEncoder)
|
||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
|
||||
classes."""
|
||||
_estimator_type = "ranker"
|
||||
|
||||
|
||||
def _objective_decorator(func):
|
||||
"""Decorate an objective function
|
||||
|
||||
@@ -407,6 +413,14 @@ class XGBModel(XGBModelBase):
|
||||
"""Gets the number of xgboost boosting rounds."""
|
||||
return self.n_estimators
|
||||
|
||||
def _get_type(self) -> str:
|
||||
if not hasattr(self, '_estimator_type'):
|
||||
raise TypeError(
|
||||
"`_estimator_type` undefined. "
|
||||
"Please use appropriate mixin to define estimator type."
|
||||
)
|
||||
return self._estimator_type # pylint: disable=no-member
|
||||
|
||||
def save_model(self, fname: str):
|
||||
"""Save the model to a file.
|
||||
|
||||
@@ -442,7 +456,7 @@ class XGBModel(XGBModelBase):
|
||||
meta[k] = v
|
||||
except TypeError:
|
||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||
meta['type'] = type(self).__name__
|
||||
meta['_estimator_type'] = self._get_type()
|
||||
meta_str = json.dumps(meta)
|
||||
self.get_booster().set_attr(scikit_learn=meta_str)
|
||||
self.get_booster().save_model(fname)
|
||||
@@ -484,11 +498,12 @@ class XGBModel(XGBModelBase):
|
||||
if k == 'use_label_encoder':
|
||||
self.use_label_encoder = bool(v)
|
||||
continue
|
||||
if k == 'type' and type(self).__name__ != v:
|
||||
msg = 'Current model type: {}, '.format(type(self).__name__) + \
|
||||
'type of model in file: {}'.format(v)
|
||||
raise TypeError(msg)
|
||||
if k == 'type':
|
||||
if k == "_estimator_type":
|
||||
if self._get_type() != v:
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. "
|
||||
f"Expecting: {self._get_type()}, got: {v}"
|
||||
)
|
||||
continue
|
||||
states[k] = v
|
||||
self.__dict__.update(states)
|
||||
@@ -1211,7 +1226,7 @@ class XGBRFRegressor(XGBRegressor):
|
||||
|
||||
then your group array should be ``[3, 4]``.
|
||||
''')
|
||||
class XGBRanker(XGBModel):
|
||||
class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||
@_deprecate_positional_args
|
||||
def __init__(self, *, objective='rank:pairwise', **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user