[Breaking] Update sklearn interface. (#4929)
* Remove nthread, seed, silent. Add tree_method, gpu_id, num_parallel_tree. Fix #4909. * Check data shape. Fix #4896. * Check element of eval_set is tuple. Fix #4875 * Add doc for random_state with hogwild. Fixes #4919
This commit is contained in:
@@ -62,17 +62,18 @@ class XGBModel(XGBModelBase):
|
||||
Number of trees to fit.
|
||||
verbosity : int
|
||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
|
||||
objective : string or callable
|
||||
Specify the learning task and the corresponding learning objective or
|
||||
a custom objective function to be used (see note below).
|
||||
booster: string
|
||||
Specify which booster to use: gbtree, gblinear or dart.
|
||||
nthread : int
|
||||
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
|
||||
tree_method: string
|
||||
Specify which tree method to use. Default to auto. If this parameter
|
||||
is set to default, XGBoost will choose the most conservative option
|
||||
available. It's recommended to study this option from parameters
|
||||
document.
|
||||
n_jobs : int
|
||||
Number of parallel threads used to run xgboost. (replaces ``nthread``)
|
||||
Number of parallel threads used to run xgboost.
|
||||
gamma : float
|
||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
min_child_weight : int
|
||||
@@ -95,13 +96,17 @@ class XGBModel(XGBModelBase):
|
||||
Balancing of positive and negative weights.
|
||||
base_score:
|
||||
The initial prediction score of all instances, global bias.
|
||||
seed : int
|
||||
Random number seed. (Deprecated, please use random_state)
|
||||
random_state : int
|
||||
Random number seed. (replaces seed)
|
||||
Random number seed.
|
||||
|
||||
.. note:: Using gblinear booster with shotgun updater is
|
||||
nondeterministic as it uses Hogwild algorithm.
|
||||
|
||||
missing : float, optional
|
||||
Value in the data which needs to be present as a missing value. If
|
||||
None, defaults to np.nan.
|
||||
num_parallel_tree: int
|
||||
Used for boosting random forest.
|
||||
importance_type: string, default "gain"
|
||||
The feature importance type for the feature_importances\\_ property:
|
||||
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
||||
@@ -131,25 +136,27 @@ class XGBModel(XGBModelBase):
|
||||
The value of the gradient for each sample point.
|
||||
hess: array_like of shape [n_samples]
|
||||
The value of the second derivative for each sample point
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
verbosity=1, silent=None, objective="reg:squarederror",
|
||||
booster='gbtree', n_jobs=1, nthread=None, gamma=0,
|
||||
verbosity=1, objective="reg:squarederror",
|
||||
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||
random_state=0, seed=None, missing=None,
|
||||
random_state=0, missing=None, num_parallel_tree=1,
|
||||
importance_type="gain", **kwargs):
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
||||
raise XGBoostError(
|
||||
'sklearn needs to be installed in order to use this module')
|
||||
self.max_depth = max_depth
|
||||
self.learning_rate = learning_rate
|
||||
self.n_estimators = n_estimators
|
||||
self.verbosity = verbosity
|
||||
self.silent = silent
|
||||
self.objective = objective
|
||||
self.booster = booster
|
||||
self.tree_method = tree_method
|
||||
self.gamma = gamma
|
||||
self.min_child_weight = min_child_weight
|
||||
self.max_delta_step = max_delta_step
|
||||
@@ -162,11 +169,10 @@ class XGBModel(XGBModelBase):
|
||||
self.scale_pos_weight = scale_pos_weight
|
||||
self.base_score = base_score
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.num_parallel_tree = num_parallel_tree
|
||||
self.kwargs = kwargs
|
||||
self._Booster = None
|
||||
self.seed = seed
|
||||
self.random_state = random_state
|
||||
self.nthread = nthread
|
||||
self.n_jobs = n_jobs
|
||||
self.importance_type = importance_type
|
||||
|
||||
@@ -227,33 +233,6 @@ class XGBModel(XGBModelBase):
|
||||
def get_xgb_params(self):
|
||||
"""Get xgboost type parameters."""
|
||||
xgb_params = self.get_params()
|
||||
random_state = xgb_params.pop('random_state')
|
||||
if 'seed' in xgb_params and xgb_params['seed'] is not None:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['seed'] = random_state
|
||||
n_jobs = xgb_params.pop('n_jobs')
|
||||
if 'nthread' in xgb_params and xgb_params['nthread'] is not None:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
else:
|
||||
xgb_params['nthread'] = n_jobs
|
||||
|
||||
if 'silent' in xgb_params and xgb_params['silent'] is not None:
|
||||
warnings.warn('The silent parameter is deprecated.'
|
||||
'Please use verbosity instead.'
|
||||
'silent is depreated', DeprecationWarning)
|
||||
# TODO(canonizer): set verbosity explicitly if silent is removed from xgboost,
|
||||
# but remains in this API
|
||||
else:
|
||||
# silent=None shouldn't be passed to xgboost
|
||||
xgb_params.pop('silent', None)
|
||||
|
||||
if xgb_params['nthread'] <= 0:
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
def get_num_boosting_rounds(self):
|
||||
@@ -301,7 +280,7 @@ class XGBModel(XGBModelBase):
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if self._Booster is None:
|
||||
self._Booster = Booster({'nthread': self.n_jobs})
|
||||
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||
self._Booster.load_model(fname)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
@@ -364,13 +343,17 @@ class XGBModel(XGBModelBase):
|
||||
"""
|
||||
if sample_weight is not None:
|
||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
missing=self.missing,
|
||||
nthread=self.n_jobs)
|
||||
else:
|
||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
|
||||
nthread=self.n_jobs)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
if eval_set is not None:
|
||||
if not isinstance(eval_set[0], (list, tuple)):
|
||||
raise TypeError('Unexpected input type for `eval_set`')
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
evals = list(
|
||||
@@ -610,22 +593,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
verbosity=1, silent=None,
|
||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||
n_estimators=100, verbosity=1,
|
||||
objective="binary:logistic", booster='gbtree',
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||
random_state=0, missing=None, **kwargs):
|
||||
super(XGBClassifier, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
max_depth=max_depth, learning_rate=learning_rate,
|
||||
n_estimators=n_estimators, verbosity=verbosity,
|
||||
objective=objective, booster=booster, tree_method=tree_method,
|
||||
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
||||
min_child_weight=min_child_weight,
|
||||
max_delta_step=max_delta_step, subsample=subsample,
|
||||
colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||
scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
@@ -676,6 +664,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
else:
|
||||
evals = ()
|
||||
|
||||
if len(X.shape) != 2:
|
||||
# Simply raise an error here since there might be many
|
||||
# different ways of reshaping
|
||||
raise ValueError(
|
||||
'Please reshape the input data X into 2-dimensional matrix.')
|
||||
self._features_count = X.shape[1]
|
||||
|
||||
if sample_weight is not None:
|
||||
@@ -846,26 +839,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
|
||||
class XGBRFClassifier(XGBClassifier):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = "Experimental implementation of the scikit-learn API "\
|
||||
+ "for XGBoost random forest classification.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
||||
verbosity=1, silent=None,
|
||||
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
||||
verbosity=1, objective="binary:logistic", n_jobs=1,
|
||||
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=0.8, colsample_bytree=1, colsample_bylevel=1,
|
||||
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
||||
scale_pos_weight=1, base_score=0.5, random_state=0,
|
||||
missing=None, **kwargs):
|
||||
super(XGBRFClassifier, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
max_depth=max_depth, learning_rate=learning_rate,
|
||||
n_estimators=n_estimators, verbosity=verbosity,
|
||||
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
||||
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
||||
max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
colsample_bylevel=colsample_bylevel,
|
||||
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
||||
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def get_xgb_params(self):
|
||||
@@ -885,26 +879,28 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||
|
||||
class XGBRFRegressor(XGBRegressor):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = "Experimental implementation of the scikit-learn API "\
|
||||
+ "for XGBoost random forest regression.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
||||
verbosity=1, silent=None,
|
||||
objective="reg:squarederror", n_jobs=1, nthread=None, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
||||
missing=None, **kwargs):
|
||||
verbosity=1, objective="reg:squarederror", n_jobs=1,
|
||||
gpu_id=-1, gamma=0, min_child_weight=1,
|
||||
max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
|
||||
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
|
||||
random_state=0, missing=None, **kwargs):
|
||||
super(XGBRFRegressor, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
max_depth=max_depth, learning_rate=learning_rate,
|
||||
n_estimators=n_estimators, verbosity=verbosity,
|
||||
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
||||
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
||||
max_delta_step=max_delta_step, subsample=subsample,
|
||||
colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel,
|
||||
colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||
scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def get_xgb_params(self):
|
||||
@@ -930,17 +926,13 @@ class XGBRanker(XGBModel):
|
||||
Number of boosted trees to fit.
|
||||
verbosity : int
|
||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
|
||||
objective : string
|
||||
Specify the learning task and the corresponding learning objective.
|
||||
The objective name must start with "rank:".
|
||||
booster: string
|
||||
Specify which booster to use: gbtree, gblinear or dart.
|
||||
nthread : int
|
||||
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
|
||||
n_jobs : int
|
||||
Number of parallel threads used to run xgboost. (replaces ``nthread``)
|
||||
Number of parallel threads used to run xgboost.
|
||||
gamma : float
|
||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
min_child_weight : int
|
||||
@@ -963,10 +955,12 @@ class XGBRanker(XGBModel):
|
||||
Balancing of positive and negative weights.
|
||||
base_score:
|
||||
The initial prediction score of all instances, global bias.
|
||||
seed : int
|
||||
Random number seed. (Deprecated, please use random_state)
|
||||
random_state : int
|
||||
Random number seed. (replaces seed)
|
||||
Random number seed.
|
||||
|
||||
.. note:: Using gblinear booster with shotgun updater is
|
||||
nondeterministic as it uses Hogwild algorithm.
|
||||
|
||||
missing : float, optional
|
||||
Value in the data which needs to be present as a missing value. If
|
||||
None, defaults to np.nan.
|
||||
@@ -1015,33 +1009,39 @@ class XGBRanker(XGBModel):
|
||||
+-------+-----------+---------------+
|
||||
|
||||
then your group array should be ``[3, 4]``.
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
verbosity=1, silent=None, objective="rank:pairwise", booster='gbtree',
|
||||
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||
verbosity=1, objective="rank:pairwise", booster='gbtree',
|
||||
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||
random_state=0, missing=None, **kwargs):
|
||||
|
||||
super(XGBRanker, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
max_depth=max_depth, learning_rate=learning_rate,
|
||||
n_estimators=n_estimators, verbosity=verbosity,
|
||||
objective=objective, booster=booster, tree_method=tree_method,
|
||||
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||
scale_pos_weight=scale_pos_weight, base_score=base_score,
|
||||
random_state=random_state, seed=seed, missing=missing, **kwargs)
|
||||
colsample_bylevel=colsample_bylevel,
|
||||
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
||||
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, missing=missing,
|
||||
**kwargs)
|
||||
if callable(self.objective):
|
||||
raise ValueError("custom objective function not supported by XGBRanker")
|
||||
raise ValueError(
|
||||
"custom objective function not supported by XGBRanker")
|
||||
if "rank:" not in self.objective:
|
||||
raise ValueError("please use XGBRanker for ranking task")
|
||||
|
||||
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
||||
eval_group=None, eval_metric=None, early_stopping_rounds=None,
|
||||
verbose=False, xgb_model=None, callbacks=None):
|
||||
def fit(self, X, y, group, sample_weight=None, eval_set=None,
|
||||
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||
callbacks=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit gradient boosting ranker
|
||||
@@ -1132,11 +1132,13 @@ class XGBRanker(XGBModel):
|
||||
return ret
|
||||
|
||||
if sample_weight is not None:
|
||||
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix = _dmat_init(
|
||||
group, data=X, label=y, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
else:
|
||||
train_dmatrix = _dmat_init(group, data=X, label=y,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix = _dmat_init(
|
||||
group, data=X, label=y,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user