[Breaking] Update sklearn interface. (#4929)

* Remove nthread, seed, silent. Add tree_method, gpu_id, num_parallel_tree. Fix #4909.
* Check data shape. Fix #4896.
* Check element of eval_set is tuple. Fix #4875
*  Add doc for random_state with hogwild. Fixes #4919
This commit is contained in:
Jiaming Yuan 2019-10-12 02:50:09 -04:00 committed by GitHub
parent c2cce4fac3
commit 4bbf062ed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 128 deletions

View File

@ -35,6 +35,7 @@ struct GenericParameter : public dmlc::Parameter<GenericParameter> {
DMLC_DECLARE_PARAMETER(GenericParameter) {
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
"Random number seed during training.");
DMLC_DECLARE_ALIAS(seed, random_state);
DMLC_DECLARE_FIELD(seed_per_iteration)
.set_default(false)
.describe(

View File

@ -62,17 +62,18 @@ class XGBModel(XGBModelBase):
Number of trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
silent : boolean
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
objective : string or callable
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
tree_method: string
Specify which tree method to use. Default to auto. If this parameter
is set to default, XGBoost will choose the most conservative option
available. It's recommended to study this option from parameters
document.
n_jobs : int
Number of parallel threads used to run xgboost. (replaces ``nthread``)
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
@ -95,13 +96,17 @@ class XGBModel(XGBModelBase):
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
seed : int
Random number seed. (Deprecated, please use random_state)
random_state : int
Random number seed. (replaces seed)
Random number seed.
.. note:: Using gblinear booster with shotgun updater is
nondeterministic as it uses Hogwild algorithm.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
num_parallel_tree: int
Used for boosting random forest.
importance_type: string, default "gain"
The feature importance type for the feature_importances\\_ property:
either "gain", "weight", "cover", "total_gain" or "total_cover".
@ -131,25 +136,27 @@ class XGBModel(XGBModelBase):
The value of the gradient for each sample point.
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
verbosity=1, silent=None, objective="reg:squarederror",
booster='gbtree', n_jobs=1, nthread=None, gamma=0,
verbosity=1, objective="reg:squarederror",
booster='gbtree', tree_method='auto', n_jobs=1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, seed=None, missing=None,
random_state=0, missing=None, num_parallel_tree=1,
importance_type="gain", **kwargs):
if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module')
raise XGBoostError(
'sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.verbosity = verbosity
self.silent = silent
self.objective = objective
self.booster = booster
self.tree_method = tree_method
self.gamma = gamma
self.min_child_weight = min_child_weight
self.max_delta_step = max_delta_step
@ -162,11 +169,10 @@ class XGBModel(XGBModelBase):
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing if missing is not None else np.nan
self.num_parallel_tree = num_parallel_tree
self.kwargs = kwargs
self._Booster = None
self.seed = seed
self.random_state = random_state
self.nthread = nthread
self.n_jobs = n_jobs
self.importance_type = importance_type
@ -227,33 +233,6 @@ class XGBModel(XGBModelBase):
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
random_state = xgb_params.pop('random_state')
if 'seed' in xgb_params and xgb_params['seed'] is not None:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
else:
xgb_params['seed'] = random_state
n_jobs = xgb_params.pop('n_jobs')
if 'nthread' in xgb_params and xgb_params['nthread'] is not None:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
else:
xgb_params['nthread'] = n_jobs
if 'silent' in xgb_params and xgb_params['silent'] is not None:
warnings.warn('The silent parameter is deprecated.'
'Please use verbosity instead.'
'silent is depreated', DeprecationWarning)
# TODO(canonizer): set verbosity explicitly if silent is removed from xgboost,
# but remains in this API
else:
# silent=None shouldn't be passed to xgboost
xgb_params.pop('silent', None)
if xgb_params['nthread'] <= 0:
xgb_params.pop('nthread', None)
return xgb_params
def get_num_boosting_rounds(self):
@ -301,7 +280,7 @@ class XGBModel(XGBModelBase):
Input file name or memory buffer(see also save_raw)
"""
if self._Booster is None:
self._Booster = Booster({'nthread': self.n_jobs})
self._Booster = Booster({'n_jobs': self.n_jobs})
self._Booster.load_model(fname)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
@ -364,13 +343,17 @@ class XGBModel(XGBModelBase):
"""
if sample_weight is not None:
trainDmatrix = DMatrix(X, label=y, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs)
missing=self.missing,
nthread=self.n_jobs)
else:
trainDmatrix = DMatrix(X, label=y, missing=self.missing, nthread=self.n_jobs)
trainDmatrix = DMatrix(X, label=y, missing=self.missing,
nthread=self.n_jobs)
evals_result = {}
if eval_set is not None:
if not isinstance(eval_set[0], (list, tuple)):
raise TypeError('Unexpected input type for `eval_set`')
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = list(
@ -610,22 +593,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
verbosity=1, silent=None,
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, verbosity=1,
objective="binary:logistic", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
tree_method='auto', n_jobs=1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
@ -676,6 +664,11 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
else:
evals = ()
if len(X.shape) != 2:
# Simply raise an error here since there might be many
# different ways of reshaping
raise ValueError(
'Please reshape the input data X into 2-dimensional matrix.')
self._features_count = X.shape[1]
if sample_weight is not None:
@ -846,26 +839,27 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
__doc__ = "Experimental implementation of the scikit-learn API "\
+ "for XGBoost random forest classification.\n\n"\
__doc__ = "scikit-learn API for XGBoost random forest classification.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
verbosity=1, silent=None,
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
verbosity=1, objective="binary:logistic", n_jobs=1,
gpu_id=-1, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=0.8, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
scale_pos_weight=1, base_score=0.5, random_state=0,
missing=None, **kwargs):
super(XGBRFClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def get_xgb_params(self):
@ -885,26 +879,28 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
__doc__ = "Experimental implementation of the scikit-learn API "\
+ "for XGBoost random forest regression.\n\n"\
__doc__ = "scikit-learn API for XGBoost random forest regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100,
verbosity=1, silent=None,
objective="reg:squarederror", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1e-5,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
missing=None, **kwargs):
verbosity=1, objective="reg:squarederror", n_jobs=1,
gpu_id=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0,
reg_lambda=1e-5, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRFRegressor, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, silent=silent, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster='gbtree', n_jobs=n_jobs,
gpu_id=gpu_id, gamma=gamma, min_child_weight=min_child_weight,
max_delta_step=max_delta_step, subsample=subsample,
colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
def get_xgb_params(self):
@ -930,17 +926,13 @@ class XGBRanker(XGBModel):
Number of boosted trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
silent : boolean
Whether to print messages while running boosting. Deprecated. Use verbosity instead.
objective : string
Specify the learning task and the corresponding learning objective.
The objective name must start with "rank:".
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use ``n_jobs``)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces ``nthread``)
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
@ -963,10 +955,12 @@ class XGBRanker(XGBModel):
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
seed : int
Random number seed. (Deprecated, please use random_state)
random_state : int
Random number seed. (replaces seed)
Random number seed.
.. note:: Using gblinear booster with shotgun updater is
nondeterministic as it uses Hogwild algorithm.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
@ -1015,33 +1009,39 @@ class XGBRanker(XGBModel):
+-------+-----------+---------------+
then your group array should be ``[3, 4]``.
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
verbosity=1, silent=None, objective="rank:pairwise", booster='gbtree',
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
verbosity=1, objective="rank:pairwise", booster='gbtree',
tree_method='auto', n_jobs=-1, gpu_id=-1, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=1,
colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,
random_state=0, missing=None, **kwargs):
super(XGBRanker, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, silent=silent, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, verbosity=verbosity,
objective=objective, booster=booster, tree_method=tree_method,
n_jobs=n_jobs, gpu_id=gpu_id, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight, base_score=base_score,
random_state=random_state, seed=seed, missing=missing, **kwargs)
colsample_bylevel=colsample_bylevel,
colsample_bynode=colsample_bynode, reg_alpha=reg_alpha,
reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, missing=missing,
**kwargs)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
raise ValueError(
"custom objective function not supported by XGBRanker")
if "rank:" not in self.objective:
raise ValueError("please use XGBRanker for ranking task")
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None, early_stopping_rounds=None,
verbose=False, xgb_model=None, callbacks=None):
def fit(self, X, y, group, sample_weight=None, eval_set=None,
sample_weight_eval_set=None, eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None,
callbacks=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting ranker
@ -1132,10 +1132,12 @@ class XGBRanker(XGBModel):
return ret
if sample_weight is not None:
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
train_dmatrix = _dmat_init(
group, data=X, label=y, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs)
else:
train_dmatrix = _dmat_init(group, data=X, label=y,
train_dmatrix = _dmat_init(
group, data=X, label=y,
missing=self.missing, nthread=self.n_jobs)
evals_result = {}

View File

@ -601,7 +601,7 @@ class LearnerImpl : public Learner {
gbm_->Configure(args);
if (this->gbm_->UseGPU()) {
if (cfg_.find("gpu_id") == cfg_.cend()) {
if (generic_param_.gpu_id == -1) {
generic_param_.gpu_id = 0;
}
}

View File

@ -0,0 +1,31 @@
import xgboost as xgb
import pytest
import sys
import numpy as np
sys.path.append("tests/python")
import testing as tm
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
rng = np.random.RandomState(1994)
def test_gpu_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
digits = load_digits(2)
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(
random_state=42, tree_method='gpu_hist',
n_estimators=4, gpu_id='0').fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1

View File

@ -175,6 +175,21 @@ def test_feature_importances_gain():
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
def test_num_parallel_tree():
from sklearn.datasets import load_boston
reg = xgb.XGBRegressor(n_estimators=4, num_parallel_tree=4,
tree_method='hist')
boston = load_boston()
bst = reg.fit(X=boston['data'], y=boston['target'])
dump = bst.get_booster().get_dump(dump_format='json')
assert len(dump) == 16
reg = xgb.XGBRFRegressor(n_estimators=4)
bst = reg.fit(X=boston['data'], y=boston['target'])
dump = bst.get_booster().get_dump(dump_format='json')
assert len(dump) == 4
def test_boston_housing_regression():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
@ -430,18 +445,18 @@ def test_split_value_histograms():
def test_sklearn_random_state():
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_xgb_params()['seed'] == 402
assert clf.get_xgb_params()['random_state'] == 402
clf = xgb.XGBClassifier(seed=401)
assert clf.get_xgb_params()['seed'] == 401
clf = xgb.XGBClassifier(random_state=401)
assert clf.get_xgb_params()['random_state'] == 401
def test_sklearn_n_jobs():
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_xgb_params()['nthread'] == 1
assert clf.get_xgb_params()['n_jobs'] == 1
clf = xgb.XGBClassifier(nthread=2)
assert clf.get_xgb_params()['nthread'] == 2
clf = xgb.XGBClassifier(n_jobs=2)
assert clf.get_xgb_params()['n_jobs'] == 2
def test_kwargs():
@ -482,7 +497,7 @@ def test_kwargs_error():
def test_sklearn_clone():
from sklearn.base import clone
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
clf = xgb.XGBClassifier(n_jobs=2)
clf.n_jobs = -1
clone(clf)