[Breaking] Remove Scikit-Learn default parameters (#5130)
* Simplify Scikit-Learn parameter management. * Copy base class for removing duplicated parameter signatures. * Set all parameters to None. * Handle None in set_param. * Extract the doc. Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
aa9a68010b
commit
b4f952bd22
@ -1219,7 +1219,9 @@ class Booster(object):
|
|||||||
elif isinstance(params, STRING_TYPES) and value is not None:
|
elif isinstance(params, STRING_TYPES) and value is not None:
|
||||||
params = [(params, value)]
|
params = [(params, value)]
|
||||||
for key, val in params:
|
for key, val in params:
|
||||||
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val))))
|
if val is not None:
|
||||||
|
_check_call(_LIB.XGBoosterSetParam(self.handle, c_str(key),
|
||||||
|
c_str(str(val))))
|
||||||
|
|
||||||
def update(self, dtrain, iteration, fobj=None):
|
def update(self, dtrain, iteration, fobj=None):
|
||||||
"""Update for one iteration, with objective function calculated
|
"""Update for one iteration, with objective function calculated
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from .compat import CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_concat
|
|||||||
from .core import DMatrix, Booster, _expect
|
from .core import DMatrix, Booster, _expect
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker
|
from .tracker import RabitTracker
|
||||||
from .sklearn import XGBModel, XGBClassifierBase
|
from .sklearn import XGBModel, XGBClassifierBase, xgboost_model_doc
|
||||||
|
|
||||||
# Current status is considered as initial support, many features are
|
# Current status is considered as initial support, many features are
|
||||||
# not properly supported yet.
|
# not properly supported yet.
|
||||||
@ -580,13 +580,10 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
def client(self, clt):
|
def client(self, clt):
|
||||||
self._client = clt
|
self._client = clt
|
||||||
|
|
||||||
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase):
|
class DaskXGBRegressor(DaskScikitLearnBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
|
|
||||||
'regression. \n\n') + '\n'.join(
|
|
||||||
XGBModel.__doc__.split('\n')[2:])
|
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
X,
|
X,
|
||||||
y,
|
y,
|
||||||
@ -616,12 +613,13 @@ class DaskXGBRegressor(DaskScikitLearnBase):
|
|||||||
return pred_probs
|
return pred_probs
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc(
|
||||||
|
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||||
|
['estimators', 'model']
|
||||||
|
)
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
_client = None
|
_client = None
|
||||||
__doc__ = ('Implementation of the scikit-learn API for XGBoost ' +
|
|
||||||
'classification.\n\n') + '\n'.join(
|
|
||||||
XGBModel.__doc__.split('\n')[2:])
|
|
||||||
|
|
||||||
def fit(self,
|
def fit(self,
|
||||||
X,
|
X,
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
|
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912, C0302
|
||||||
"""Scikit-Learn Wrapper interface for XGBoost."""
|
"""Scikit-Learn Wrapper interface for XGBoost."""
|
||||||
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
from .training import train
|
from .training import train
|
||||||
|
|
||||||
# Do not use class names on scikit-learn directly.
|
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||||
# Re-define the classes on .compat to guarantee the behavior without scikit-learn
|
# .compat to guarantee the behavior without scikit-learn
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||||
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
||||||
|
|
||||||
@ -48,18 +49,17 @@ def _objective_decorator(func):
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
class XGBModel(XGBModelBase):
|
__estimator_doc = '''
|
||||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
|
n_estimators : int
|
||||||
"""Implementation of the Scikit-Learn API for XGBoost.
|
Number of gradient boosted trees. Equivalent to number of boosting
|
||||||
|
rounds.
|
||||||
|
'''
|
||||||
|
|
||||||
Parameters
|
__model_doc = '''
|
||||||
----------
|
|
||||||
max_depth : int
|
max_depth : int
|
||||||
Maximum tree depth for base learners.
|
Maximum tree depth for base learners.
|
||||||
learning_rate : float
|
learning_rate : float
|
||||||
Boosting learning rate (xgb's "eta")
|
Boosting learning rate (xgb's "eta")
|
||||||
n_estimators : int
|
|
||||||
Number of trees to fit.
|
|
||||||
verbosity : int
|
verbosity : int
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||||
objective : string or callable
|
objective : string or callable
|
||||||
@ -75,7 +75,8 @@ class XGBModel(XGBModelBase):
|
|||||||
n_jobs : int
|
n_jobs : int
|
||||||
Number of parallel threads used to run xgboost.
|
Number of parallel threads used to run xgboost.
|
||||||
gamma : float
|
gamma : float
|
||||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
Minimum loss reduction required to make a further partition on a leaf
|
||||||
|
node of the tree.
|
||||||
min_child_weight : int
|
min_child_weight : int
|
||||||
Minimum sum of instance weight(hessian) needed in a child.
|
Minimum sum of instance weight(hessian) needed in a child.
|
||||||
max_delta_step : int
|
max_delta_step : int
|
||||||
@ -112,17 +113,21 @@ class XGBModel(XGBModelBase):
|
|||||||
importance_type: string, default "gain"
|
importance_type: string, default "gain"
|
||||||
The feature importance type for the feature_importances\\_ property:
|
The feature importance type for the feature_importances\\_ property:
|
||||||
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
||||||
|
|
||||||
\\*\\*kwargs : dict, optional
|
\\*\\*kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
Keyword arguments for XGBoost Booster object. Full documentation of
|
||||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
parameters can be found here:
|
||||||
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict simultaneously
|
https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||||
will result in a TypeError.
|
Attempting to set a parameter via the constructor args and \\*\\*kwargs
|
||||||
|
dict simultaneously will result in a TypeError.
|
||||||
|
|
||||||
.. note:: \\*\\*kwargs unsupported by scikit-learn
|
.. note:: \\*\\*kwargs unsupported by scikit-learn
|
||||||
|
|
||||||
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee
|
||||||
passed via this argument will interact properly with scikit-learn.
|
that parameters passed via this argument will interact properly
|
||||||
|
with scikit-learn. '''
|
||||||
|
|
||||||
|
__custom_obj_note = '''
|
||||||
Note
|
Note
|
||||||
----
|
----
|
||||||
A custom objective function can be provided for the ``objective``
|
A custom objective function can be provided for the ``objective``
|
||||||
@ -138,25 +143,72 @@ class XGBModel(XGBModelBase):
|
|||||||
The value of the gradient for each sample point.
|
The value of the gradient for each sample point.
|
||||||
hess: array_like of shape [n_samples]
|
hess: array_like of shape [n_samples]
|
||||||
The value of the second derivative for each sample point
|
The value of the second derivative for each sample point
|
||||||
|
'''
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def xgboost_model_doc(header, items, extra_parameters=None, end_note=None):
|
||||||
verbosity=1, objective="reg:squarederror",
|
'''Obtain documentation for Scikit-Learn wrappers
|
||||||
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
|
|
||||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
Parameters
|
||||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
----------
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
header: str
|
||||||
random_state=0, missing=None, num_parallel_tree=1,
|
An introducion to the class.
|
||||||
importance_type="gain", **kwargs):
|
items : list
|
||||||
|
A list of commom doc items. Available items are:
|
||||||
|
- estimators: the meaning of n_estimators
|
||||||
|
- model: All the other parameters
|
||||||
|
- objective: note for customized objective
|
||||||
|
extra_parameters: str
|
||||||
|
Document for class specific parameters, placed at the head.
|
||||||
|
end_note: str
|
||||||
|
Extra notes put to the end.
|
||||||
|
'''
|
||||||
|
def get_doc(item):
|
||||||
|
'''Return selected item'''
|
||||||
|
__doc = {'estimators': __estimator_doc,
|
||||||
|
'model': __model_doc,
|
||||||
|
'objective': __custom_obj_note}
|
||||||
|
return __doc[item]
|
||||||
|
|
||||||
|
def adddoc(cls):
|
||||||
|
doc = ['''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
''']
|
||||||
|
if extra_parameters:
|
||||||
|
doc.append(extra_parameters)
|
||||||
|
doc.extend([get_doc(i) for i in items])
|
||||||
|
if end_note:
|
||||||
|
doc.append(end_note)
|
||||||
|
full_doc = [header + '\n\n']
|
||||||
|
full_doc.extend(doc)
|
||||||
|
cls.__doc__ = ''.join(full_doc)
|
||||||
|
return cls
|
||||||
|
return adddoc
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
|
['estimators', 'model', 'objective'])
|
||||||
|
class XGBModel(XGBModelBase):
|
||||||
|
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name, missing-docstring
|
||||||
|
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
|
||||||
|
verbosity=None, objective=None, booster=None,
|
||||||
|
tree_method=None, n_jobs=None, gamma=None,
|
||||||
|
min_child_weight=None, max_delta_step=None, subsample=None,
|
||||||
|
colsample_bytree=None, colsample_bylevel=None,
|
||||||
|
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
|
||||||
|
scale_pos_weight=None, base_score=None, random_state=None,
|
||||||
|
missing=None, num_parallel_tree=None, importance_type="gain",
|
||||||
|
gpu_id=None, **kwargs):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
'sklearn needs to be installed in order to use this module')
|
'sklearn needs to be installed in order to use this module')
|
||||||
|
self.n_estimators = n_estimators
|
||||||
|
self.objective = objective
|
||||||
|
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_estimators = n_estimators
|
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
self.objective = objective
|
|
||||||
self.booster = booster
|
self.booster = booster
|
||||||
self.tree_method = tree_method
|
self.tree_method = tree_method
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@ -176,6 +228,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self._Booster = None
|
self._Booster = None
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.importance_type = importance_type
|
self.importance_type = importance_type
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
@ -201,18 +254,22 @@ class XGBModel(XGBModelBase):
|
|||||||
return self._Booster
|
return self._Booster
|
||||||
|
|
||||||
def set_params(self, **params):
|
def set_params(self, **params):
|
||||||
"""Set the parameters of this estimator.
|
"""Set the parameters of this estimator. Modification of the sklearn method to
|
||||||
Modification of the sklearn method to allow unknown kwargs. This allows using
|
allow unknown kwargs. This allows using the full range of xgboost
|
||||||
the full range of xgboost parameters that are not defined as member variables
|
parameters that are not defined as member variables in sklearn grid
|
||||||
in sklearn grid search.
|
search.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
self
|
self
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not params:
|
if not params:
|
||||||
# Simple optimization to gain speed (inspect is slow)
|
# Simple optimization to gain speed (inspect is slow)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# this concatenates kwargs into paraemters, enabling `get_params` for
|
||||||
|
# obtaining parameters from keyword paraemters.
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@ -221,16 +278,26 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_params(self, deep=False):
|
def get_params(self, deep=True):
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
"""Get parameters."""
|
"""Get parameters."""
|
||||||
params = super(XGBModel, self).get_params(deep=deep)
|
# Based on: https://stackoverflow.com/questions/59248211
|
||||||
|
# The basic flow in `get_params` is:
|
||||||
|
# 0. Return parameters in subclass first, by using inspect.
|
||||||
|
# 1. Return parameters in `XGBModel` (the base class).
|
||||||
|
# 2. Return whatever in `**kwargs`.
|
||||||
|
# 3. Merge them.
|
||||||
|
params = super().get_params(deep)
|
||||||
|
if hasattr(self, '__copy__'):
|
||||||
|
warnings.warn('Calling __copy__ on Scikit-Learn wrapper, ' +
|
||||||
|
'which may disable data cache and result in ' +
|
||||||
|
'lower performance.')
|
||||||
|
cp = copy.copy(self)
|
||||||
|
cp.__class__ = cp.__class__.__bases__[0]
|
||||||
|
params.update(cp.__class__.get_params(cp, deep))
|
||||||
# if kwargs is a dict, update params accordingly
|
# if kwargs is a dict, update params accordingly
|
||||||
if isinstance(self.kwargs, dict):
|
if isinstance(self.kwargs, dict):
|
||||||
params.update(self.kwargs)
|
params.update(self.kwargs)
|
||||||
if params['missing'] is np.nan:
|
|
||||||
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
|
||||||
if not params.get('eval_metric', True):
|
|
||||||
del params['eval_metric'] # don't give as None param to Booster
|
|
||||||
if isinstance(params['random_state'], np.random.RandomState):
|
if isinstance(params['random_state'], np.random.RandomState):
|
||||||
params['random_state'] = params['random_state'].randint(
|
params['random_state'] = params['random_state'].randint(
|
||||||
np.iinfo(np.int32).max)
|
np.iinfo(np.int32).max)
|
||||||
@ -354,10 +421,11 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
"""
|
"""
|
||||||
trainDmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
nthread=self.n_jobs)
|
nthread=self.n_jobs)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
@ -389,7 +457,7 @@ class XGBModel(XGBModelBase):
|
|||||||
else:
|
else:
|
||||||
params.update({'eval_metric': eval_metric})
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
self._Booster = train(params, trainDmatrix,
|
self._Booster = train(params, train_dmatrix,
|
||||||
self.get_num_boosting_rounds(), evals=evals,
|
self.get_num_boosting_rounds(), evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
@ -419,13 +487,6 @@ class XGBModel(XGBModelBase):
|
|||||||
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
|
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
|
||||||
of model object and then call ``predict()``.
|
of model object and then call ``predict()``.
|
||||||
|
|
||||||
.. note:: Using ``predict()`` with DART booster
|
|
||||||
|
|
||||||
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
|
|
||||||
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
|
|
||||||
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
|
|
||||||
a nonzero value, e.g.
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||||
@ -539,7 +600,8 @@ class XGBModel(XGBModelBase):
|
|||||||
feature_importances_ : array of shape ``[n_features]``
|
feature_importances_ : array of shape ``[n_features]``
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
if getattr(self, 'booster', None) is not None and self.booster not in {
|
||||||
|
'gbtree', 'dart'}:
|
||||||
raise AttributeError('Feature importance is not defined for Booster type {}'
|
raise AttributeError('Feature importance is not defined for Booster type {}'
|
||||||
.format(self.booster))
|
.format(self.booster))
|
||||||
b = self.get_booster()
|
b = self.get_booster()
|
||||||
@ -555,9 +617,9 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
.. note:: Coefficients are defined only for linear learners
|
.. note:: Coefficients are defined only for linear learners
|
||||||
|
|
||||||
Coefficients are only defined when the linear model is chosen as base
|
Coefficients are only defined when the linear model is chosen as
|
||||||
learner (`booster=gblinear`). It is not defined for other base learner types, such
|
base learner (`booster=gblinear`). It is not defined for other base
|
||||||
as tree learners (`booster=gbtree`).
|
learner types, such as tree learners (`booster=gbtree`).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -599,33 +661,13 @@ class XGBModel(XGBModelBase):
|
|||||||
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc(
|
||||||
|
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||||
|
['model', 'objective'])
|
||||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name,too-many-instance-attributes
|
||||||
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
def __init__(self, objective="binary:logistic", **kwargs):
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
|
||||||
n_estimators=100, verbosity=1,
|
|
||||||
objective="binary:logistic", booster='gbtree',
|
|
||||||
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
|
|
||||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
|
||||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
|
||||||
random_state=0, missing=None, **kwargs):
|
|
||||||
super(XGBClassifier, self).__init__(
|
|
||||||
max_depth=max_depth, learning_rate=learning_rate,
|
|
||||||
n_estimators=n_estimators, verbosity=verbosity,
|
|
||||||
objective=objective, booster=booster, tree_method=tree_method,
|
|
||||||
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
|
||||||
min_child_weight=min_child_weight,
|
|
||||||
max_delta_step=max_delta_step, subsample=subsample,
|
|
||||||
colsample_bytree=colsample_bytree,
|
|
||||||
colsample_bylevel=colsample_bylevel,
|
|
||||||
colsample_bynode=colsample_bynode,
|
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
|
||||||
scale_pos_weight=scale_pos_weight,
|
|
||||||
base_score=base_score, random_state=random_state, missing=missing,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, base_margin=None,
|
def fit(self, X, y, sample_weight=None, base_margin=None,
|
||||||
eval_set=None, eval_metric=None,
|
eval_set=None, eval_metric=None,
|
||||||
@ -647,8 +689,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying XGB instance
|
# Switch to using a multiclass objective in the underlying
|
||||||
xgb_options["objective"] = "multi:softprob"
|
# XGB instance
|
||||||
|
xgb_options['objective'] = 'multi:softprob'
|
||||||
xgb_options['num_class'] = self.n_classes_
|
xgb_options['num_class'] = self.n_classes_
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
@ -665,7 +708,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
if sample_weight_eval_set is None:
|
if sample_weight_eval_set is None:
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
evals = list(
|
evals = list(
|
||||||
DMatrix(eval_set[i][0], label=self._le.transform(eval_set[i][1]),
|
DMatrix(eval_set[i][0],
|
||||||
|
label=self._le.transform(eval_set[i][1]),
|
||||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||||
nthread=self.n_jobs)
|
nthread=self.n_jobs)
|
||||||
for i in range(len(eval_set))
|
for i in range(len(eval_set))
|
||||||
@ -686,8 +730,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
|
self._Booster = train(xgb_options, train_dmatrix,
|
||||||
evals=evals, early_stopping_rounds=early_stopping_rounds,
|
self.get_num_boosting_rounds(),
|
||||||
|
evals=evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result, obj=obj, feval=feval,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=xgb_model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
@ -696,7 +742,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
for val in evals_result.items():
|
||||||
evals_result_key = list(val[1].keys())[0]
|
evals_result_key = list(val[1].keys())[0]
|
||||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
evals_result[val[0]][
|
||||||
|
evals_result_key] = val[1][evals_result_key]
|
||||||
self.evals_result_ = evals_result
|
self.evals_result_ = evals_result
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
@ -706,7 +753,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
fit.__doc__ = XGBModel.fit.__doc__.replace('Fit gradient boosting model',
|
fit.__doc__ = XGBModel.fit.__doc__.replace(
|
||||||
|
'Fit gradient boosting model',
|
||||||
'Fit gradient boosting classifier', 1)
|
'Fit gradient boosting classifier', 1)
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=None,
|
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||||
@ -717,15 +765,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
.. note:: This function is not thread safe.
|
.. note:: This function is not thread safe.
|
||||||
|
|
||||||
For each booster object, predict can only be called from one thread.
|
For each booster object, predict can only be called from one thread.
|
||||||
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
|
If you want to run prediction using multiple thread, call
|
||||||
of model object and then call ``predict()``.
|
``xgb.copy()`` to make copies of model object and then call
|
||||||
|
``predict()``.
|
||||||
.. note:: Using ``predict()`` with DART booster
|
|
||||||
|
|
||||||
If the booster object is DART type, ``predict()`` will perform dropouts, i.e. only
|
|
||||||
some of the trees will be evaluated. This will produce incorrect results if ``data`` is
|
|
||||||
not the training data. To obtain correct results on test sets, set ``ntree_limit`` to
|
|
||||||
a nonzero value, e.g.
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -738,11 +780,14 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
output_margin : bool
|
output_margin : bool
|
||||||
Whether to output the raw untransformed margin value.
|
Whether to output the raw untransformed margin value.
|
||||||
ntree_limit : int
|
ntree_limit : int
|
||||||
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
|
Limit number of trees in the prediction; defaults to
|
||||||
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
|
best_ntree_limit if defined (i.e. it has been trained with early
|
||||||
|
stopping), otherwise 0 (use all trees).
|
||||||
validate_features : bool
|
validate_features : bool
|
||||||
When this is True, validate that the Booster's and data's feature_names are identical.
|
When this is True, validate that the Booster's and data's
|
||||||
Otherwise, it is assumed that the feature_names are the same.
|
feature_names are identical. Otherwise, it is assumed that the
|
||||||
|
feature_names are the same.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction : numpy array
|
prediction : numpy array
|
||||||
@ -773,9 +818,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
.. note:: This function is not thread safe
|
.. note:: This function is not thread safe
|
||||||
|
|
||||||
For each booster object, predict can only be called from one thread.
|
For each booster object, predict can only be called from one
|
||||||
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
|
thread. If you want to run prediction using multiple thread, call
|
||||||
of model object and then call predict
|
``xgb.copy()`` to make copies of model object and then call predict
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -849,29 +894,25 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
return evals_result
|
return evals_result
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc(
|
||||||
|
"scikit-learn API for XGBoost random forest classification.",
|
||||||
|
['model', 'objective'],
|
||||||
|
extra_parameters='''
|
||||||
|
n_estimators : int
|
||||||
|
Number of trees in random forest to fit.
|
||||||
|
''')
|
||||||
class XGBRFClassifier(XGBClassifier):
|
class XGBRFClassifier(XGBClassifier):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\
|
def __init__(self,
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
learning_rate=1,
|
||||||
|
subsample=0.8,
|
||||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
colsample_bynode=0.8,
|
||||||
verbosity=1, objective="binary:logistic", n_jobs=1,
|
reg_lambda=1e-5,
|
||||||
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
**kwargs):
|
||||||
subsample=0.8, colsample_bytree=1, colsample_bylevel=1,
|
super().__init__(learning_rate=learning_rate,
|
||||||
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
subsample=subsample,
|
||||||
scale_pos_weight=1, base_score=0.5, random_state=0,
|
colsample_bynode=colsample_bynode,
|
||||||
missing=None, **kwargs):
|
reg_lambda=reg_lambda,
|
||||||
super(XGBRFClassifier, self).__init__(
|
|
||||||
max_depth=max_depth, learning_rate=learning_rate,
|
|
||||||
n_estimators=n_estimators, verbosity=verbosity,
|
|
||||||
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
|
||||||
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
|
||||||
max_delta_step=max_delta_step,
|
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
|
||||||
colsample_bylevel=colsample_bylevel,
|
|
||||||
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
|
||||||
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
|
||||||
base_score=base_score, random_state=random_state, missing=missing,
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
@ -883,37 +924,25 @@ class XGBRFClassifier(XGBClassifier):
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc(
|
||||||
|
"Implementation of the scikit-learn API for XGBoost regression.",
|
||||||
|
['estimators', 'model', 'objective'])
|
||||||
class XGBRegressor(XGBModel, XGBRegressorBase):
|
class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
|
def __init__(self, objective="reg:squarederror", **kwargs):
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@xgboost_model_doc(
|
||||||
|
"scikit-learn API for XGBoost random forest regression.",
|
||||||
|
['model', 'objective'])
|
||||||
class XGBRFRegressor(XGBRegressor):
|
class XGBRFRegressor(XGBRegressor):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\
|
def __init__(self, learning_rate=1, subsample=0.8, colsample_bynode=0.8,
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
reg_lambda=1e-5, **kwargs):
|
||||||
|
super().__init__(learning_rate=learning_rate, subsample=subsample,
|
||||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
|
||||||
verbosity=1, objective="reg:squarederror", n_jobs=1,
|
|
||||||
gpu_id=-1, gamma=0, min_child_weight=1,
|
|
||||||
max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
|
||||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
|
|
||||||
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
|
|
||||||
random_state=0, missing=None, **kwargs):
|
|
||||||
super(XGBRFRegressor, self).__init__(
|
|
||||||
max_depth=max_depth, learning_rate=learning_rate,
|
|
||||||
n_estimators=n_estimators, verbosity=verbosity,
|
|
||||||
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
|
||||||
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
|
||||||
max_delta_step=max_delta_step, subsample=subsample,
|
|
||||||
colsample_bytree=colsample_bytree,
|
|
||||||
colsample_bylevel=colsample_bylevel,
|
|
||||||
colsample_bynode=colsample_bynode,
|
colsample_bynode=colsample_bynode,
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
reg_lambda=reg_lambda, **kwargs)
|
||||||
scale_pos_weight=scale_pos_weight,
|
|
||||||
base_score=base_score, random_state=random_state, missing=missing,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
params = super(XGBRFRegressor, self).get_xgb_params()
|
params = super(XGBRFRegressor, self).get_xgb_params()
|
||||||
@ -924,71 +953,10 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class XGBRanker(XGBModel):
|
@xgboost_model_doc(
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
'Implementation of the Scikit-Learn API for XGBoost Ranking.',
|
||||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
['estimators', 'model'],
|
||||||
|
end_note='''
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
max_depth : int
|
|
||||||
Maximum tree depth for base learners.
|
|
||||||
learning_rate : float
|
|
||||||
Boosting learning rate (xgb's "eta")
|
|
||||||
n_estimators : int
|
|
||||||
Number of boosted trees to fit.
|
|
||||||
verbosity : int
|
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
|
||||||
objective : string
|
|
||||||
Specify the learning task and the corresponding learning objective.
|
|
||||||
The objective name must start with "rank:".
|
|
||||||
booster: string
|
|
||||||
Specify which booster to use: gbtree, gblinear or dart.
|
|
||||||
n_jobs : int
|
|
||||||
Number of parallel threads used to run xgboost.
|
|
||||||
gamma : float
|
|
||||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
|
||||||
min_child_weight : int
|
|
||||||
Minimum sum of instance weight(hessian) needed in a child.
|
|
||||||
max_delta_step : int
|
|
||||||
Maximum delta step we allow each tree's weight estimation to be.
|
|
||||||
subsample : float
|
|
||||||
Subsample ratio of the training instance.
|
|
||||||
colsample_bytree : float
|
|
||||||
Subsample ratio of columns when constructing each tree.
|
|
||||||
colsample_bylevel : float
|
|
||||||
Subsample ratio of columns for each level.
|
|
||||||
colsample_bynode : float
|
|
||||||
Subsample ratio of columns for each split.
|
|
||||||
reg_alpha : float (xgb's alpha)
|
|
||||||
L1 regularization term on weights
|
|
||||||
reg_lambda : float (xgb's lambda)
|
|
||||||
L2 regularization term on weights
|
|
||||||
scale_pos_weight : float
|
|
||||||
Balancing of positive and negative weights.
|
|
||||||
base_score:
|
|
||||||
The initial prediction score of all instances, global bias.
|
|
||||||
random_state : int
|
|
||||||
Random number seed.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
Using gblinear booster with shotgun updater is nondeterministic as
|
|
||||||
it uses Hogwild algorithm.
|
|
||||||
|
|
||||||
missing : float, optional
|
|
||||||
Value in the data which needs to be present as a missing value. If
|
|
||||||
None, defaults to np.nan.
|
|
||||||
\\*\\*kwargs : dict, optional
|
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
|
||||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
|
||||||
Attempting to set a parameter via the constructor args and \\*\\*kwargs dict
|
|
||||||
simultaneously will result in a TypeError.
|
|
||||||
|
|
||||||
.. note:: \\*\\*kwargs unsupported by scikit-learn
|
|
||||||
|
|
||||||
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee that parameters
|
|
||||||
passed via this argument will interact properly with scikit-learn.
|
|
||||||
|
|
||||||
Note
|
Note
|
||||||
----
|
----
|
||||||
A custom objective function is currently not supported by XGBRanker.
|
A custom objective function is currently not supported by XGBRanker.
|
||||||
@ -998,9 +966,9 @@ class XGBRanker(XGBModel):
|
|||||||
----
|
----
|
||||||
Query group information is required for ranking tasks.
|
Query group information is required for ranking tasks.
|
||||||
|
|
||||||
Before fitting the model, your data need to be sorted by query group. When
|
Before fitting the model, your data need to be sorted by query
|
||||||
fitting the model, you need to provide an additional array that
|
group. When fitting the model, you need to provide an additional array
|
||||||
contains the size of each query group.
|
that contains the size of each query group.
|
||||||
|
|
||||||
For example, if your original data look like:
|
For example, if your original data look like:
|
||||||
|
|
||||||
@ -1023,29 +991,11 @@ class XGBRanker(XGBModel):
|
|||||||
+-------+-----------+---------------+
|
+-------+-----------+---------------+
|
||||||
|
|
||||||
then your group array should be ``[3, 4]``.
|
then your group array should be ``[3, 4]``.
|
||||||
|
''')
|
||||||
"""
|
class XGBRanker(XGBModel):
|
||||||
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, objective='rank:pairwise', **kwargs):
|
||||||
verbosity=1, objective="rank:pairwise", booster='gbtree',
|
super().__init__(objective=objective, **kwargs)
|
||||||
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
|
|
||||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
|
||||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
|
||||||
random_state=0, missing=None, **kwargs):
|
|
||||||
|
|
||||||
super(XGBRanker, self).__init__(
|
|
||||||
max_depth=max_depth, learning_rate=learning_rate,
|
|
||||||
n_estimators=n_estimators, verbosity=verbosity,
|
|
||||||
objective=objective, booster=booster, tree_method=tree_method,
|
|
||||||
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
|
||||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
|
||||||
colsample_bylevel=colsample_bylevel,
|
|
||||||
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
|
||||||
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
|
||||||
base_score=base_score, random_state=random_state, missing=missing,
|
|
||||||
**kwargs)
|
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"custom objective function not supported by XGBRanker")
|
"custom objective function not supported by XGBRanker")
|
||||||
@ -1058,8 +1008,7 @@ class XGBRanker(XGBModel):
|
|||||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||||
callbacks=None):
|
callbacks=None):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""
|
"""Fit gradient boosting ranker
|
||||||
Fit gradient boosting ranker
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1068,17 +1017,17 @@ class XGBRanker(XGBModel):
|
|||||||
y : array_like
|
y : array_like
|
||||||
Labels
|
Labels
|
||||||
group : array_like
|
group : array_like
|
||||||
Size of each query group of training data. Should have as many elements as
|
Size of each query group of training data. Should have as many
|
||||||
the query groups in the training data
|
elements as the query groups in the training data
|
||||||
sample_weight : array_like
|
sample_weight : array_like
|
||||||
Query group weights
|
Query group weights
|
||||||
|
|
||||||
.. note:: Weights are per-group for ranking tasks
|
.. note:: Weights are per-group for ranking tasks
|
||||||
|
|
||||||
In ranking task, one weight is assigned to each query group (not each
|
In ranking task, one weight is assigned to each query group
|
||||||
data point). This is because we only care about the relative ordering of
|
(not each data point). This is because we only care about the
|
||||||
data points within each group, so it doesn't make sense to assign
|
relative ordering of data points within each group, so it
|
||||||
weights to individual data points.
|
doesn't make sense to assign weights to individual data points.
|
||||||
|
|
||||||
base_margin : array_like
|
base_margin : array_like
|
||||||
Global bias for each instance.
|
Global bias for each instance.
|
||||||
@ -1112,24 +1061,26 @@ class XGBRanker(XGBModel):
|
|||||||
The method returns the model from the last iteration (not the best one).
|
The method returns the model from the last iteration (not the best one).
|
||||||
If there's more than one item in **eval_set**, the last entry will be used
|
If there's more than one item in **eval_set**, the last entry will be used
|
||||||
for early stopping.
|
for early stopping.
|
||||||
If there's more than one metric in **eval_metric**, the last metric will be
|
If there's more than one metric in **eval_metric**, the last metric
|
||||||
used for early stopping.
|
will be used for early stopping.
|
||||||
If early stopping occurs, the model will have three additional fields:
|
If early stopping occurs, the model will have three additional
|
||||||
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
fields: ``clf.best_score``, ``clf.best_iteration`` and
|
||||||
|
``clf.best_ntree_limit``.
|
||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
xgb_model : str
|
xgb_model : str
|
||||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
file name of stored XGBoost model or 'Booster' instance XGBoost
|
||||||
loaded before training (allows training continuation).
|
model to be loaded before training (allows training continuation).
|
||||||
callbacks : list of callback functions
|
callbacks : list of callback functions
|
||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each
|
||||||
It is possible to use predefined callbacks by using :ref:`callback_api`.
|
iteration. It is possible to use predefined callbacks by using
|
||||||
Example:
|
:ref:`callback_api`. Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
[xgb.callback.reset_learning_rate(custom_rates)]
|
[xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check if group information is provided
|
# check if group information is provided
|
||||||
if group is None:
|
if group is None:
|
||||||
@ -1137,11 +1088,14 @@ class XGBRanker(XGBModel):
|
|||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
if eval_group is None:
|
if eval_group is None:
|
||||||
raise ValueError("eval_group is required if eval_set is not None")
|
raise ValueError(
|
||||||
|
"eval_group is required if eval_set is not None")
|
||||||
if len(eval_group) != len(eval_set):
|
if len(eval_group) != len(eval_set):
|
||||||
raise ValueError("length of eval_group should match that of eval_set")
|
raise ValueError(
|
||||||
|
"length of eval_group should match that of eval_set")
|
||||||
if any(group is None for group in eval_group):
|
if any(group is None for group in eval_group):
|
||||||
raise ValueError("group is required for all eval datasets for ranking task")
|
raise ValueError(
|
||||||
|
"group is required for all eval datasets for ranking task")
|
||||||
|
|
||||||
def _dmat_init(group, **params):
|
def _dmat_init(group, **params):
|
||||||
ret = DMatrix(**params)
|
ret = DMatrix(**params)
|
||||||
@ -1158,9 +1112,13 @@ class XGBRanker(XGBModel):
|
|||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
if sample_weight_eval_set is None:
|
if sample_weight_eval_set is None:
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1],
|
evals = [_dmat_init(eval_group[i],
|
||||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
data=eval_set[i][0],
|
||||||
nthread=self.n_jobs) for i in range(len(eval_set))]
|
label=eval_set[i][1],
|
||||||
|
missing=self.missing,
|
||||||
|
weight=sample_weight_eval_set[i],
|
||||||
|
nthread=self.n_jobs)
|
||||||
|
for i in range(len(eval_set))]
|
||||||
nevals = len(evals)
|
nevals = len(evals)
|
||||||
eval_names = ["eval_{}".format(i) for i in range(nevals)]
|
eval_names = ["eval_{}".format(i) for i in range(nevals)]
|
||||||
evals = list(zip(evals, eval_names))
|
evals = list(zip(evals, eval_names))
|
||||||
@ -1172,13 +1130,14 @@ class XGBRanker(XGBModel):
|
|||||||
feval = eval_metric if callable(eval_metric) else None
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
if callable(eval_metric):
|
if callable(eval_metric):
|
||||||
raise ValueError('Custom evaluation metric is not yet supported' +
|
raise ValueError(
|
||||||
'for XGBRanker.')
|
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||||
params.update({'eval_metric': eval_metric})
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(params, train_dmatrix,
|
||||||
self.n_estimators,
|
self.n_estimators,
|
||||||
early_stopping_rounds=early_stopping_rounds, evals=evals,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
evals=evals,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, feval=feval,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=xgb_model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
# pylint: disable=too-many-branches, too-many-statements
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
|
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv
|
||||||
|
from .core import EarlyStopException
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from . import callback
|
from . import callback
|
||||||
@ -37,10 +38,11 @@ def _train_internal(params, dtrain,
|
|||||||
|
|
||||||
_params = dict(params) if isinstance(params, list) else params
|
_params = dict(params) if isinstance(params, list) else params
|
||||||
|
|
||||||
if 'num_parallel_tree' in _params:
|
if 'num_parallel_tree' in _params and params[
|
||||||
|
'num_parallel_tree'] is not None:
|
||||||
num_parallel_tree = _params['num_parallel_tree']
|
num_parallel_tree = _params['num_parallel_tree']
|
||||||
nboost //= num_parallel_tree
|
nboost //= num_parallel_tree
|
||||||
if 'num_class' in _params:
|
if 'num_class' in _params and _params['num_class'] is not None:
|
||||||
nboost //= _params['num_class']
|
nboost //= _params['num_class']
|
||||||
|
|
||||||
# Distributed code: Load the checkpoint from rabit.
|
# Distributed code: Load the checkpoint from rabit.
|
||||||
|
|||||||
@ -22,17 +22,17 @@ class TestEarlyStopping(unittest.TestCase):
|
|||||||
y = digits['target']
|
y = digits['target']
|
||||||
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||||
random_state=0)
|
random_state=0)
|
||||||
clf1 = xgb.XGBClassifier()
|
clf1 = xgb.XGBClassifier(learning_rate=0.1)
|
||||||
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
|
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
|
||||||
eval_set=[(X_test, y_test)])
|
eval_set=[(X_test, y_test)])
|
||||||
clf2 = xgb.XGBClassifier()
|
clf2 = xgb.XGBClassifier(learning_rate=0.1)
|
||||||
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc",
|
clf2.fit(X_train, y_train, early_stopping_rounds=4, eval_metric="auc",
|
||||||
eval_set=[(X_test, y_test)])
|
eval_set=[(X_test, y_test)])
|
||||||
# should be the same
|
# should be the same
|
||||||
assert clf1.best_score == clf2.best_score
|
assert clf1.best_score == clf2.best_score
|
||||||
assert clf1.best_score != 1
|
assert clf1.best_score != 1
|
||||||
# check overfit
|
# check overfit
|
||||||
clf3 = xgb.XGBClassifier()
|
clf3 = xgb.XGBClassifier(learning_rate=0.1)
|
||||||
clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
|
clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
|
||||||
eval_set=[(X_test, y_test)])
|
eval_set=[(X_test, y_test)])
|
||||||
assert clf3.best_score == 1
|
assert clf3.best_score == 1
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import pytest
|
import pytest
|
||||||
import unittest
|
import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
@ -117,9 +118,10 @@ def test_feature_importances_weight():
|
|||||||
digits = load_digits(2)
|
digits = load_digits(2)
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="weight").fit(X, y)
|
||||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
|
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
|
||||||
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
|
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
|
||||||
0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0.,
|
0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0.,
|
||||||
@ -134,12 +136,16 @@ def test_feature_importances_weight():
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
y = pd.Series(digits['target'])
|
y = pd.Series(digits['target'])
|
||||||
X = pd.DataFrame(digits['data'])
|
X = pd.DataFrame(digits['data'])
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="weight").fit(X, y)
|
||||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||||
|
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="weight").fit(X, y)
|
||||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +157,9 @@ def test_feature_importances_gain():
|
|||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(
|
||||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
random_state=0, tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="gain").fit(X, y)
|
||||||
|
|
||||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
|
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
@ -169,11 +177,15 @@ def test_feature_importances_gain():
|
|||||||
y = pd.Series(digits['target'])
|
y = pd.Series(digits['target'])
|
||||||
X = pd.DataFrame(digits['data'])
|
X = pd.DataFrame(digits['data'])
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(
|
||||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
random_state=0, tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="gain").fit(X, y)
|
||||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||||
|
|
||||||
xgb_model = xgb.XGBClassifier(
|
xgb_model = xgb.XGBClassifier(
|
||||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
random_state=0, tree_method="exact",
|
||||||
|
learning_rate=0.1,
|
||||||
|
importance_type="gain").fit(X, y)
|
||||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +203,10 @@ def test_num_parallel_tree():
|
|||||||
dump = bst.get_booster().get_dump(dump_format='json')
|
dump = bst.get_booster().get_dump(dump_format='json')
|
||||||
assert len(dump) == 4
|
assert len(dump) == 4
|
||||||
|
|
||||||
|
config = json.loads(bst.get_booster().save_config())
|
||||||
|
assert int(config['learner']['gradient_booster']['gbtree_train_param'][
|
||||||
|
'num_parallel_tree']) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_boston_housing_regression():
|
def test_boston_housing_regression():
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
@ -244,7 +260,7 @@ def test_parameter_tuning():
|
|||||||
boston = load_boston()
|
boston = load_boston()
|
||||||
y = boston['target']
|
y = boston['target']
|
||||||
X = boston['data']
|
X = boston['data']
|
||||||
xgb_model = xgb.XGBRegressor()
|
xgb_model = xgb.XGBRegressor(learning_rate=0.1)
|
||||||
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
|
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
|
||||||
'n_estimators': [50, 100, 200]},
|
'n_estimators': [50, 100, 200]},
|
||||||
cv=3, verbose=1, iid=True)
|
cv=3, verbose=1, iid=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user