[Breaking] Remove Scikit-Learn default parameters (#5130)

* Simplify Scikit-Learn parameter management.

* Copy base class for removing duplicated parameter signatures.
* Set all parameters to None.
* Handle None in set_param.
* Extract the doc.

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
OrdoAbChao 2020-01-23 13:25:20 +01:00 committed by Jiaming Yuan
parent aa9a68010b
commit b4f952bd22
6 changed files with 270 additions and 293 deletions

View File

@ -1219,7 +1219,9 @@ class Booster(object):
elif isinstance(params, STRING_TYPES) and value is not None:
params = [(params, value)]
for key, val in params:
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))))
if val is not None:
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
c_str(str(val))))
def update(self, dtrain, iteration, fobj=None):
"""Update for one iteration, with objective function calculated

View File

@ -30,7 +30,7 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
from .core import DMatrix, Booster, _expect
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBClassifierBase
from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
# Current status is considered as initial support, many features are
# not properly supported yet.
@ -580,13 +580,10 @@ class DaskScikitLearnBase(XGBModel):
def client(self, clt):
self._client = clt
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model'])
class DaskXGBRegressor(DaskScikitLearnBase):
# pylint: disable=missing-docstring
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'regression. \n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self,
X,
y,
@ -616,12 +613,13 @@ class DaskXGBRegressor(DaskScikitLearnBase):
return pred_probs
@xgboost_model_doc(
'Implementation of the scikit-learn API for XGBoost classification.',
['estimators', 'model']
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-docstring
_client = None
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
'classification.\n\n') + '\n'.join(
XGBModel.__doc__.split('\n')[2:])
def fit(self,
X,

View File

@ -1,14 +1,15 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
"""Scikit-Learn Wrapper interface for XGBoost."""
import copy
import warnings
import json
import numpy as np
from .core import Booster, DMatrix, XGBoostError
from .training import train
# Do not use class names on scikit-learn directly.
# Re-define the classes on .compat to guarantee the behavior without scikit-learn
# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
@ -48,18 +49,17 @@ def _objective_decorator(func):
return inner
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
"""Implementation of the Scikit-Learn API for XGBoost.
__estimator_doc = '''
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
rounds.
'''
Parameters
----------
__model_doc = '''
max_depth : int
Maximum tree depth for base learners.
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string or callable
@ -75,7 +75,8 @@ class XGBModel(XGBModelBase):
n_jobs : int
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
Minimum loss reduction required to make a further partition on a leaf
node of the tree.
min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
@ -112,17 +113,21 @@ class XGBModel(XGBModelBase):
importance_type: string, default "gain"
The feature importance type for the feature_importances\\_ property:
either "gain", "weight", "cover", "total_gain" or "total_cover".
\\*\\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict simultaneously
will result in a TypeError.
Keyword arguments for XGBoost Booster object. Full documentation of
parameters can be found here:
https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and \\*\\*kwargs
dict simultaneously will result in a TypeError.
.. note:: \\*\\*kwargs unsupported by scikit-learn
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
passed via this argument will interact properly with scikit-learn.
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee
that parameters passed via this argument will interact properly
with scikit-learn. '''
__custom_obj_note = '''
Note
----
A custom objective function can be provided for the ``objective``
@ -138,25 +143,72 @@ class XGBModel(XGBModelBase):
The value of the gradient for each sample point.
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
'''
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
verbosity=1, objective="reg:squarederror",
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, num_parallel_tree=1,
importance_type="gain", **kwargs):
def xgboost_model_doc(header, items, extra_parameters=None, end_note=None):
'''Obtain documentation for Scikit-Learn wrappers
Parameters
----------
header: str
An introducion to the class.
items : list
A list of commom doc items. Available items are:
- estimators: the meaning of n_estimators
- model: All the other parameters
- objective: note for customized objective
extra_parameters: str
Document for class specific parameters, placed at the head.
end_note: str
Extra notes put to the end.
'''
def get_doc(item):
'''Return selected item'''
__doc = {'estimators': __estimator_doc,
'model': __model_doc,
'objective': __custom_obj_note}
return __doc[item]
def adddoc(cls):
doc = ['''
Parameters
----------
''']
if extra_parameters:
doc.append(extra_parameters)
doc.extend([get_doc(i) for i in items])
if end_note:
doc.append(end_note)
full_doc = [header + '\n\n']
full_doc.extend(doc)
cls.__doc__ = ''.join(full_doc)
return cls
return adddoc
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model', 'objective'])
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
verbosity=None, objective=None, booster=None,
tree_method=None, n_jobs=None, gamma=None,
min_child_weight=None, max_delta_step=None, subsample=None,
colsample_bytree=None, colsample_bylevel=None,
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
scale_pos_weight=None, base_score=None, random_state=None,
missing=None, num_parallel_tree=None, importance_type="gain",
gpu_id=None, **kwargs):
if not SKLEARN_INSTALLED:
raise XGBoostError(
'sklearn needs to be installed in order to use this module')
self.n_estimators = n_estimators
self.objective = objective
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.verbosity = verbosity
self.objective = objective
self.booster = booster
self.tree_method = tree_method
self.gamma = gamma
@ -176,6 +228,7 @@ class XGBModel(XGBModelBase):
self._Booster = None
self.random_state = random_state
self.n_jobs = n_jobs
self.gpu_id = gpu_id
self.importance_type = importance_type
def __setstate__(self, state):
@ -201,18 +254,22 @@ class XGBModel(XGBModelBase):
return self._Booster
def set_params(self, **params):
"""Set the parameters of this estimator.
Modification of the sklearn method to allow unknown kwargs. This allows using
the full range of xgboost parameters that are not defined as member variables
in sklearn grid search.
"""Set the parameters of this estimator. Modification of the sklearn method to
allow unknown kwargs. This allows using the full range of xgboost
parameters that are not defined as member variables in sklearn grid
search.
Returns
-------
self
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
# this concatenates kwargs into paraemters, enabling `get_params` for
# obtaining parameters from keyword paraemters.
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
@ -221,16 +278,26 @@ class XGBModel(XGBModelBase):
return self
def get_params(self, deep=False):
def get_params(self, deep=True):
# pylint: disable=attribute-defined-outside-init
"""Get parameters."""
params = super(XGBModel, self).get_params(deep=deep)
# Based on: https://stackoverflow.com/questions/59248211
# The basic flow in `get_params` is:
# 0. Return parameters in subclass first, by using inspect.
# 1. Return parameters in `XGBModel` (the base class).
# 2. Return whatever in `**kwargs`.
# 3. Merge them.
params = super().get_params(deep)
if hasattr(self, '__copy__'):
warnings.warn('Calling __copy__ on Scikit-Learn wrapper, ' +
'which may disable data cache and result in ' +
'lower performance.')
cp = copy.copy(self)
cp.__class__ = cp.__class__.__bases__[0]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly
if isinstance(self.kwargs, dict):
params.update(self.kwargs)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(
np.iinfo(np.int32).max)
@ -354,10 +421,11 @@ class XGBModel(XGBModelBase):
[xgb.callback.reset_learning_rate(custom_rates)]
"""
trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs)
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs)
evals_result = {}
if eval_set is not None:
@ -389,7 +457,7 @@ class XGBModel(XGBModelBase):
else:
params.update({'eval_metric': eval_metric})
self._Booster = train(params, trainDmatrix,
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
@ -419,13 +487,6 @@ class XGBModel(XGBModelBase):
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
@ -539,7 +600,8 @@ class XGBModel(XGBModelBase):
feature_importances_ : array of shape ``[n_features]``
"""
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
if getattr(self, 'booster', None) is not None and self.booster not in {
'gbtree', 'dart'}:
raise AttributeError('Feature importance is not defined for Booster type {}'
.format(self.booster))
b = self.get_booster()
@ -555,9 +617,9 @@ class XGBModel(XGBModelBase):
.. note:: Coefficients are defined only for linear learners
Coefficients are only defined when the linear model is chosen as base
learner (`booster=gblinear`). It is not defined for other base learner types, such
as tree learners (`booster=gbtree`).
Coefficients are only defined when the linear model is chosen as
base learner (`booster=gblinear`). It is not defined for other base
learner types, such as tree learners (`booster=gbtree`).
Returns
-------
@ -599,33 +661,13 @@ class XGBModel(XGBModelBase):
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost classification.",
['model', 'objective'])
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, verbosity=1,
objective="binary:logistic", booster='gbtree',
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def __init__(self, objective="binary:logistic", **kwargs):
super().__init__(objective=objective, **kwargs)
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None,
@ -647,8 +689,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
obj = None
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance
xgb_options["objective"] = "multi:softprob"
# Switch to using a multiclass objective in the underlying
# XGB instance
xgb_options['objective'] = 'multi:softprob'
xgb_options['num_class'] = self.n_classes_
feval = eval_metric if callable(eval_metric) else None
@ -665,7 +708,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = list(
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]),
DMatrix(eval_set[i][0],
label=self._le.transform(eval_set[i][1]),
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
for i in range(len(eval_set))
@ -686,8 +730,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
evals=evals, early_stopping_rounds=early_stopping_rounds,
self._Booster = train(xgb_options, train_dmatrix,
self.get_num_boosting_rounds(),
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
@ -696,7 +742,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
evals_result[val[0]][
evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result
if early_stopping_rounds is not None:
@ -706,8 +753,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
fit.__doc__ = XGBModel.fit.__doc__.replace(
'Fit gradient boosting model',
'Fit gradient boosting classifier', 1)
def predict(self, data, output_margin=False, ntree_limit=None,
validate_features=True, base_margin=None):
@ -717,15 +765,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
a nonzero value, e.g.
If you want to run prediction using multiple thread, call
``xgb.copy()`` to make copies of model object and then call
``predict()``.
.. code-block:: python
@ -738,11 +780,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin : bool
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Limit number of trees in the prediction; defaults to
best_ntree_limit if defined (i.e. it has been trained with early
stopping), otherwise 0 (use all trees).
validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical.
Otherwise, it is assumed that the feature_names are the same.
When this is True, validate that the Booster's and data's
feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
Returns
-------
prediction : numpy array
@ -773,9 +818,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
.. note:: This function is not thread safe
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call predict
For each booster object, predict can only be called from one
thread. If you want to run prediction using multiple thread, call
``xgb.copy()`` to make copies of model object and then call predict
Parameters
----------
@ -849,30 +894,26 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return evals_result
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest classification.",
['model', 'objective'],
extra_parameters='''
n_estimators : int
Number of trees in random forest to fit.
''')
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
verbosity=1, objective="binary:logistic", n_jobs=1,
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=0.8, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
scale_pos_weight=1, base_score=0.5, random_state=0,
missing=None, **kwargs):
super(XGBRFClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def __init__(self,
learning_rate=1,
subsample=0.8,
colsample_bynode=0.8,
reg_lambda=1e-5,
**kwargs):
super().__init__(learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs)
def get_xgb_params(self):
params = super(XGBRFClassifier, self).get_xgb_params()
@ -883,37 +924,25 @@ class XGBRFClassifier(XGBClassifier):
return 1
@xgboost_model_doc(
"Implementation of the scikit-learn API for XGBoost regression.",
['estimators', 'model', 'objective'])
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, objective="reg:squarederror", **kwargs):
super().__init__(objective=objective, **kwargs)
@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
['model', 'objective'])
class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
verbosity=1, objective="reg:squarederror", n_jobs=1,
gpu_id=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRFRegressor, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def __init__(self, learning_rate=1, subsample=0.8, colsample_bynode=0.8,
reg_lambda=1e-5, **kwargs):
super().__init__(learning_rate=learning_rate, subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda, **kwargs)
def get_xgb_params(self):
params = super(XGBRFRegressor, self).get_xgb_params()
@ -924,71 +953,10 @@ class XGBRFRegressor(XGBRegressor):
return 1
class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
Parameters
----------
max_depth : int
Maximum tree depth for base learners.
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string
Specify the learning task and the corresponding learning objective.
The objective name must start with "rank:".
booster: string
Specify which booster to use: gbtree, gblinear or dart.
n_jobs : int
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
random_state : int
Random number seed.
.. note::
Using gblinear booster with shotgun updater is nondeterministic as
it uses Hogwild algorithm.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
\\*\\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict
simultaneously will result in a TypeError.
.. note:: \\*\\*kwargs unsupported by scikit-learn
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
passed via this argument will interact properly with scikit-learn.
@xgboost_model_doc(
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
['estimators', 'model'],
end_note='''
Note
----
A custom objective function is currently not supported by XGBRanker.
@ -998,9 +966,9 @@ class XGBRanker(XGBModel):
----
Query group information is required for ranking tasks.
Before fitting the model, your data need to be sorted by query group. When
fitting the model, you need to provide an additional array that
contains the size of each query group.
Before fitting the model, your data need to be sorted by query
group. When fitting the model, you need to provide an additional array
that contains the size of each query group.
For example, if your original data look like:
@ -1023,29 +991,11 @@ class XGBRanker(XGBModel):
+-------+-----------+---------------+
then your group array should be ``[3, 4]``.
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
verbosity=1, objective="rank:pairwise", booster='gbtree',
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRanker, self).__init__(
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
''')
class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
def __init__(self, objective='rank:pairwise', **kwargs):
super().__init__(objective=objective, **kwargs)
if callable(self.objective):
raise ValueError(
"custom objective function not supported by XGBRanker")
@ -1058,8 +1008,7 @@ class XGBRanker(XGBModel):
early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting ranker
"""Fit gradient boosting ranker
Parameters
----------
@ -1068,17 +1017,17 @@ class XGBRanker(XGBModel):
y : array_like
Labels
group : array_like
Size of each query group of training data. Should have as many elements as
the query groups in the training data
Size of each query group of training data. Should have as many
elements as the query groups in the training data
sample_weight : array_like
Query group weights
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each query group (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign
weights to individual data points.
In ranking task, one weight is assigned to each query group
(not each data point). This is because we only care about the
relative ordering of data points within each group, so it
doesn't make sense to assign weights to individual data points.
base_margin : array_like
Global bias for each instance.
@ -1112,24 +1061,26 @@ class XGBRanker(XGBModel):
The method returns the model from the last iteration (not the best one).
If there's more than one item in **eval_set**, the last entry will be used
for early stopping.
If there's more than one metric in **eval_metric**, the last metric will be
used for early stopping.
If early stopping occurs, the model will have three additional fields:
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
If there's more than one metric in **eval_metric**, the last metric
will be used for early stopping.
If early stopping occurs, the model will have three additional
fields: ``clf.best_score``, ``clf.best_iteration`` and
``clf.best_ntree_limit``.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
file name of stored XGBoost model or 'Booster' instance XGBoost
model to be loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
List of callback functions that are applied at end of each
iteration. It is possible to use predefined callbacks by using
:ref:`callback_api`. Example:
.. code-block:: python
[xgb.callback.reset_learning_rate(custom_rates)]
"""
# check if group information is provided
if group is None:
@ -1137,11 +1088,14 @@ class XGBRanker(XGBModel):
if eval_set is not None:
if eval_group is None:
raise ValueError("eval_group is required if eval_set is not None")
raise ValueError(
"eval_group is required if eval_set is not None")
if len(eval_group) != len(eval_set):
raise ValueError("length of eval_group should match that of eval_set")
raise ValueError(
"length of eval_group should match that of eval_set")
if any(group is None for group in eval_group):
raise ValueError("group is required for all eval datasets for ranking task")
raise ValueError(
"group is required for all eval datasets for ranking task")
def _dmat_init(group, **params):
ret = DMatrix(**params)
@ -1158,9 +1112,13 @@ class XGBRanker(XGBModel):
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1],
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs) for i in range(len(eval_set))]
evals = [_dmat_init(eval_group[i],
data=eval_set[i][0],
label=eval_set[i][1],
missing=self.missing,
weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
for i in range(len(eval_set))]
nevals = len(evals)
eval_names = ["eval_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names))
@ -1172,13 +1130,14 @@ class XGBRanker(XGBModel):
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
raise ValueError('Custom evaluation metric is not yet supported' +
'for XGBRanker.')
raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix,
self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)

View File

@ -3,7 +3,8 @@
# pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines."""
import numpy as np
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv
from .core import EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit
from . import callback
@ -37,10 +38,11 @@ def _train_internal(params, dtrain,
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params:
if 'num_parallel_tree' in _params and params[
'num_parallel_tree'] is not None:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
if 'num_class' in _params and _params['num_class'] is not None:
nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit.

View File

@ -22,17 +22,17 @@ class TestEarlyStopping(unittest.TestCase):
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
clf1 = xgb.XGBClassifier()
clf1 = xgb.XGBClassifier(learning_rate=0.1)
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
eval_set=[(X_test, y_test)])
clf2 = xgb.XGBClassifier()
clf2 = xgb.XGBClassifier(learning_rate=0.1)
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc",
eval_set=[(X_test, y_test)])
# should be the same
assert clf1.best_score == clf2.best_score
assert clf1.best_score != 1
# check overfit
clf3 = xgb.XGBClassifier()
clf3 = xgb.XGBClassifier(learning_rate=0.1)
clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
eval_set=[(X_test, y_test)])
assert clf3.best_score == 1

View File

@ -6,6 +6,7 @@ import os
import shutil
import pytest
import unittest
import json
rng = np.random.RandomState(1994)
@ -117,9 +118,10 @@ def test_feature_importances_weight():
digits = load_digits(2)
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(random_state=0,
tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0.,
@ -134,12 +136,16 @@ def test_feature_importances_weight():
import pandas as pd
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(random_state=0,
tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(random_state=0,
tree_method="exact",
learning_rate=0.1,
importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -151,7 +157,9 @@ def test_feature_importances_gain():
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
@ -169,11 +177,15 @@ def test_feature_importances_gain():
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -191,6 +203,10 @@ def test_num_parallel_tree():
dump = bst.get_booster().get_dump(dump_format='json')
assert len(dump) == 4
config = json.loads(bst.get_booster().save_config())
assert int(config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']) == 4
def test_boston_housing_regression():
from sklearn.metrics import mean_squared_error
@ -244,7 +260,7 @@ def test_parameter_tuning():
boston = load_boston()
y = boston['target']
X = boston['data']
xgb_model = xgb.XGBRegressor()
xgb_model = xgb.XGBRegressor(learning_rate=0.1)
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
'n_estimators': [50, 100, 200]},
cv=3, verbose=1, iid=True)