[Breaking] Update sklearn interface. (#4929)
* Remove nthread, seed, silent. Add tree_method, gpu_id, num_parallel_tree. Fix #4909. * Check data shape. Fix #4896. * Check element of eval_set is tuple. Fix #4875 * Add doc for random_state with hogwild. Fixes #4919
This commit is contained in:
parent
c2cce4fac3
commit
4bbf062ed3
@ -35,6 +35,7 @@ struct GenericParameter : public dmlc::Parameter<GenericParameter> {
|
|||||||
DMLC_DECLARE_PARAMETER(GenericParameter) {
|
DMLC_DECLARE_PARAMETER(GenericParameter) {
|
||||||
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
|
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
|
||||||
"Random number seed during training.");
|
"Random number seed during training.");
|
||||||
|
DMLC_DECLARE_ALIAS(seed, random_state);
|
||||||
DMLC_DECLARE_FIELD(seed_per_iteration)
|
DMLC_DECLARE_FIELD(seed_per_iteration)
|
||||||
.set_default(false)
|
.set_default(false)
|
||||||
.describe(
|
.describe(
|
||||||
|
|||||||
@ -62,17 +62,18 @@ class XGBModel(XGBModelBase):
|
|||||||
Number of trees to fit.
|
Number of trees to fit.
|
||||||
verbosity : int
|
verbosity : int
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||||
silent : boolean
|
|
||||||
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
|
|
||||||
objective : string or callable
|
objective : string or callable
|
||||||
Specify the learning task and the corresponding learning objective or
|
Specify the learning task and the corresponding learning objective or
|
||||||
a custom objective function to be used (see note below).
|
a custom objective function to be used (see note below).
|
||||||
booster: string
|
booster: string
|
||||||
Specify which booster to use: gbtree, gblinear or dart.
|
Specify which booster to use: gbtree, gblinear or dart.
|
||||||
nthread : int
|
tree_method: string
|
||||||
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
|
Specify which tree method to use. Default to auto. If this parameter
|
||||||
|
is set to default, XGBoost will choose the most conservative option
|
||||||
|
available. It's recommended to study this option from parameters
|
||||||
|
document.
|
||||||
n_jobs : int
|
n_jobs : int
|
||||||
Number of parallel threads used to run xgboost. (replaces ``nthread``)
|
Number of parallel threads used to run xgboost.
|
||||||
gamma : float
|
gamma : float
|
||||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
min_child_weight : int
|
min_child_weight : int
|
||||||
@ -95,13 +96,17 @@ class XGBModel(XGBModelBase):
|
|||||||
Balancing of positive and negative weights.
|
Balancing of positive and negative weights.
|
||||||
base_score:
|
base_score:
|
||||||
The initial prediction score of all instances, global bias.
|
The initial prediction score of all instances, global bias.
|
||||||
seed : int
|
|
||||||
Random number seed. (Deprecated, please use random_state)
|
|
||||||
random_state : int
|
random_state : int
|
||||||
Random number seed. (replaces seed)
|
Random number seed.
|
||||||
|
|
||||||
|
.. note:: Using gblinear booster with shotgun updater is
|
||||||
|
nondeterministic as it uses Hogwild algorithm.
|
||||||
|
|
||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
|
num_parallel_tree: int
|
||||||
|
Used for boosting random forest.
|
||||||
importance_type: string, default "gain"
|
importance_type: string, default "gain"
|
||||||
The feature importance type for the feature_importances\\_ property:
|
The feature importance type for the feature_importances\\_ property:
|
||||||
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
||||||
@ -131,25 +136,27 @@ class XGBModel(XGBModelBase):
|
|||||||
The value of the gradient for each sample point.
|
The value of the gradient for each sample point.
|
||||||
hess: array_like of shape [n_samples]
|
hess: array_like of shape [n_samples]
|
||||||
The value of the second derivative for each sample point
|
The value of the second derivative for each sample point
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
verbosity=1, silent=None, objective="reg:squarederror",
|
verbosity=1, objective="reg:squarederror",
|
||||||
booster='gbtree', n_jobs=1, nthread=None, gamma=0,
|
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
|
||||||
min_child_weight=1, max_delta_step=0, subsample=1,
|
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||||
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||||
random_state=0, seed=None, missing=None,
|
random_state=0, missing=None, num_parallel_tree=1,
|
||||||
importance_type="gain", **kwargs):
|
importance_type="gain", **kwargs):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
raise XGBoostError(
|
||||||
|
'sklearn needs to be installed in order to use this module')
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_estimators = n_estimators
|
self.n_estimators = n_estimators
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
self.silent = silent
|
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self.booster = booster
|
self.booster = booster
|
||||||
|
self.tree_method = tree_method
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.min_child_weight = min_child_weight
|
self.min_child_weight = min_child_weight
|
||||||
self.max_delta_step = max_delta_step
|
self.max_delta_step = max_delta_step
|
||||||
@ -162,11 +169,10 @@ class XGBModel(XGBModelBase):
|
|||||||
self.scale_pos_weight = scale_pos_weight
|
self.scale_pos_weight = scale_pos_weight
|
||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self.num_parallel_tree = num_parallel_tree
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self._Booster = None
|
self._Booster = None
|
||||||
self.seed = seed
|
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.nthread = nthread
|
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.importance_type = importance_type
|
self.importance_type = importance_type
|
||||||
|
|
||||||
@ -227,33 +233,6 @@ class XGBModel(XGBModelBase):
|
|||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
"""Get xgboost type parameters."""
|
"""Get xgboost type parameters."""
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
random_state = xgb_params.pop('random_state')
|
|
||||||
if 'seed' in xgb_params and xgb_params['seed'] is not None:
|
|
||||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
|
||||||
'Please use random_state instead.'
|
|
||||||
'seed is deprecated.', DeprecationWarning)
|
|
||||||
else:
|
|
||||||
xgb_params['seed'] = random_state
|
|
||||||
n_jobs = xgb_params.pop('n_jobs')
|
|
||||||
if 'nthread' in xgb_params and xgb_params['nthread'] is not None:
|
|
||||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
|
||||||
'Please use n_jobs instead.'
|
|
||||||
'nthread is deprecated.', DeprecationWarning)
|
|
||||||
else:
|
|
||||||
xgb_params['nthread'] = n_jobs
|
|
||||||
|
|
||||||
if 'silent' in xgb_params and xgb_params['silent'] is not None:
|
|
||||||
warnings.warn('The silent parameter is deprecated.'
|
|
||||||
'Please use verbosity instead.'
|
|
||||||
'silent is depreated', DeprecationWarning)
|
|
||||||
# TODO(canonizer): set verbosity explicitly if silent is removed from xgboost,
|
|
||||||
# but remains in this API
|
|
||||||
else:
|
|
||||||
# silent=None shouldn't be passed to xgboost
|
|
||||||
xgb_params.pop('silent', None)
|
|
||||||
|
|
||||||
if xgb_params['nthread'] <= 0:
|
|
||||||
xgb_params.pop('nthread', None)
|
|
||||||
return xgb_params
|
return xgb_params
|
||||||
|
|
||||||
def get_num_boosting_rounds(self):
|
def get_num_boosting_rounds(self):
|
||||||
@ -301,7 +280,7 @@ class XGBModel(XGBModelBase):
|
|||||||
Input file name or memory buffer(see also save_raw)
|
Input file name or memory buffer(see also save_raw)
|
||||||
"""
|
"""
|
||||||
if self._Booster is None:
|
if self._Booster is None:
|
||||||
self._Booster = Booster({'nthread': self.n_jobs})
|
self._Booster = Booster({'n_jobs': self.n_jobs})
|
||||||
self._Booster.load_model(fname)
|
self._Booster.load_model(fname)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
@ -364,13 +343,17 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
missing=self.missing,
|
||||||
|
nthread=self.n_jobs)
|
||||||
else:
|
else:
|
||||||
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
|
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
|
||||||
|
nthread=self.n_jobs)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
|
if not isinstance(eval_set[0], (list, tuple)):
|
||||||
|
raise TypeError('Unexpected input type for `eval_set`')
|
||||||
if sample_weight_eval_set is None:
|
if sample_weight_eval_set is None:
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
evals = list(
|
evals = list(
|
||||||
@ -610,22 +593,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||||
verbosity=1, silent=None,
|
n_estimators=100, verbosity=1,
|
||||||
objective="binary:logistic", booster='gbtree',
|
objective="binary:logistic", booster='gbtree',
|
||||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
|
||||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||||
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||||
|
random_state=0, missing=None, **kwargs):
|
||||||
super(XGBClassifier, self).__init__(
|
super(XGBClassifier, self).__init__(
|
||||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
max_depth=max_depth, learning_rate=learning_rate,
|
||||||
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
|
n_estimators=n_estimators, verbosity=verbosity,
|
||||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
objective=objective, booster=booster, tree_method=tree_method,
|
||||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
min_child_weight=min_child_weight,
|
||||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
max_delta_step=max_delta_step, subsample=subsample,
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
colsample_bytree=colsample_bytree,
|
||||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
colsample_bylevel=colsample_bylevel,
|
||||||
|
colsample_bynode=colsample_bynode,
|
||||||
|
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||||
|
scale_pos_weight=scale_pos_weight,
|
||||||
|
base_score=base_score, random_state=random_state, missing=missing,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
@ -676,6 +664,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
evals = ()
|
evals = ()
|
||||||
|
|
||||||
|
if len(X.shape) != 2:
|
||||||
|
# Simply raise an error here since there might be many
|
||||||
|
# different ways of reshaping
|
||||||
|
raise ValueError(
|
||||||
|
'Please reshape the input data X into 2-dimensional matrix.')
|
||||||
self._features_count = X.shape[1]
|
self._features_count = X.shape[1]
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
@ -846,26 +839,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
class XGBRFClassifier(XGBClassifier):
|
class XGBRFClassifier(XGBClassifier):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = "Experimental implementation of the scikit-learn API "\
|
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\
|
||||||
+ "for XGBoost random forest classification.\n\n"\
|
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
||||||
verbosity=1, silent=None,
|
verbosity=1, objective="binary:logistic", n_jobs=1,
|
||||||
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
|
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
subsample=0.8, colsample_bytree=1, colsample_bylevel=1,
|
||||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
||||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
scale_pos_weight=1, base_score=0.5, random_state=0,
|
||||||
missing=None, **kwargs):
|
missing=None, **kwargs):
|
||||||
super(XGBRFClassifier, self).__init__(
|
super(XGBRFClassifier, self).__init__(
|
||||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
max_depth=max_depth, learning_rate=learning_rate,
|
||||||
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
|
n_estimators=n_estimators, verbosity=verbosity,
|
||||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
||||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
||||||
|
max_delta_step=max_delta_step,
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
colsample_bylevel=colsample_bylevel,
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
||||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||||
|
base_score=base_score, random_state=random_state, missing=missing,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
@ -885,26 +879,28 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
|||||||
|
|
||||||
class XGBRFRegressor(XGBRegressor):
|
class XGBRFRegressor(XGBRegressor):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = "Experimental implementation of the scikit-learn API "\
|
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\
|
||||||
+ "for XGBoost random forest regression.\n\n"\
|
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
|
||||||
verbosity=1, silent=None,
|
verbosity=1, objective="reg:squarederror", n_jobs=1,
|
||||||
objective="reg:squarederror", n_jobs=1, nthread=None, gamma=0,
|
gpu_id=-1, gamma=0, min_child_weight=1,
|
||||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
|
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
|
||||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
|
||||||
missing=None, **kwargs):
|
random_state=0, missing=None, **kwargs):
|
||||||
super(XGBRFRegressor, self).__init__(
|
super(XGBRFRegressor, self).__init__(
|
||||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
max_depth=max_depth, learning_rate=learning_rate,
|
||||||
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
|
n_estimators=n_estimators, verbosity=verbosity,
|
||||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
objective=objective, booster='gbtree', n_jobs=n_jobs,
|
||||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
max_delta_step=max_delta_step, subsample=subsample,
|
||||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
colsample_bytree=colsample_bytree,
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
colsample_bylevel=colsample_bylevel,
|
||||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
colsample_bynode=colsample_bynode,
|
||||||
|
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||||
|
scale_pos_weight=scale_pos_weight,
|
||||||
|
base_score=base_score, random_state=random_state, missing=missing,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
@ -930,17 +926,13 @@ class XGBRanker(XGBModel):
|
|||||||
Number of boosted trees to fit.
|
Number of boosted trees to fit.
|
||||||
verbosity : int
|
verbosity : int
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||||
silent : boolean
|
|
||||||
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
|
|
||||||
objective : string
|
objective : string
|
||||||
Specify the learning task and the corresponding learning objective.
|
Specify the learning task and the corresponding learning objective.
|
||||||
The objective name must start with "rank:".
|
The objective name must start with "rank:".
|
||||||
booster: string
|
booster: string
|
||||||
Specify which booster to use: gbtree, gblinear or dart.
|
Specify which booster to use: gbtree, gblinear or dart.
|
||||||
nthread : int
|
|
||||||
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
|
|
||||||
n_jobs : int
|
n_jobs : int
|
||||||
Number of parallel threads used to run xgboost. (replaces ``nthread``)
|
Number of parallel threads used to run xgboost.
|
||||||
gamma : float
|
gamma : float
|
||||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
min_child_weight : int
|
min_child_weight : int
|
||||||
@ -963,10 +955,12 @@ class XGBRanker(XGBModel):
|
|||||||
Balancing of positive and negative weights.
|
Balancing of positive and negative weights.
|
||||||
base_score:
|
base_score:
|
||||||
The initial prediction score of all instances, global bias.
|
The initial prediction score of all instances, global bias.
|
||||||
seed : int
|
|
||||||
Random number seed. (Deprecated, please use random_state)
|
|
||||||
random_state : int
|
random_state : int
|
||||||
Random number seed. (replaces seed)
|
Random number seed.
|
||||||
|
|
||||||
|
.. note:: Using gblinear booster with shotgun updater is
|
||||||
|
nondeterministic as it uses Hogwild algorithm.
|
||||||
|
|
||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
@ -1015,33 +1009,39 @@ class XGBRanker(XGBModel):
|
|||||||
+-------+-----------+---------------+
|
+-------+-----------+---------------+
|
||||||
|
|
||||||
then your group array should be ``[3, 4]``.
|
then your group array should be ``[3, 4]``.
|
||||||
"""
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
verbosity=1, silent=None, objective="rank:pairwise", booster='gbtree',
|
verbosity=1, objective="rank:pairwise", booster='gbtree',
|
||||||
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
|
||||||
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
min_child_weight=1, max_delta_step=0, subsample=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
|
||||||
|
random_state=0, missing=None, **kwargs):
|
||||||
|
|
||||||
super(XGBRanker, self).__init__(
|
super(XGBRanker, self).__init__(
|
||||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
max_depth=max_depth, learning_rate=learning_rate,
|
||||||
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
|
n_estimators=n_estimators, verbosity=verbosity,
|
||||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
objective=objective, booster=booster, tree_method=tree_method,
|
||||||
|
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
|
||||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
colsample_bylevel=colsample_bylevel,
|
||||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
|
||||||
scale_pos_weight=scale_pos_weight, base_score=base_score,
|
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||||
random_state=random_state, seed=seed, missing=missing, **kwargs)
|
base_score=base_score, random_state=random_state, missing=missing,
|
||||||
|
**kwargs)
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
raise ValueError("custom objective function not supported by XGBRanker")
|
raise ValueError(
|
||||||
|
"custom objective function not supported by XGBRanker")
|
||||||
if "rank:" not in self.objective:
|
if "rank:" not in self.objective:
|
||||||
raise ValueError("please use XGBRanker for ranking task")
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
def fit(self, X, y, group, sample_weight=None, eval_set=None,
|
||||||
eval_group=None, eval_metric=None, early_stopping_rounds=None,
|
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
|
||||||
verbose=False, xgb_model=None, callbacks=None):
|
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||||
|
callbacks=None):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""
|
"""
|
||||||
Fit gradient boosting ranker
|
Fit gradient boosting ranker
|
||||||
@ -1132,11 +1132,13 @@ class XGBRanker(XGBModel):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
|
train_dmatrix = _dmat_init(
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
group, data=X, label=y, weight=sample_weight,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
else:
|
else:
|
||||||
train_dmatrix = _dmat_init(group, data=X, label=y,
|
train_dmatrix = _dmat_init(
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
group, data=X, label=y,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
|
|||||||
@ -601,7 +601,7 @@ class LearnerImpl : public Learner {
|
|||||||
gbm_->Configure(args);
|
gbm_->Configure(args);
|
||||||
|
|
||||||
if (this->gbm_->UseGPU()) {
|
if (this->gbm_->UseGPU()) {
|
||||||
if (cfg_.find("gpu_id") == cfg_.cend()) {
|
if (generic_param_.gpu_id == -1) {
|
||||||
generic_param_.gpu_id = 0;
|
generic_param_.gpu_id = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
31
tests/python-gpu/test_gpu_with_sklearn.py
Normal file
31
tests/python-gpu/test_gpu_with_sklearn.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
sys.path.append("tests/python")
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpu_binary_classification():
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
from sklearn.model_selection import KFold
|
||||||
|
|
||||||
|
digits = load_digits(2)
|
||||||
|
y = digits['target']
|
||||||
|
X = digits['data']
|
||||||
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
|
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
|
||||||
|
for train_index, test_index in kf.split(X, y):
|
||||||
|
xgb_model = cls(
|
||||||
|
random_state=42, tree_method='gpu_hist',
|
||||||
|
n_estimators=4, gpu_id='0').fit(X[train_index], y[train_index])
|
||||||
|
preds = xgb_model.predict(X[test_index])
|
||||||
|
labels = y[test_index]
|
||||||
|
err = sum(1 for i in range(len(preds))
|
||||||
|
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||||
|
assert err < 0.1
|
||||||
@ -175,6 +175,21 @@ def test_feature_importances_gain():
|
|||||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_parallel_tree():
|
||||||
|
from sklearn.datasets import load_boston
|
||||||
|
reg = xgb.XGBRegressor(n_estimators=4, num_parallel_tree=4,
|
||||||
|
tree_method='hist')
|
||||||
|
boston = load_boston()
|
||||||
|
bst = reg.fit(X=boston['data'], y=boston['target'])
|
||||||
|
dump = bst.get_booster().get_dump(dump_format='json')
|
||||||
|
assert len(dump) == 16
|
||||||
|
|
||||||
|
reg = xgb.XGBRFRegressor(n_estimators=4)
|
||||||
|
bst = reg.fit(X=boston['data'], y=boston['target'])
|
||||||
|
dump = bst.get_booster().get_dump(dump_format='json')
|
||||||
|
assert len(dump) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_boston_housing_regression():
|
def test_boston_housing_regression():
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
from sklearn.datasets import load_boston
|
from sklearn.datasets import load_boston
|
||||||
@ -430,18 +445,18 @@ def test_split_value_histograms():
|
|||||||
|
|
||||||
def test_sklearn_random_state():
|
def test_sklearn_random_state():
|
||||||
clf = xgb.XGBClassifier(random_state=402)
|
clf = xgb.XGBClassifier(random_state=402)
|
||||||
assert clf.get_xgb_params()['seed'] == 402
|
assert clf.get_xgb_params()['random_state'] == 402
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(seed=401)
|
clf = xgb.XGBClassifier(random_state=401)
|
||||||
assert clf.get_xgb_params()['seed'] == 401
|
assert clf.get_xgb_params()['random_state'] == 401
|
||||||
|
|
||||||
|
|
||||||
def test_sklearn_n_jobs():
|
def test_sklearn_n_jobs():
|
||||||
clf = xgb.XGBClassifier(n_jobs=1)
|
clf = xgb.XGBClassifier(n_jobs=1)
|
||||||
assert clf.get_xgb_params()['nthread'] == 1
|
assert clf.get_xgb_params()['n_jobs'] == 1
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(nthread=2)
|
clf = xgb.XGBClassifier(n_jobs=2)
|
||||||
assert clf.get_xgb_params()['nthread'] == 2
|
assert clf.get_xgb_params()['n_jobs'] == 2
|
||||||
|
|
||||||
|
|
||||||
def test_kwargs():
|
def test_kwargs():
|
||||||
@ -482,7 +497,7 @@ def test_kwargs_error():
|
|||||||
def test_sklearn_clone():
|
def test_sklearn_clone():
|
||||||
from sklearn.base import clone
|
from sklearn.base import clone
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
|
clf = xgb.XGBClassifier(n_jobs=2)
|
||||||
clf.n_jobs = -1
|
clf.n_jobs = -1
|
||||||
clone(clf)
|
clone(clf)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user