Added SKLearn-like random forest Python API. (#4148)
* Added SKLearn-like random forest Python API. - added XGBRFClassifier and XGBRFRegressor classes to SKL-like xgboost API - also added n_gpus and gpu_id parameters to SKL classes - added documentation describing how to use xgboost for random forests, as well as existing caveats
This commit is contained in:
parent
6fb4c5efef
commit
a36c3ed4f4
89
doc/rf.rst
Normal file
89
doc/rf.rst
Normal file
@ -0,0 +1,89 @@
|
||||
#########################
|
||||
Random Forests in XGBoost
|
||||
#########################
|
||||
|
||||
XGBoost is normally used to train gradient-boosted decision trees and other gradient
|
||||
boosted models. Random forests use the same model representation and inference, as
|
||||
gradient-boosted decision trees, but a different training algorithm. There are XGBoost
|
||||
parameters that enable training a forest in a random forest fashion.
|
||||
|
||||
|
||||
****************
|
||||
With XGBoost API
|
||||
****************
|
||||
|
||||
The following parameters must be set to enable random forest training.
|
||||
|
||||
* ``booster`` should be set to ``gbtree``, as we are training forests. Note that as this
|
||||
is the default, this parameter needn't be set explicitly.
|
||||
* ``subsample`` must be set to a value less than 1 to enable random selection of training
|
||||
cases (rows).
|
||||
* One of ``colsample_by*`` parameters must be set to a value less than 1 to enable random
|
||||
selection of columns. Normally, ``colsample_bynode`` would be set to a value less than 1
|
||||
to randomly sample columns at each tree split.
|
||||
* ``num_parallel_tree`` should be set to the size of the forest being trained.
|
||||
* ``num_boost_round`` should be set to 1. Note that this is a keyword argument to
|
||||
``train()``, and is not part of the parameter dictionary.
|
||||
* ``eta`` (alias: ``learning_rate``) must be set to 1 when training random forest
|
||||
regression.
|
||||
* ``random_state`` can be used to seed the random number generator.
|
||||
|
||||
|
||||
Other parameters should be set in a similar way they are set for gradient boosting. For
|
||||
instance, ``objective`` will typically be ``reg:linear`` for regression and
|
||||
``binary:logistic`` for classification, ``lambda`` should be set according to a desired
|
||||
regularization weight, etc.
|
||||
|
||||
If both ``num_parallel_tree`` and ``num_boost_round`` are greater than 1, training will
|
||||
use a combination of random forest and gradient boosting strategy. It will perform
|
||||
``num_boost_round`` rounds, boosting a random forest of ``num_parallel_tree`` trees at
|
||||
each round. If early stopping is not enabled, the final model will consist of
|
||||
``num_parallel_tree`` * ``num_boost_round`` trees.
|
||||
|
||||
Here is a sample parameter dictionary for training a random forest on a GPU using
|
||||
xgboost::
|
||||
|
||||
params = {
|
||||
'colsample_bynode': 0.8,
|
||||
'learning_rate': 1,
|
||||
'max_depth': 5,
|
||||
'num_parallel_tree': 100,
|
||||
'objective': 'binary:logistic',
|
||||
'subsample': 0.8,
|
||||
'tree_method': 'gpu_hist'
|
||||
}
|
||||
|
||||
A random forest model can then be trained as follows::
|
||||
|
||||
bst = train(params, dmatrix, num_boost_round=1)
|
||||
|
||||
|
||||
**************************
|
||||
With Scikit-Learn-Like API
|
||||
**************************
|
||||
|
||||
``XGBRFClassifier`` and ``XGBRFRegressor`` are SKL-like classes that provide random forest
|
||||
functionality. They are basically versions of ``XGBClassifier`` and ``XGBRegressor`` that
|
||||
train random forest instead of gradient boosting, and have default values and meaning of
|
||||
some of the parameters adjusted accordingly. In particular:
|
||||
|
||||
* ``n_estimators`` specifies the size of the forest to be trained; it is converted to
|
||||
``num_parallel_tree``, instead of the number of boosting rounds
|
||||
* ``learning_rate`` is set to 1 by default
|
||||
* ``colsample_bynode`` and ``subsample`` are set to 0.8 by default
|
||||
* ``booster`` is always ``gbtree``
|
||||
|
||||
Note that these classes have a smaller selection of parameters compared to using
|
||||
``train()``. In particular, it is impossible to combine random forests with gradient
|
||||
boosting using this API.
|
||||
|
||||
|
||||
*******
|
||||
Caveats
|
||||
*******
|
||||
|
||||
* XGBoost uses 2nd order approximation to the objective function. This can lead to results
|
||||
that differ from a random forest implementation that uses the exact value of the
|
||||
objective function.
|
||||
* XGBoost does not perform replacement when subsampling training cases. Each training case
|
||||
can occur in a subsampled set either 0 or 1 time.
|
||||
@ -13,6 +13,7 @@ from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
from .plotting import plot_importance, plot_tree, to_graphviz
|
||||
except ImportError:
|
||||
pass
|
||||
@ -24,4 +25,5 @@ with open(VERSION_FILE) as f:
|
||||
__all__ = ['DMatrix', 'Booster',
|
||||
'train', 'cv',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||
'XGBRFClassifier', 'XGBRFRegressor',
|
||||
'plot_importance', 'plot_tree', 'to_graphviz']
|
||||
|
||||
@ -61,9 +61,9 @@ class XGBModel(XGBModelBase):
|
||||
learning_rate : float
|
||||
Boosting learning rate (xgb's "eta")
|
||||
n_estimators : int
|
||||
Number of boosted trees to fit.
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting.
|
||||
Number of trees to fit.
|
||||
verbosity : int
|
||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||
objective : string or callable
|
||||
Specify the learning task and the corresponding learning objective or
|
||||
a custom objective function to be used (see note below).
|
||||
@ -84,7 +84,9 @@ class XGBModel(XGBModelBase):
|
||||
colsample_bytree : float
|
||||
Subsample ratio of columns when constructing each tree.
|
||||
colsample_bylevel : float
|
||||
Subsample ratio of columns for each split, in each level.
|
||||
Subsample ratio of columns for each level.
|
||||
colsample_bynode : float
|
||||
Subsample ratio of columns for each split.
|
||||
reg_alpha : float (xgb's alpha)
|
||||
L1 regularization term on weights
|
||||
reg_lambda : float (xgb's lambda)
|
||||
@ -132,10 +134,10 @@ class XGBModel(XGBModelBase):
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
silent=True, objective="reg:linear", booster='gbtree',
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
verbosity=1, objective="reg:linear", booster='gbtree',
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None,
|
||||
importance_type="gain", **kwargs):
|
||||
if not SKLEARN_INSTALLED:
|
||||
@ -143,7 +145,7 @@ class XGBModel(XGBModelBase):
|
||||
self.max_depth = max_depth
|
||||
self.learning_rate = learning_rate
|
||||
self.n_estimators = n_estimators
|
||||
self.silent = silent
|
||||
self.verbosity = verbosity
|
||||
self.objective = objective
|
||||
self.booster = booster
|
||||
self.gamma = gamma
|
||||
@ -152,6 +154,7 @@ class XGBModel(XGBModelBase):
|
||||
self.subsample = subsample
|
||||
self.colsample_bytree = colsample_bytree
|
||||
self.colsample_bylevel = colsample_bylevel
|
||||
self.colsample_bynode = colsample_bynode
|
||||
self.reg_alpha = reg_alpha
|
||||
self.reg_lambda = reg_lambda
|
||||
self.scale_pos_weight = scale_pos_weight
|
||||
@ -237,12 +240,14 @@ class XGBModel(XGBModelBase):
|
||||
else:
|
||||
xgb_params['nthread'] = n_jobs
|
||||
|
||||
xgb_params['verbosity'] = 0 if self.silent else 0
|
||||
|
||||
if xgb_params['nthread'] <= 0:
|
||||
xgb_params.pop('nthread', None)
|
||||
return xgb_params
|
||||
|
||||
def get_num_boosting_rounds(self):
|
||||
"""Gets the number of xgboost boosting rounds."""
|
||||
return self.n_estimators
|
||||
|
||||
def save_model(self, fname):
|
||||
"""
|
||||
Save the model to a file.
|
||||
@ -371,7 +376,7 @@ class XGBModel(XGBModelBase):
|
||||
params.update({'eval_metric': eval_metric})
|
||||
|
||||
self._Booster = train(params, trainDmatrix,
|
||||
self.n_estimators, evals=evals,
|
||||
self.get_num_boosting_rounds(), evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result, obj=obj, feval=feval,
|
||||
verbose_eval=verbose, xgb_model=xgb_model,
|
||||
@ -583,21 +588,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
__doc__ = "Implementation of the scikit-learn API for XGBoost classification.\n\n" \
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||
n_estimators=100, silent=True,
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, verbosity=1,
|
||||
objective="binary:logistic", booster='gbtree',
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
colsample_bynode=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||
n_estimators, silent, objective, booster,
|
||||
n_jobs, nthread, gamma, min_child_weight,
|
||||
max_delta_step, subsample,
|
||||
colsample_bytree, colsample_bylevel,
|
||||
reg_alpha, reg_lambda,
|
||||
scale_pos_weight, base_score,
|
||||
random_state, seed, missing, **kwargs)
|
||||
super(XGBClassifier, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, objective=objective, booster=booster,
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
@ -705,9 +711,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
train_dmatrix = DMatrix(X, label=training_labels,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
self._Booster = train(xgb_options, train_dmatrix, self.get_num_boosting_rounds(),
|
||||
evals=evals, early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result, obj=obj, feval=feval,
|
||||
verbose_eval=verbose, xgb_model=xgb_model,
|
||||
callbacks=callbacks)
|
||||
@ -863,12 +868,76 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
return evals_result
|
||||
|
||||
|
||||
class XGBRFClassifier(XGBClassifier):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = "Implementation of the scikit-learn API "\
|
||||
+ "for XGBoost random forest classification.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
|
||||
objective="binary:logistic", n_jobs=1, nthread=None, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
|
||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
||||
missing=None, **kwargs):
|
||||
super(XGBRFClassifier, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, objective=objective, booster='gbtree',
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def get_xgb_params(self):
|
||||
params = super(XGBRFClassifier, self).get_xgb_params()
|
||||
params['num_parallel_tree'] = self.n_estimators
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self):
|
||||
return 1
|
||||
|
||||
|
||||
class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = "Implementation of the scikit-learn API for XGBoost regression.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
|
||||
class XGBRFRegressor(XGBRegressor):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = "Implementation of the scikit-learn API "\
|
||||
+ "for XGBoost random forest regression.\n\n"\
|
||||
+ '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=1, n_estimators=100, verbosity=1,
|
||||
objective="reg:linear", n_jobs=1, nthread=None, gamma=0,
|
||||
min_child_weight=1, max_delta_step=0, subsample=0.8, colsample_bytree=1,
|
||||
colsample_bylevel=1, colsample_bynode=0.8, reg_alpha=0, reg_lambda=1,
|
||||
scale_pos_weight=1, base_score=0.5, random_state=0, seed=None,
|
||||
missing=None, **kwargs):
|
||||
super(XGBRFRegressor, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, objective=objective, booster='gbtree',
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda, scale_pos_weight=scale_pos_weight,
|
||||
base_score=base_score, random_state=random_state, seed=seed, missing=missing,
|
||||
**kwargs)
|
||||
|
||||
def get_xgb_params(self):
|
||||
params = super(XGBRFRegressor, self).get_xgb_params()
|
||||
params['num_parallel_tree'] = self.n_estimators
|
||||
return params
|
||||
|
||||
def get_num_boosting_rounds(self):
|
||||
return 1
|
||||
|
||||
|
||||
class XGBRanker(XGBModel):
|
||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
@ -881,8 +950,8 @@ class XGBRanker(XGBModel):
|
||||
Boosting learning rate (xgb's "eta")
|
||||
n_estimators : int
|
||||
Number of boosted trees to fit.
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting.
|
||||
verbosity : int
|
||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||
objective : string
|
||||
Specify the learning task and the corresponding learning objective.
|
||||
The objective name must start with "rank:".
|
||||
@ -903,7 +972,9 @@ class XGBRanker(XGBModel):
|
||||
colsample_bytree : float
|
||||
Subsample ratio of columns when constructing each tree.
|
||||
colsample_bylevel : float
|
||||
Subsample ratio of columns for each split, in each level.
|
||||
Subsample ratio of columns for each level.
|
||||
colsample_bynode : float
|
||||
Subsample ratio of columns for each split.
|
||||
reg_alpha : float (xgb's alpha)
|
||||
L1 regularization term on weights
|
||||
reg_lambda : float (xgb's lambda)
|
||||
@ -966,18 +1037,22 @@ class XGBRanker(XGBModel):
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
silent=True, objective="rank:pairwise", booster='gbtree',
|
||||
verbosity=1, objective="rank:pairwise", booster='gbtree',
|
||||
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1, colsample_bynode=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||
|
||||
super(XGBRanker, self).__init__(max_depth, learning_rate,
|
||||
n_estimators, silent, objective, booster,
|
||||
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
|
||||
subsample, colsample_bytree, colsample_bylevel,
|
||||
reg_alpha, reg_lambda, scale_pos_weight,
|
||||
base_score, random_state, seed, missing)
|
||||
super(XGBRanker, self).__init__(
|
||||
max_depth=max_depth, learning_rate=learning_rate, n_estimators=n_estimators,
|
||||
verbosity=verbosity, objective=objective, booster=booster,
|
||||
n_jobs=n_jobs, nthread=nthread, gamma=gamma,
|
||||
min_child_weight=min_child_weight, max_delta_step=max_delta_step,
|
||||
subsample=subsample, colsample_bytree=colsample_bytree,
|
||||
colsample_bylevel=colsample_bylevel, colsample_bynode=colsample_bynode,
|
||||
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
|
||||
scale_pos_weight=scale_pos_weight, base_score=base_score,
|
||||
random_state=random_state, seed=seed, missing=missing, **kwargs)
|
||||
if callable(self.objective):
|
||||
raise ValueError("custom objective function not supported by XGBRanker")
|
||||
elif "rank:" not in self.objective:
|
||||
|
||||
@ -29,13 +29,14 @@ def test_binary_classification():
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = cls(random_state=42).fit(X[train_index], y[train_index])
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
def test_multiclass_classification():
|
||||
@ -83,8 +84,8 @@ def test_ranking():
|
||||
valid_group = np.repeat(50, 4)
|
||||
x_test = np.random.rand(100, 10)
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, train_group,
|
||||
@ -97,7 +98,8 @@ def test_ranking():
|
||||
train_data.set_group(train_group)
|
||||
valid_data.set_group(valid_group)
|
||||
|
||||
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
|
||||
evals=[(valid_data, 'validation')])
|
||||
@ -113,7 +115,7 @@ def test_feature_importances_weight():
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="weight").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
||||
|
||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
|
||||
@ -130,11 +132,11 @@ def test_feature_importances_weight():
|
||||
y = pd.Series(digits['target'])
|
||||
X = pd.DataFrame(digits['data'])
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="weight").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="weight").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="weight").fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
|
||||
@ -145,7 +147,7 @@ def test_feature_importances_gain():
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="gain").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
||||
|
||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
@ -163,11 +165,11 @@ def test_feature_importances_gain():
|
||||
y = pd.Series(digits['target'])
|
||||
X = pd.DataFrame(digits['data'])
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="gain").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, importance_type="gain").fit(X, y)
|
||||
random_state=0, tree_method="exact", importance_type="gain").fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
|
||||
@ -199,6 +201,23 @@ def test_boston_housing_regression():
|
||||
assert mean_squared_error(preds4, labels) < 350
|
||||
|
||||
|
||||
def test_boston_housing_rf_regression():
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.datasets import load_boston
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
boston = load_boston()
|
||||
y = boston['target']
|
||||
X = boston['data']
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBRFRegressor(random_state=42).fit(
|
||||
X[train_index], y[train_index])
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
assert mean_squared_error(preds, labels) < 35
|
||||
|
||||
|
||||
def test_parameter_tuning():
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
from sklearn.datasets import load_boston
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user