[Breaking] Remove Scikit-Learn default parameters (#5130)

* Simplify Scikit-Learn parameter management.

* Copy base class for removing duplicated parameter signatures.
* Set all parameters to None.
* Handle None in set_param.
* Extract the doc.

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
OrdoAbChao 2020-01-23 13:25:20 +01:00 committed by Jiaming Yuan
parent aa9a68010b
commit b4f952bd22
6 changed files with 270 additions and 293 deletions

View File

@ -1219,7 +1219,9 @@ class Booster(object):
elif isinstance(params, STRING_TYPES) and value is not None: elif isinstance(params, STRING_TYPES) and value is not None:
params = [(params, value)] params = [(params, value)]
for key, val in params: for key, val in params:
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))) if val is not None:
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
c_str(str(val))))
def update(self, dtrain, iteration, fobj=None): def update(self, dtrain, iteration, fobj=None):
"""Update for one iteration, with objective function calculated """Update for one iteration, with objective function calculated

View File

@ -30,7 +30,7 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
from .core import DMatrix, Booster, _expect from .core import DMatrix, Booster, _expect
from .training import train as worker_train from .training import train as worker_train
from .tracker import RabitTracker from .tracker import RabitTracker
from .sklearn import XGBModel, XGBClassifierBase from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
# Current status is considered as initial support, many features are # Current status is considered as initial support, many features are
# not properly supported yet. # not properly supported yet.
@ -580,13 +580,10 @@ class DaskScikitLearnBase(XGBModel):
def client(self, clt): def client(self, clt):
self._client = clt self._client = clt
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase): class DaskXGBRegressor(DaskScikitLearnBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'regression. \n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self, def fit(self,
X, X,
y, y,
@ -616,12 +613,13 @@ class DaskXGBRegressor(DaskScikitLearnBase):
return pred_probs return pred_probs
@xgboost_model_doc(
'Implementation of the scikit-learn API for XGBoost classification.',
['estimators', 'model']
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
_client = None _client = None
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'classification.\n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self, def fit(self,
X, X,

View File

@ -1,14 +1,15 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302 # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
"""Scikit-Learn Wrapper interface for XGBoost.""" """Scikit-Learn Wrapper interface for XGBoost."""
import copy
import warnings import warnings
import json import json
import numpy as np import numpy as np
from .core import Booster, DMatrix, XGBoostError from .core import Booster, DMatrix, XGBoostError
from .training import train from .training import train
# Do not use class names on scikit-learn directly. # Do not use class names on scikit-learn directly. Re-define the classes on
# Re-define the classes on .compat to guarantee the behavior without scikit-learn # .compat to guarantee the behavior without scikit-learn
from .compat import (SKLEARN_INSTALLED, XGBModelBase, from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
@ -48,18 +49,17 @@ def _objective_decorator(func):
return inner return inner
class XGBModel(XGBModelBase): __estimator_doc = '''
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name n_estimators : int
"""Implementation of the Scikit-Learn API for XGBoost. Number of gradient boosted trees. Equivalent to number of boosting
rounds.
'''
Parameters __model_doc = '''
----------
max_depth : int max_depth : int
Maximum tree depth for base learners. Maximum tree depth for base learners.
learning_rate : float learning_rate : float
Boosting learning rate (xgb's "eta") Boosting learning rate (xgb's "eta")
n_estimators : int
Number of trees to fit.
verbosity : int verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug). The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string or callable objective : string or callable
@ -75,7 +75,8 @@ class XGBModel(XGBModelBase):
n_jobs : int n_jobs : int
Number of parallel threads used to run xgboost. Number of parallel threads used to run xgboost.
gamma : float gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree. Minimum loss reduction required to make a further partition on a leaf
node of the tree.
min_child_weight : int min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child. Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int max_delta_step : int
@ -112,17 +113,21 @@ class XGBModel(XGBModelBase):
importance_type: string, default "gain" importance_type: string, default "gain"
The feature importance type for the feature_importances\\_ property: The feature importance type for the feature_importances\\_ property:
either "gain", "weight", "cover", "total_gain" or "total_cover". either "gain", "weight", "cover", "total_gain" or "total_cover".
\\*\\*kwargs : dict, optional \\*\\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can Keyword arguments for XGBoost Booster object. Full documentation of
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst. parameters can be found here:
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict simultaneously https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
will result in a TypeError. Attempting to set a parameter via the constructor args and \\*\\*kwargs
dict simultaneously will result in a TypeError.
.. note:: \\*\\*kwargs unsupported by scikit-learn .. note:: \\*\\*kwargs unsupported by scikit-learn
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters \\*\\*kwargs is unsupported by scikit-learn. We do not guarantee
passed via this argument will interact properly with scikit-learn. that parameters passed via this argument will interact properly
with scikit-learn. '''
__custom_obj_note = '''
Note Note
---- ----
A custom objective function can be provided for the ``objective`` A custom objective function can be provided for the ``objective``
@ -138,25 +143,72 @@ class XGBModel(XGBModelBase):
The value of the gradient for each sample point. The value of the gradient for each sample point.
hess: array_like of shape [n_samples] hess: array_like of shape [n_samples]
The value of the second derivative for each sample point The value of the second derivative for each sample point
'''
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, def xgboost_model_doc(header, items, extra_parameters=None, end_note=None):
verbosity=1, objective="reg:squarederror", '''Obtain documentation for Scikit-Learn wrappers
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1, Parameters
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1, ----------
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, header: str
random_state=0, missing=None, num_parallel_tree=1, An introducion to the class.
importance_type="gain", **kwargs): items : list
A list of commom doc items. Available items are:
- estimators: the meaning of n_estimators
- model: All the other parameters
- objective: note for customized objective
extra_parameters: str
Document for class specific parameters, placed at the head.
end_note: str
Extra notes put to the end.
'''
def get_doc(item):
'''Return selected item'''
__doc = {'estimators': __estimator_doc,
'model': __model_doc,
'objective': __custom_obj_note}
return __doc[item]
def adddoc(cls):
doc = ['''
Parameters
----------
''']
if extra_parameters:
doc.append(extra_parameters)
doc.extend([get_doc(i) for i in items])
if end_note:
doc.append(end_note)
full_doc = [header + '\n\n']
full_doc.extend(doc)
cls.__doc__ = ''.join(full_doc)
return cls
return adddoc
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model', 'objective'])
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
verbosity=None, objective=None, booster=None,
tree_method=None, n_jobs=None, gamma=None,
min_child_weight=None, max_delta_step=None, subsample=None,
colsample_bytree=None, colsample_bylevel=None,
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
scale_pos_weight=None, base_score=None, random_state=None,
missing=None, num_parallel_tree=None, importance_type="gain",
gpu_id=None, **kwargs):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise XGBoostError( raise XGBoostError(
'sklearn needs to be installed in order to use this module') 'sklearn needs to be installed in order to use this module')
self.n_estimators = n_estimators
self.objective = objective
self.max_depth = max_depth self.max_depth = max_depth
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.verbosity = verbosity self.verbosity = verbosity
self.objective = objective
self.booster = booster self.booster = booster
self.tree_method = tree_method self.tree_method = tree_method
self.gamma = gamma self.gamma = gamma
@ -176,6 +228,7 @@ class XGBModel(XGBModelBase):
self._Booster = None self._Booster = None
self.random_state = random_state self.random_state = random_state
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.gpu_id = gpu_id
self.importance_type = importance_type self.importance_type = importance_type
def __setstate__(self, state): def __setstate__(self, state):
@ -201,18 +254,22 @@ class XGBModel(XGBModelBase):
return self._Booster return self._Booster
def set_params(self, **params): def set_params(self, **params):
"""Set the parameters of this estimator. """Set the parameters of this estimator. Modification of the sklearn method to
Modification of the sklearn method to allow unknown kwargs. This allows using allow unknown kwargs. This allows using the full range of xgboost
the full range of xgboost parameters that are not defined as member variables parameters that are not defined as member variables in sklearn grid
in sklearn grid search. search.
Returns Returns
------- -------
self self
""" """
if not params: if not params:
# Simple optimization to gain speed (inspect is slow) # Simple optimization to gain speed (inspect is slow)
return self return self
# this concatenates kwargs into paraemters, enabling `get_params` for
# obtaining parameters from keyword paraemters.
for key, value in params.items(): for key, value in params.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@ -221,16 +278,26 @@ class XGBModel(XGBModelBase):
return self return self
def get_params(self, deep=False): def get_params(self, deep=True):
# pylint: disable=attribute-defined-outside-init
"""Get parameters.""" """Get parameters."""
params = super(XGBModel, self).get_params(deep=deep) # Based on: https://stackoverflow.com/questions/59248211
# The basic flow in `get_params` is:
# 0. Return parameters in subclass first, by using inspect.
# 1. Return parameters in `XGBModel` (the base class).
# 2. Return whatever in `**kwargs`.
# 3. Merge them.
params = super().get_params(deep)
if hasattr(self, '__copy__'):
warnings.warn('Calling __copy__ on Scikit-Learn wrapper, ' +
'which may disable data cache and result in ' +
'lower performance.')
cp = copy.copy(self)
cp.__class__ = cp.__class__.__bases__[0]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly # if kwargs is a dict, update params accordingly
if isinstance(self.kwargs, dict): if isinstance(self.kwargs, dict):
params.update(self.kwargs) params.update(self.kwargs)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
if isinstance(params['random_state'], np.random.RandomState): if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint( params['random_state'] = params['random_state'].randint(
np.iinfo(np.int32).max) np.iinfo(np.int32).max)
@ -354,10 +421,11 @@ class XGBModel(XGBModelBase):
[xgb.callback.reset_learning_rate(custom_rates)] [xgb.callback.reset_learning_rate(custom_rates)]
""" """
trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight, train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, missing=self.missing,
nthread=self.n_jobs) nthread=self.n_jobs)
evals_result = {} evals_result = {}
if eval_set is not None: if eval_set is not None:
@ -389,7 +457,7 @@ class XGBModel(XGBModelBase):
else: else:
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
self._Booster = train(params, trainDmatrix, self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals, self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result, obj=obj, feval=feval,
@ -419,13 +487,6 @@ class XGBModel(XGBModelBase):
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``. of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python .. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round) preds = bst.predict(dtest, ntree_limit=num_round)
@ -539,7 +600,8 @@ class XGBModel(XGBModelBase):
feature_importances_ : array of shape ``[n_features]`` feature_importances_ : array of shape ``[n_features]``
""" """
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: if getattr(self, 'booster', None) is not None and self.booster not in {
'gbtree', 'dart'}:
raise AttributeError('Feature importance is not defined for Booster type {}' raise AttributeError('Feature importance is not defined for Booster type {}'
.format(self.booster)) .format(self.booster))
b = self.get_booster() b = self.get_booster()
@ -555,9 +617,9 @@ class XGBModel(XGBModelBase):
.. note:: Coefficients are defined only for linear learners .. note:: Coefficients are defined only for linear learners
Coefficients are only defined when the linear model is chosen as base Coefficients are only defined when the linear model is chosen as
learner (`booster=gblinear`). It is not defined for other base learner types, such base learner (`booster=gblinear`). It is not defined for other base
as tree learners (`booster=gbtree`). learner types, such as tree learners (`booster=gbtree`).
Returns Returns
------- -------
@ -599,33 +661,13 @@ class XGBModel(XGBModelBase):
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective'])
class XGBClassifier(XGBModel, XGBClassifierBase): class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes # pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \ def __init__(self, objective="binary:logistic", **kwargs):
+ '\n'.join(XGBModel.__doc__.split('\n')[2:]) super().__init__(objective=objective, **kwargs)
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, verbosity=1,
objective="binary:logistic", booster='gbtree',
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def fit(self, X, y, sample_weight=None, base_margin=None, def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, eval_set=None, eval_metric=None,
@ -647,8 +689,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
obj = None obj = None
if self.n_classes_ > 2: if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance # Switch to using a multiclass objective in the underlying
xgb_options["objective"] = "multi:softprob" # XGB instance
xgb_options['objective'] = 'multi:softprob'
xgb_options['num_class'] = self.n_classes_ xgb_options['num_class'] = self.n_classes_
feval = eval_metric if callable(eval_metric) else None feval = eval_metric if callable(eval_metric) else None
@ -665,7 +708,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if sample_weight_eval_set is None: if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set) sample_weight_eval_set = [None] * len(eval_set)
evals = list( evals = list(
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]), DMatrix(eval_set[i][0],
label=self._le.transform(eval_set[i][1]),
missing=self.missing, weight=sample_weight_eval_set[i], missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs) nthread=self.n_jobs)
for i in range(len(eval_set)) for i in range(len(eval_set))
@ -686,8 +730,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs) missing=self.missing, nthread=self.n_jobs)
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(), self._Booster = train(xgb_options, train_dmatrix,
evals=evals, early_stopping_rounds=early_stopping_rounds, self.get_num_boosting_rounds(),
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks) callbacks=callbacks)
@ -696,7 +742,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if evals_result: if evals_result:
for val in evals_result.items(): for val in evals_result.items():
evals_result_key = list(val[1].keys())[0] evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key] evals_result[val[0]][
evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result self.evals_result_ = evals_result
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
@ -706,8 +753,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self return self
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model', fit.__doc__ = XGBModel.fit.__doc__.replace(
'Fit gradient boosting classifier', 1) 'Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
def predict(self, data, output_margin=False, ntree_limit=None, def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None): validate_features=True, base_margin=None):
@ -717,15 +765,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. note:: This function is not thread safe. .. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread. For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies If you want to run prediction using multiple thread, call
of model object and then call ``predict()``. ``xgb.copy()`` to make copies of model object and then call
``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python .. code-block:: python
@ -738,11 +780,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin : bool output_margin : bool
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined Limit number of trees in the prediction; defaults to
(i.e. it has been trained with early stopping), otherwise 0 (use all trees). best_ntree_limit if defined (i.e. it has been trained with early
stopping), otherwise 0 (use all trees).
validate_features : bool validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical. When this is True, validate that the Booster's and data's
Otherwise, it is assumed that the feature_names are the same. feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
@ -773,9 +818,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. note:: This function is not thread safe .. note:: This function is not thread safe
For each booster object, predict can only be called from one thread. For each booster object, predict can only be called from one
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies thread. If you want to run prediction using multiple thread, call
of model object and then call predict ``xgb.copy()`` to make copies of model object and then call predict
Parameters Parameters
---------- ----------
@ -849,30 +894,26 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return evals_result return evals_result
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.",
['model', 'objective'],
extra_parameters='''
n_estimators : int
Number of trees in random forest to fit.
''')
class XGBRFClassifier(XGBClassifier): class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\ def __init__(self,
+ '\n'.join(XGBModel.__doc__.split('\n')[2:]) learning_rate=1,
subsample=0.8,
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, colsample_bynode=0.8,
verbosity=1, objective="binary:logistic", n_jobs=1, reg_lambda=1e-5,
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0, **kwargs):
subsample=0.8, colsample_bytree=1, colsample_bylevel=1, super().__init__(learning_rate=learning_rate,
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5, subsample=subsample,
scale_pos_weight=1, base_score=0.5, random_state=0, colsample_bynode=colsample_bynode,
missing=None, **kwargs): reg_lambda=reg_lambda,
super(XGBRFClassifier, self).__init__( **kwargs)
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def get_xgb_params(self): def get_xgb_params(self):
params = super(XGBRFClassifier, self).get_xgb_params() params = super(XGBRFClassifier, self).get_xgb_params()
@ -883,37 +924,25 @@ class XGBRFClassifier(XGBClassifier):
return 1 return 1
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost regression.",
['estimators', 'model', 'objective'])
class XGBRegressor(XGBModel, XGBRegressorBase): class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\ def __init__(self, objective="reg:squarederror", **kwargs):
+ '\n'.join(XGBModel.__doc__.split('\n')[2:]) super().__init__(objective=objective, **kwargs)
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
['model', 'objective'])
class XGBRFRegressor(XGBRegressor): class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring # pylint: disable=missing-docstring
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\ def __init__(self, learning_rate=1, subsample=0.8, colsample_bynode=0.8,
+ '\n'.join(XGBModel.__doc__.split('\n')[2:]) reg_lambda=1e-5, **kwargs):
super().__init__(learning_rate=learning_rate, subsample=subsample,
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, colsample_bynode=colsample_bynode,
verbosity=1, objective="reg:squarederror", n_jobs=1, reg_lambda=reg_lambda, **kwargs)
gpu_id=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRFRegressor, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def get_xgb_params(self): def get_xgb_params(self):
params = super(XGBRFRegressor, self).get_xgb_params() params = super(XGBRFRegressor, self).get_xgb_params()
@ -924,71 +953,10 @@ class XGBRFRegressor(XGBRegressor):
return 1 return 1
class XGBRanker(XGBModel): @xgboost_model_doc(
# pylint: disable=missing-docstring,too-many-arguments,invalid-name 'Implementation of the Scikit-Learn API for XGBoost Ranking.',
"""Implementation of the Scikit-Learn API for XGBoost Ranking. ['estimators', 'model'],
end_note='''
Parameters
----------
max_depth : int
Maximum tree depth for base learners.
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string
Specify the learning task and the corresponding learning objective.
The objective name must start with "rank:".
booster: string
Specify which booster to use: gbtree, gblinear or dart.
n_jobs : int
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
random_state : int
Random number seed.
.. note::
Using gblinear booster with shotgun updater is nondeterministic as
it uses Hogwild algorithm.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
\\*\\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict
simultaneously will result in a TypeError.
.. note:: \\*\\*kwargs unsupported by scikit-learn
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
passed via this argument will interact properly with scikit-learn.
Note Note
---- ----
A custom objective function is currently not supported by XGBRanker. A custom objective function is currently not supported by XGBRanker.
@ -998,9 +966,9 @@ class XGBRanker(XGBModel):
---- ----
Query group information is required for ranking tasks. Query group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by query group. When Before fitting the model, your data need to be sorted by query
fitting the model, you need to provide an additional array that group. When fitting the model, you need to provide an additional array
contains the size of each query group. that contains the size of each query group.
For example, if your original data look like: For example, if your original data look like:
@ -1023,29 +991,11 @@ class XGBRanker(XGBModel):
+-------+-----------+---------------+ +-------+-----------+---------------+
then your group array should be ``[3, 4]``. then your group array should be ``[3, 4]``.
''')
""" class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, def __init__(self, objective='rank:pairwise', **kwargs):
verbosity=1, objective="rank:pairwise", booster='gbtree', super().__init__(objective=objective, **kwargs)
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRanker, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
if callable(self.objective): if callable(self.objective):
raise ValueError( raise ValueError(
"custom objective function not supported by XGBRanker") "custom objective function not supported by XGBRanker")
@ -1058,8 +1008,7 @@ class XGBRanker(XGBModel):
early_stopping_rounds=None, verbose=False, xgb_model=None, early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None): callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ # pylint: disable = attribute-defined-outside-init,arguments-differ
""" """Fit gradient boosting ranker
Fit gradient boosting ranker
Parameters Parameters
---------- ----------
@ -1068,17 +1017,17 @@ class XGBRanker(XGBModel):
y : array_like y : array_like
Labels Labels
group : array_like group : array_like
Size of each query group of training data. Should have as many elements as Size of each query group of training data. Should have as many
the query groups in the training data elements as the query groups in the training data
sample_weight : array_like sample_weight : array_like
Query group weights Query group weights
.. note:: Weights are per-group for ranking tasks .. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each query group (not each In ranking task, one weight is assigned to each query group
data point). This is because we only care about the relative ordering of (not each data point). This is because we only care about the
data points within each group, so it doesn't make sense to assign relative ordering of data points within each group, so it
weights to individual data points. doesn't make sense to assign weights to individual data points.
base_margin : array_like base_margin : array_like
Global bias for each instance. Global bias for each instance.
@ -1112,24 +1061,26 @@ class XGBRanker(XGBModel):
The method returns the model from the last iteration (not the best one). The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used If there's more than one item in **eval_set**, the last entry will be used
for early stopping. for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be If there's more than one metric in **eval_metric**, the last metric
used for early stopping. will be used for early stopping.
If early stopping occurs, the model will have three additional fields: If early stopping occurs, the model will have three additional
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. fields: ``clf.best_score``, ``clf.best_iteration`` and
``clf.best_ntree_limit``.
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
xgb_model : str xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost model to be file name of stored XGBoost model or 'Booster' instance XGBoost
loaded before training (allows training continuation). model to be loaded before training (allows training continuation).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each
It is possible to use predefined callbacks by using :ref:`callback_api`. iteration. It is possible to use predefined callbacks by using
Example: :ref:`callback_api`. Example:
.. code-block:: python .. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)] [xgb.callback.reset_learning_rate(custom_rates)]
""" """
# check if group information is provided # check if group information is provided
if group is None: if group is None:
@ -1137,11 +1088,14 @@ class XGBRanker(XGBModel):
if eval_set is not None: if eval_set is not None:
if eval_group is None: if eval_group is None:
raise ValueError("eval_group is required if eval_set is not None") raise ValueError(
"eval_group is required if eval_set is not None")
if len(eval_group) != len(eval_set): if len(eval_group) != len(eval_set):
raise ValueError("length of eval_group should match that of eval_set") raise ValueError(
"length of eval_group should match that of eval_set")
if any(group is None for group in eval_group): if any(group is None for group in eval_group):
raise ValueError("group is required for all eval datasets for ranking task") raise ValueError(
"group is required for all eval datasets for ranking task")
def _dmat_init(group, **params): def _dmat_init(group, **params):
ret = DMatrix(**params) ret = DMatrix(**params)
@ -1158,9 +1112,13 @@ class XGBRanker(XGBModel):
if eval_set is not None: if eval_set is not None:
if sample_weight_eval_set is None: if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set) sample_weight_eval_set = [None] * len(eval_set)
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1], evals = [_dmat_init(eval_group[i],
missing=self.missing, weight=sample_weight_eval_set[i], data=eval_set[i][0],
nthread=self.n_jobs) for i in range(len(eval_set))] label=eval_set[i][1],
missing=self.missing,
weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
for i in range(len(eval_set))]
nevals = len(evals) nevals = len(evals)
eval_names = ["eval_{}".format(i) for i in range(nevals)] eval_names = ["eval_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names)) evals = list(zip(evals, eval_names))
@ -1172,13 +1130,14 @@ class XGBRanker(XGBModel):
feval = eval_metric if callable(eval_metric) else None feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None: if eval_metric is not None:
if callable(eval_metric): if callable(eval_metric):
raise ValueError('Custom evaluation metric is not yet supported' + raise ValueError(
'for XGBRanker.') 'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.n_estimators, self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals, early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval, evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks) callbacks=callbacks)

View File

@ -3,7 +3,8 @@
# pylint: disable=too-many-branches, too-many-statements # pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines.""" """Training Library containing training routines."""
import numpy as np import numpy as np
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv
from .core import EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit from . import rabit
from . import callback from . import callback
@ -37,10 +38,11 @@ def _train_internal(params, dtrain,
_params = dict(params) if isinstance(params, list) else params _params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params: if 'num_parallel_tree' in _params and params[
'num_parallel_tree'] is not None:
num_parallel_tree = _params['num_parallel_tree'] num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree nboost //= num_parallel_tree
if 'num_class' in _params: if 'num_class' in _params and _params['num_class'] is not None:
nboost //= _params['num_class'] nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit. # Distributed code: Load the checkpoint from rabit.

View File

@ -22,17 +22,17 @@ class TestEarlyStopping(unittest.TestCase):
y = digits['target'] y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0) random_state=0)
clf1 = xgb.XGBClassifier() clf1 = xgb.XGBClassifier(learning_rate=0.1)
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc", clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
eval_set=[(X_test, y_test)]) eval_set=[(X_test, y_test)])
clf2 = xgb.XGBClassifier() clf2 = xgb.XGBClassifier(learning_rate=0.1)
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc", clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc",
eval_set=[(X_test, y_test)]) eval_set=[(X_test, y_test)])
# should be the same # should be the same
assert clf1.best_score == clf2.best_score assert clf1.best_score == clf2.best_score
assert clf1.best_score != 1 assert clf1.best_score != 1
# check overfit # check overfit
clf3 = xgb.XGBClassifier() clf3 = xgb.XGBClassifier(learning_rate=0.1)
clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc", clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
eval_set=[(X_test, y_test)]) eval_set=[(X_test, y_test)])
assert clf3.best_score == 1 assert clf3.best_score == 1

