sklearn api for ranking (#3560)
* added xgbranker * fixed predict method and ranking test * reformatted code in accordance with pep8 * fixed lint error * fixed docstring and added checks on objective * added ranking demo for python * fixed suffix in rank.py
This commit is contained in:
parent
b13c3a8bcc
commit
24a268a2e3
@ -1,6 +1,6 @@
|
|||||||
Learning to rank
|
Learning to rank
|
||||||
====
|
====
|
||||||
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
||||||
|
|
||||||
### Parameters
|
### Parameters
|
||||||
The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives:
|
The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives:
|
||||||
@ -15,14 +15,27 @@ For more usage details please refer to the [binary classification demo](../binar
|
|||||||
Instructions
|
Instructions
|
||||||
====
|
====
|
||||||
The dataset for ranking demo is from LETOR04 MQ2008 fold1.
|
The dataset for ranking demo is from LETOR04 MQ2008 fold1.
|
||||||
You can use the following command to run the example:
|
Before running the examples, you need to get the data by running:
|
||||||
|
|
||||||
Get the data:
|
|
||||||
```
|
```
|
||||||
./wgetdata.sh
|
./wgetdata.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Command Line
|
||||||
Run the example:
|
Run the example:
|
||||||
```
|
```
|
||||||
./runexp.sh
|
./runexp.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Python
|
||||||
|
There are two ways of doing ranking in python.
|
||||||
|
|
||||||
|
Run the example using `xgboost.train`:
|
||||||
|
```
|
||||||
|
python rank.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the example using `XGBRanker`:
|
||||||
|
```
|
||||||
|
python rank_sklearn.py
|
||||||
|
```
|
||||||
|
|||||||
41
demo/rank/rank.py
Normal file
41
demo/rank/rank.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import DMatrix
|
||||||
|
from sklearn.datasets import load_svmlight_file
|
||||||
|
|
||||||
|
|
||||||
|
# This script demonstrate how to do ranking with xgboost.train
|
||||||
|
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||||
|
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||||
|
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||||
|
|
||||||
|
group_train = []
|
||||||
|
with open("mq2008.train.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_train.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
group_valid = []
|
||||||
|
with open("mq2008.vali.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_valid.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
group_test = []
|
||||||
|
with open("mq2008.test.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_test.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
train_dmatrix = DMatrix(x_train, y_train)
|
||||||
|
valid_dmatrix = DMatrix(x_valid, y_valid)
|
||||||
|
test_dmatrix = DMatrix(x_test)
|
||||||
|
|
||||||
|
train_dmatrix.set_group(group_train)
|
||||||
|
valid_dmatrix.set_group(group_valid)
|
||||||
|
|
||||||
|
params = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||||
|
'min_child_weight': 0.1, 'max_depth': 6}
|
||||||
|
xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4,
|
||||||
|
evals=[(valid_dmatrix, 'validation')])
|
||||||
|
pred = xgb_model.predict(test_dmatrix)
|
||||||
35
demo/rank/rank_sklearn.py
Normal file
35
demo/rank/rank_sklearn.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
import xgboost as xgb
|
||||||
|
from sklearn.datasets import load_svmlight_file
|
||||||
|
|
||||||
|
|
||||||
|
# This script demonstrate how to do ranking with XGBRanker
|
||||||
|
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||||
|
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||||
|
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||||
|
|
||||||
|
group_train = []
|
||||||
|
with open("mq2008.train.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_train.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
group_valid = []
|
||||||
|
with open("mq2008.vali.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_valid.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
group_test = []
|
||||||
|
with open("mq2008.test.group", "r") as f:
|
||||||
|
data = f.readlines()
|
||||||
|
for line in data:
|
||||||
|
group_test.append(int(line.split("\n")[0]))
|
||||||
|
|
||||||
|
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||||
|
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||||
|
'max_depth': 6, 'n_estimators': 4}
|
||||||
|
model = xgb.sklearn.XGBRanker(**params)
|
||||||
|
model.fit(x_train, y_train, group_train,
|
||||||
|
eval_set=[(x_valid, y_valid)], eval_group=[group_valid])
|
||||||
|
pred = model.predict(x_test)
|
||||||
@ -1 +1 @@
|
|||||||
Subproject commit f2afdc7788ee8ed6fd06cc095b6838d4ce61bb5a
|
Subproject commit 459ab734d15acd68fd437abf845c7c1730b5a38f
|
||||||
@ -101,7 +101,7 @@ class XGBModel(XGBModelBase):
|
|||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
**kwargs : dict, optional
|
**kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md.
|
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||||
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
||||||
will result in a TypeError.
|
will result in a TypeError.
|
||||||
Note:
|
Note:
|
||||||
@ -259,7 +259,7 @@ class XGBModel(XGBModelBase):
|
|||||||
instance weights on the i-th validation set.
|
instance weights on the i-th validation set.
|
||||||
eval_metric : str, callable, optional
|
eval_metric : str, callable, optional
|
||||||
If a str, should be a built-in evaluation metric to use. See
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||||
signature is func(y_predicted, y_true) where y_true will be a
|
signature is func(y_predicted, y_true) where y_true will be a
|
||||||
DMatrix object such that you may need to call the get_label
|
DMatrix object such that you may need to call the get_label
|
||||||
method. It must return a str, value pair where the str is a name
|
method. It must return a str, value pair where the str is a name
|
||||||
@ -465,7 +465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
instance weights on the i-th validation set.
|
instance weights on the i-th validation set.
|
||||||
eval_metric : str, callable, optional
|
eval_metric : str, callable, optional
|
||||||
If a str, should be a built-in evaluation metric to use. See
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||||
signature is func(y_predicted, y_true) where y_true will be a
|
signature is func(y_predicted, y_true) where y_true will be a
|
||||||
DMatrix object such that you may need to call the get_label
|
DMatrix object such that you may need to call the get_label
|
||||||
method. It must return a str, value pair where the str is a name
|
method. It must return a str, value pair where the str is a name
|
||||||
@ -679,3 +679,232 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
|||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
|
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
|
||||||
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class XGBRanker(XGBModel):
|
||||||
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
|
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_depth : int
|
||||||
|
Maximum tree depth for base learners.
|
||||||
|
learning_rate : float
|
||||||
|
Boosting learning rate (xgb's "eta")
|
||||||
|
n_estimators : int
|
||||||
|
Number of boosted trees to fit.
|
||||||
|
silent : boolean
|
||||||
|
Whether to print messages while running boosting.
|
||||||
|
objective : string
|
||||||
|
Specify the learning task and the corresponding learning objective.
|
||||||
|
Only "rank:pairwise" is supported currently.
|
||||||
|
booster: string
|
||||||
|
Specify which booster to use: gbtree, gblinear or dart.
|
||||||
|
nthread : int
|
||||||
|
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
|
||||||
|
n_jobs : int
|
||||||
|
Number of parallel threads used to run xgboost. (replaces nthread)
|
||||||
|
gamma : float
|
||||||
|
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
|
min_child_weight : int
|
||||||
|
Minimum sum of instance weight(hessian) needed in a child.
|
||||||
|
max_delta_step : int
|
||||||
|
Maximum delta step we allow each tree's weight estimation to be.
|
||||||
|
subsample : float
|
||||||
|
Subsample ratio of the training instance.
|
||||||
|
colsample_bytree : float
|
||||||
|
Subsample ratio of columns when constructing each tree.
|
||||||
|
colsample_bylevel : float
|
||||||
|
Subsample ratio of columns for each split, in each level.
|
||||||
|
reg_alpha : float (xgb's alpha)
|
||||||
|
L1 regularization term on weights
|
||||||
|
reg_lambda : float (xgb's lambda)
|
||||||
|
L2 regularization term on weights
|
||||||
|
scale_pos_weight : float
|
||||||
|
Balancing of positive and negative weights.
|
||||||
|
base_score:
|
||||||
|
The initial prediction score of all instances, global bias.
|
||||||
|
seed : int
|
||||||
|
Random number seed. (Deprecated, please use random_state)
|
||||||
|
random_state : int
|
||||||
|
Random number seed. (replaces seed)
|
||||||
|
missing : float, optional
|
||||||
|
Value in the data which needs to be present as a missing value. If
|
||||||
|
None, defaults to np.nan.
|
||||||
|
**kwargs : dict, optional
|
||||||
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||||
|
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||||
|
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
||||||
|
will result in a TypeError.
|
||||||
|
Note:
|
||||||
|
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
|
||||||
|
this argument will interact properly with Sklearn.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
A custom objective function is currently not supported by XGBRanker.
|
||||||
|
|
||||||
|
Group information is required for ranking tasks. Before fitting the model, your data need to
|
||||||
|
be sorted by group. When fitting the model, you need to provide an additional array that
|
||||||
|
contains the size of each group.
|
||||||
|
|
||||||
|
For example, if your original data look like:
|
||||||
|
|
||||||
|
| qid | label | features |
|
||||||
|
| 1 | 0 | x_1 |
|
||||||
|
| 1 | 1 | x_2 |
|
||||||
|
| 1 | 0 | x_3 |
|
||||||
|
| 2 | 0 | x_4 |
|
||||||
|
| 2 | 1 | x_5 |
|
||||||
|
| 2 | 1 | x_6 |
|
||||||
|
| 2 | 1 | x_7 |
|
||||||
|
|
||||||
|
then your group array should be [3, 4].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
|
silent=True, objective="rank:pairwise", booster='gbtree',
|
||||||
|
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
|
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
|
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||||
|
|
||||||
|
super(XGBRanker, self).__init__(max_depth, learning_rate,
|
||||||
|
n_estimators, silent, objective, booster,
|
||||||
|
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
|
||||||
|
subsample, colsample_bytree, colsample_bylevel,
|
||||||
|
reg_alpha, reg_lambda, scale_pos_weight,
|
||||||
|
base_score, random_state, seed, missing)
|
||||||
|
if callable(self.objective):
|
||||||
|
raise ValueError("custom objective function not supported by XGBRanker")
|
||||||
|
elif self.objective != "rank:pairwise":
|
||||||
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
|
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
||||||
|
eval_group=None, eval_metric=None, early_stopping_rounds=None,
|
||||||
|
verbose=False, xgb_model=None):
|
||||||
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
|
"""
|
||||||
|
Fit the gradient boosting model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : array_like
|
||||||
|
Feature matrix
|
||||||
|
y : array_like
|
||||||
|
Labels
|
||||||
|
group : array_like
|
||||||
|
group size of training data
|
||||||
|
sample_weight : array_like
|
||||||
|
instance weights
|
||||||
|
eval_set : list, optional
|
||||||
|
A list of (X, y) tuple pairs to use as a validation set for
|
||||||
|
early-stopping
|
||||||
|
sample_weight_eval_set : list, optional
|
||||||
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||||
|
instance weights on the i-th validation set.
|
||||||
|
eval_group : list of arrays, optional
|
||||||
|
A list that contains the group size corresponds to each
|
||||||
|
(X, y) pair in eval_set
|
||||||
|
eval_metric : str, callable, optional
|
||||||
|
If a str, should be a built-in evaluation metric to use. See
|
||||||
|
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||||
|
signature is func(y_predicted, y_true) where y_true will be a
|
||||||
|
DMatrix object such that you may need to call the get_label
|
||||||
|
method. It must return a str, value pair where the str is a name
|
||||||
|
for the evaluation and value is the value of the evaluation
|
||||||
|
function. This objective is always minimized.
|
||||||
|
early_stopping_rounds : int
|
||||||
|
Activates early stopping. Validation error needs to decrease at
|
||||||
|
least every <early_stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals. If there's more than one,
|
||||||
|
will use the last. Returns the model from the last iteration
|
||||||
|
(not the best one). If early stopping occurs, the model will
|
||||||
|
have three additional fields: bst.best_score, bst.best_iteration
|
||||||
|
and bst.best_ntree_limit.
|
||||||
|
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||||
|
and/or num_class appears in the parameters)
|
||||||
|
verbose : bool
|
||||||
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
|
metric measured on the validation set to stderr.
|
||||||
|
xgb_model : str
|
||||||
|
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||||
|
loaded before training (allows training continuation).
|
||||||
|
"""
|
||||||
|
# check if group information is provided
|
||||||
|
if group is None:
|
||||||
|
raise ValueError("group is required for ranking task")
|
||||||
|
|
||||||
|
if eval_set is not None:
|
||||||
|
if eval_group is None:
|
||||||
|
raise ValueError("eval_group is required if eval_set is not None")
|
||||||
|
elif len(eval_group) != len(eval_set):
|
||||||
|
raise ValueError("length of eval_group should match that of eval_set")
|
||||||
|
elif any(group is None for group in eval_group):
|
||||||
|
raise ValueError("group is required for all eval datasets for ranking task")
|
||||||
|
|
||||||
|
def _dmat_init(group, **params):
|
||||||
|
ret = DMatrix(**params)
|
||||||
|
ret.set_group(group)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if sample_weight is not None:
|
||||||
|
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
else:
|
||||||
|
train_dmatrix = _dmat_init(group, data=X, label=y,
|
||||||
|
missing=self.missing, nthread=self.n_jobs)
|
||||||
|
|
||||||
|
evals_result = {}
|
||||||
|
|
||||||
|
if eval_set is not None:
|
||||||
|
if sample_weight_eval_set is None:
|
||||||
|
sample_weight_eval_set = [None] * len(eval_set)
|
||||||
|
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1],
|
||||||
|
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||||
|
nthread=self.n_jobs) for i in range(len(eval_set))]
|
||||||
|
nevals = len(evals)
|
||||||
|
eval_names = ["eval_{}".format(i) for i in range(nevals)]
|
||||||
|
evals = list(zip(evals, eval_names))
|
||||||
|
else:
|
||||||
|
evals = ()
|
||||||
|
|
||||||
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
|
feval = eval_metric if callable(eval_metric) else None
|
||||||
|
if eval_metric is not None:
|
||||||
|
if callable(eval_metric):
|
||||||
|
eval_metric = None
|
||||||
|
else:
|
||||||
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
|
self._Booster = train(params, train_dmatrix,
|
||||||
|
self.n_estimators,
|
||||||
|
early_stopping_rounds=early_stopping_rounds, evals=evals,
|
||||||
|
evals_result=evals_result, feval=feval,
|
||||||
|
verbose_eval=verbose, xgb_model=xgb_model)
|
||||||
|
|
||||||
|
self.objective = params["objective"]
|
||||||
|
|
||||||
|
if evals_result:
|
||||||
|
for val in evals_result.items():
|
||||||
|
evals_result_key = list(val[1].keys())[0]
|
||||||
|
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||||
|
self.evals_result = evals_result
|
||||||
|
|
||||||
|
if early_stopping_rounds is not None:
|
||||||
|
self.best_score = self._Booster.best_score
|
||||||
|
self.best_iteration = self._Booster.best_iteration
|
||||||
|
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
|
|
||||||
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
|
if ntree_limit is None:
|
||||||
|
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||||
|
|
||||||
|
return self.get_booster().predict(test_dmatrix,
|
||||||
|
output_margin=output_margin,
|
||||||
|
ntree_limit=ntree_limit)
|
||||||
|
|||||||
@ -77,6 +77,40 @@ def test_multiclass_classification():
|
|||||||
check_pred(preds4, labels)
|
check_pred(preds4, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ranking():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
# generate random data
|
||||||
|
x_train = np.random.rand(1000, 10)
|
||||||
|
y_train = np.random.randint(5, size=1000)
|
||||||
|
train_group = np.repeat(50, 20)
|
||||||
|
x_valid = np.random.rand(200, 10)
|
||||||
|
y_valid = np.random.randint(5, size=200)
|
||||||
|
valid_group = np.repeat(50, 4)
|
||||||
|
x_test = np.random.rand(100, 10)
|
||||||
|
|
||||||
|
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||||
|
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||||
|
'max_depth': 6, 'n_estimators': 4}
|
||||||
|
model = xgb.sklearn.XGBRanker(**params)
|
||||||
|
model.fit(x_train, y_train, train_group,
|
||||||
|
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||||
|
pred = model.predict(x_test)
|
||||||
|
|
||||||
|
train_data = xgb.DMatrix(x_train, y_train)
|
||||||
|
valid_data = xgb.DMatrix(x_valid, y_valid)
|
||||||
|
test_data = xgb.DMatrix(x_test)
|
||||||
|
train_data.set_group(train_group)
|
||||||
|
valid_data.set_group(valid_group)
|
||||||
|
|
||||||
|
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||||
|
'min_child_weight': 0.1, 'max_depth': 6}
|
||||||
|
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
|
||||||
|
evals=[(valid_data, 'validation')])
|
||||||
|
pred_orig = xgb_model_orig.predict(test_data)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(pred, pred_orig)
|
||||||
|
|
||||||
|
|
||||||
def test_feature_importances():
|
def test_feature_importances():
|
||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user