Added SKLearn-like random forest Python API. (#4148)

* Added SKLearn-like random forest Python API.

- added XGBRFClassifier and XGBRFRegressor classes to SKL-like xgboost API
- also added n_gpus and gpu_id parameters to SKL classes
- added documentation describing how to use xgboost for random forests,
  as well as existing caveats
This commit is contained in:
Andy Adinets 2019-03-12 15:28:19 +01:00 committed by Jiaming Yuan
parent 6fb4c5efef
commit a36c3ed4f4
4 changed files with 240 additions and 55 deletions

89
doc/rf.rst Normal file
View File

@ -0,0 +1,89 @@
#########################
Random Forests in XGBoost
#########################
XGBoost is normally used to train gradient-boosted decision trees and other gradient
boosted models. Random forests use the same model representation and inference, as
gradient-boosted decision trees, but a different training algorithm. There are XGBoost
parameters that enable training a forest in a random forest fashion.
****************
With XGBoost API
****************
The following parameters must be set to enable random forest training.
* ``booster`` should be set to ``gbtree``, as we are training forests. Note that as this
is the default, this parameter needn't be set explicitly.
* ``subsample`` must be set to a value less than 1 to enable random selection of training
cases (rows).
* One of ``colsample_by*`` parameters must be set to a value less than 1 to enable random
selection of columns. Normally, ``colsample_bynode`` would be set to a value less than 1
to randomly sample columns at each tree split.
* ``num_parallel_tree`` should be set to the size of the forest being trained.
* ``num_boost_round`` should be set to 1. Note that this is a keyword argument to
``train()``, and is not part of the parameter dictionary.
* ``eta`` (alias: ``learning_rate``) must be set to 1 when training random forest
regression.
* ``random_state`` can be used to seed the random number generator.
Other parameters should be set in a similar way they are set for gradient boosting. For
instance, ``objective`` will typically be ``reg:linear`` for regression and
``binary:logistic`` for classification, ``lambda`` should be set according to a desired
regularization weight, etc.
If both ``num_parallel_tree`` and ``num_boost_round`` are greater than 1, training will
use a combination of random forest and gradient boosting strategy. It will perform
``num_boost_round`` rounds, boosting a random forest of ``num_parallel_tree`` trees at
each round. If early stopping is not enabled, the final model will consist of
``num_parallel_tree`` * ``num_boost_round`` trees.
Here is a sample parameter dictionary for training a random forest on a GPU using
xgboost::
params = {
'colsample_bynode': 0.8,
'learning_rate': 1,
'max_depth': 5,
'num_parallel_tree': 100,
'objective': 'binary:logistic',
'subsample': 0.8,
'tree_method': 'gpu_hist'
}
A random forest model can then be trained as follows::
bst = train(params, dmatrix, num_boost_round=1)
**************************
With Scikit-Learn-Like API
**************************
``XGBRFClassifier`` and ``XGBRFRegressor`` are SKL-like classes that provide random forest
functionality. They are basically versions of ``XGBClassifier`` and ``XGBRegressor`` that
train random forest instead of gradient boosting, and have default values and meaning of
some of the parameters adjusted accordingly. In particular:
* ``n_estimators`` specifies the size of the forest to be trained; it is converted to
``num_parallel_tree``, instead of the number of boosting rounds
* ``learning_rate`` is set to 1 by default
* ``colsample_bynode`` and ``subsample`` are set to 0.8 by default
* ``booster`` is always ``gbtree``
Note that these classes have a smaller selection of parameters compared to using
``train()``. In particular, it is impossible to combine random forests with gradient
boosting using this API.
*******
Caveats
*******
* XGBoost uses 2nd order approximation to the objective function. This can lead to results
that differ from a random forest implementation that uses the exact value of the
objective function.
* XGBoost does not perform replacement when subsampling training cases. Each training case
can occur in a subsampled set either 0 or 1 time.

View File

@ -13,6 +13,7 @@ from .training import train, cv
from . import rabit # noqa
try:
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
from .sklearn import XGBRFClassifier, XGBRFRegressor
from .plotting import plot_importance, plot_tree, to_graphviz
except ImportError:
pass
@ -24,4 +25,5 @@ with open(VERSION_FILE) as f:
__all__ = ['DMatrix', 'Booster',
'train', 'cv',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
'XGBRFClassifier', 'XGBRFRegressor',
'plot_importance', 'plot_tree', 'to_graphviz']

View File

@ -61,9 +61,9 @@ class XGBModel(XGBModelBase):
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
Number of trees to fit.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string or callable
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
@ -84,7 +84,9 @@ class XGBModel(XGBModelBase):
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
@ -132,10 +134,10 @@ class XGBModel(XGBModelBase):
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
verbosity=1, objective="reg:linear", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None,
importance_type="gain", **kwargs):
if not SKLEARN_INSTALLED:
@ -143,7 +145,7 @@ class XGBModel(XGBModelBase):
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.silent = silent
self.verbosity = verbosity
self.objective = objective
self.booster = booster
self.gamma = gamma
@ -152,6 +154,7 @@ class XGBModel(XGBModelBase):
self.subsample = subsample
self.colsample_bytree = colsample_bytree
self.colsample_bylevel = colsample_bylevel
self.colsample_bynode = colsample_bynode
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
@ -237,12 +240,14 @@ class XGBModel(XGBModelBase):
else:
xgb_params['nthread'] = n_jobs
xgb_params['verbosity'] = 0 if self.silent else 0
if xgb_params['nthread'] <= 0:
xgb_params.pop('nthread', None)
return xgb_params
def get_num_boosting_rounds(self):
"""Gets the number of xgboost boosting rounds."""
return self.n_estimators
def save_model(self, fname):
"""
Save the model to a file.
@ -371,7 +376,7 @@ class XGBModel(XGBModelBase):
params.update({'eval_metric': eval_metric})
self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
@ -583,21 +588,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, verbosity=1,
objective="binary:logistic", booster='gbtree',
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda,
scale_pos_weight, base_score,
random_state, seed, missing, **kwargs)
super(XGBClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, xgb_model=None,
@ -705,9 +711,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing, nthread=self.n_jobs)
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
evals=evals, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
@ -863,12 +868,76 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return evals_result
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API "\
+ "for XGBoost random forest classification.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
missing=None, **kwargs):
super(XGBRFClassifier, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)
def get_xgb_params(self):
params = super(XGBRFClassifier, self).get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
return params
def get_num_boosting_rounds(self):
return 1
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
class XGBRFRegressor(XGBRegressor):
# pylint: disable=missing-docstring
__doc__ = "Implementation of the scikit-learn API "\
+ "for XGBoost random forest regression.\n\n"\
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
objective="reg:linear", n_jobs=1, nthread=None, gamma=0,
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
missing=None, **kwargs):
super(XGBRFRegressor, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster='gbtree',
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
**kwargs)
def get_xgb_params(self):
params = super(XGBRFRegressor, self).get_xgb_params()
params['num_parallel_tree'] = self.n_estimators
return params
def get_num_boosting_rounds(self):
return 1
class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
@ -881,8 +950,8 @@ class XGBRanker(XGBModel):
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
verbosity : int
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : string
Specify the learning task and the corresponding learning objective.
The objective name must start with "rank:".
@ -903,7 +972,9 @@ class XGBRanker(XGBModel):
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
Subsample ratio of columns for each level.
colsample_bynode : float
Subsample ratio of columns for each split.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
@ -966,18 +1037,22 @@ class XGBRanker(XGBModel):
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="rank:pairwise", booster='gbtree',
verbosity=1, objective="rank:pairwise", booster='gbtree',
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
super(XGBRanker, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
subsample, colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda, scale_pos_weight,
base_score, random_state, seed, missing)
super(XGBRanker, self).__init__(
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
verbosity=verbosity, objective=objective, booster=booster,
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
subsample=subsample, colsample_bytree=colsample_bytree,
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight, base_score=base_score,
random_state=random_state, seed=seed, missing=missing, **kwargs)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
elif "rank:" not in self.objective:

View File

@ -29,13 +29,14 @@ def test_binary_classification():
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(random_state=42).fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
assert err < 0.1
def test_multiclass_classification():
@ -83,8 +84,8 @@ def test_ranking():
valid_group = np.repeat(50, 4)
x_test = np.random.rand(100, 10)
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
'gamma': 1.0, 'min_child_weight': 0.1,
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6, 'n_estimators': 4}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, train_group,
@ -97,7 +98,8 @@ def test_ranking():
train_data.set_group(train_group)
valid_data.set_group(valid_group)
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
'eta': 0.1, 'gamma': 1.0,
'min_child_weight': 0.1, 'max_depth': 6}
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
evals=[(valid_data, 'validation')])
@ -113,7 +115,7 @@ def test_feature_importances_weight():
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
@ -130,11 +132,11 @@ def test_feature_importances_weight():
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -145,7 +147,7 @@ def test_feature_importances_gain():
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
@ -163,11 +165,11 @@ def test_feature_importances_gain():
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -199,6 +201,23 @@ def test_boston_housing_regression():
assert mean_squared_error(preds4, labels) < 350
def test_boston_housing_rf_regression():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.model_selection import KFold
boston = load_boston()
y = boston['target']
X = boston['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBRFRegressor(random_state=42).fit(
X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
assert mean_squared_error(preds, labels) < 35
def test_parameter_tuning():
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_boston