View File

@ -6,6 +6,7 @@ import os
import shutil import shutil
import pytest import pytest
import unittest import unittest
import json
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -117,9 +118,10 @@ def test_feature_importances_weight():
digits = load_digits(2) digits = load_digits(2)
y = digits['target'] y = digits['target']
X = digits['data'] X = digits['data']
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(random_state=0,
random_state=0, tree_method="exact", importance_type="weight").fit(X, y) tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0., exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0.,
@ -134,12 +136,16 @@ def test_feature_importances_weight():
import pandas as pd import pandas as pd
y = pd.Series(digits['target']) y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data']) X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(random_state=0,
random_state=0, tree_method="exact", importance_type="weight").fit(X, y) tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(random_state=0,
random_state=0, tree_method="exact", importance_type="weight").fit(X, y) tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -151,7 +157,9 @@ def test_feature_importances_gain():
y = digits['target'] y = digits['target']
X = digits['data'] X = digits['data']
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y) random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0., 0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
@ -169,11 +177,15 @@ def test_feature_importances_gain():
y = pd.Series(digits['target']) y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data']) X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y) random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier( xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y) random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -191,6 +203,10 @@ def test_num_parallel_tree():
dump = bst.get_booster().get_dump(dump_format='json') dump = bst.get_booster().get_dump(dump_format='json')
assert len(dump) == 4 assert len(dump) == 4
config = json.loads(bst.get_booster().save_config())
assert int(config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']) == 4
def test_boston_housing_regression(): def test_boston_housing_regression():
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
@ -244,7 +260,7 @@ def test_parameter_tuning():
boston = load_boston() boston = load_boston()
y = boston['target'] y = boston['target']
X = boston['data'] X = boston['data']
xgb_model = xgb.XGBRegressor() xgb_model = xgb.XGBRegressor(learning_rate=0.1)
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6], clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
'n_estimators': [50, 100, 200]}, 'n_estimators': [50, 100, 200]},
cv=3, verbose=1, iid=True) cv=3, verbose=1, iid=True)