sklearn api for ranking (#3560)
* added xgbranker * fixed predict method and ranking test * reformatted code in accordance with pep8 * fixed lint error * fixed docstring and added checks on objective * added ranking demo for python * fixed suffix in rank.py
This commit is contained in:
parent
b13c3a8bcc
commit
24a268a2e3
@ -1,6 +1,6 @@
|
||||
Learning to rank
|
||||
====
|
||||
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
||||
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
||||
|
||||
### Parameters
|
||||
The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives:
|
||||
@ -15,14 +15,27 @@ For more usage details please refer to the [binary classification demo](../binar
|
||||
Instructions
|
||||
====
|
||||
The dataset for ranking demo is from LETOR04 MQ2008 fold1.
|
||||
You can use the following command to run the example:
|
||||
Before running the examples, you need to get the data by running:
|
||||
|
||||
Get the data:
|
||||
```
|
||||
./wgetdata.sh
|
||||
```
|
||||
|
||||
### Command Line
|
||||
Run the example:
|
||||
```
|
||||
./runexp.sh
|
||||
```
|
||||
|
||||
### Python
|
||||
There are two ways of doing ranking in python.
|
||||
|
||||
Run the example using `xgboost.train`:
|
||||
```
|
||||
python rank.py
|
||||
```
|
||||
|
||||
Run the example using `XGBRanker`:
|
||||
```
|
||||
python rank_sklearn.py
|
||||
```
|
||||
|
||||
41
demo/rank/rank.py
Normal file
41
demo/rank/rank.py
Normal file
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
from xgboost import DMatrix
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
|
||||
# This script demonstrate how to do ranking with xgboost.train
|
||||
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||
|
||||
group_train = []
|
||||
with open("mq2008.train.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_train.append(int(line.split("\n")[0]))
|
||||
|
||||
group_valid = []
|
||||
with open("mq2008.vali.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_valid.append(int(line.split("\n")[0]))
|
||||
|
||||
group_test = []
|
||||
with open("mq2008.test.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_test.append(int(line.split("\n")[0]))
|
||||
|
||||
train_dmatrix = DMatrix(x_train, y_train)
|
||||
valid_dmatrix = DMatrix(x_valid, y_valid)
|
||||
test_dmatrix = DMatrix(x_test)
|
||||
|
||||
train_dmatrix.set_group(group_train)
|
||||
valid_dmatrix.set_group(group_valid)
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4,
|
||||
evals=[(valid_dmatrix, 'validation')])
|
||||
pred = xgb_model.predict(test_dmatrix)
|
||||
35
demo/rank/rank_sklearn.py
Normal file
35
demo/rank/rank_sklearn.py
Normal file
@ -0,0 +1,35 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
|
||||
# This script demonstrate how to do ranking with XGBRanker
|
||||
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||
|
||||
group_train = []
|
||||
with open("mq2008.train.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_train.append(int(line.split("\n")[0]))
|
||||
|
||||
group_valid = []
|
||||
with open("mq2008.vali.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_valid.append(int(line.split("\n")[0]))
|
||||
|
||||
group_test = []
|
||||
with open("mq2008.test.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_test.append(int(line.split("\n")[0]))
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, group_train,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[group_valid])
|
||||
pred = model.predict(x_test)
|
||||
@ -1 +1 @@
|
||||
Subproject commit f2afdc7788ee8ed6fd06cc095b6838d4ce61bb5a
|
||||
Subproject commit 459ab734d15acd68fd437abf845c7c1730b5a38f
|
||||
@ -101,7 +101,7 @@ class XGBModel(XGBModelBase):
|
||||
None, defaults to np.nan.
|
||||
**kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md.
|
||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
||||
will result in a TypeError.
|
||||
Note:
|
||||
@ -259,7 +259,7 @@ class XGBModel(XGBModelBase):
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
method. It must return a str, value pair where the str is a name
|
||||
@ -465,7 +465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
instance weights on the i-th validation set.
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.md. If callable, a custom evaluation metric. The call
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
method. It must return a str, value pair where the str is a name
|
||||
@ -679,3 +679,232 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
|
||||
# pylint: disable=missing-docstring
|
||||
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
|
||||
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
|
||||
|
||||
|
||||
class XGBRanker(XGBModel):
|
||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_depth : int
|
||||
Maximum tree depth for base learners.
|
||||
learning_rate : float
|
||||
Boosting learning rate (xgb's "eta")
|
||||
n_estimators : int
|
||||
Number of boosted trees to fit.
|
||||
silent : boolean
|
||||
Whether to print messages while running boosting.
|
||||
objective : string
|
||||
Specify the learning task and the corresponding learning objective.
|
||||
Only "rank:pairwise" is supported currently.
|
||||
booster: string
|
||||
Specify which booster to use: gbtree, gblinear or dart.
|
||||
nthread : int
|
||||
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
|
||||
n_jobs : int
|
||||
Number of parallel threads used to run xgboost. (replaces nthread)
|
||||
gamma : float
|
||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
min_child_weight : int
|
||||
Minimum sum of instance weight(hessian) needed in a child.
|
||||
max_delta_step : int
|
||||
Maximum delta step we allow each tree's weight estimation to be.
|
||||
subsample : float
|
||||
Subsample ratio of the training instance.
|
||||
colsample_bytree : float
|
||||
Subsample ratio of columns when constructing each tree.
|
||||
colsample_bylevel : float
|
||||
Subsample ratio of columns for each split, in each level.
|
||||
reg_alpha : float (xgb's alpha)
|
||||
L1 regularization term on weights
|
||||
reg_lambda : float (xgb's lambda)
|
||||
L2 regularization term on weights
|
||||
scale_pos_weight : float
|
||||
Balancing of positive and negative weights.
|
||||
base_score:
|
||||
The initial prediction score of all instances, global bias.
|
||||
seed : int
|
||||
Random number seed. (Deprecated, please use random_state)
|
||||
random_state : int
|
||||
Random number seed. (replaces seed)
|
||||
missing : float, optional
|
||||
Value in the data which needs to be present as a missing value. If
|
||||
None, defaults to np.nan.
|
||||
**kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
|
||||
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
||||
will result in a TypeError.
|
||||
Note:
|
||||
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
|
||||
this argument will interact properly with Sklearn.
|
||||
|
||||
Note
|
||||
----
|
||||
A custom objective function is currently not supported by XGBRanker.
|
||||
|
||||
Group information is required for ranking tasks. Before fitting the model, your data need to
|
||||
be sorted by group. When fitting the model, you need to provide an additional array that
|
||||
contains the size of each group.
|
||||
|
||||
For example, if your original data look like:
|
||||
|
||||
| qid | label | features |
|
||||
| 1 | 0 | x_1 |
|
||||
| 1 | 1 | x_2 |
|
||||
| 1 | 0 | x_3 |
|
||||
| 2 | 0 | x_4 |
|
||||
| 2 | 1 | x_5 |
|
||||
| 2 | 1 | x_6 |
|
||||
| 2 | 1 | x_7 |
|
||||
|
||||
then your group array should be [3, 4].
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
silent=True, objective="rank:pairwise", booster='gbtree',
|
||||
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||
|
||||
super(XGBRanker, self).__init__(max_depth, learning_rate,
|
||||
n_estimators, silent, objective, booster,
|
||||
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
|
||||
subsample, colsample_bytree, colsample_bylevel,
|
||||
reg_alpha, reg_lambda, scale_pos_weight,
|
||||
base_score, random_state, seed, missing)
|
||||
if callable(self.objective):
|
||||
raise ValueError("custom objective function not supported by XGBRanker")
|
||||
elif self.objective != "rank:pairwise":
|
||||
raise ValueError("please use XGBRanker for ranking task")
|
||||
|
||||
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
|
||||
eval_group=None, eval_metric=None, early_stopping_rounds=None,
|
||||
verbose=False, xgb_model=None):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""
|
||||
Fit the gradient boosting model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix
|
||||
y : array_like
|
||||
Labels
|
||||
group : array_like
|
||||
group size of training data
|
||||
sample_weight : array_like
|
||||
instance weights
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as a validation set for
|
||||
early-stopping
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||
instance weights on the i-th validation set.
|
||||
eval_group : list of arrays, optional
|
||||
A list that contains the group size corresponds to each
|
||||
(X, y) pair in eval_set
|
||||
eval_metric : str, callable, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst. If callable, a custom evaluation metric. The call
|
||||
signature is func(y_predicted, y_true) where y_true will be a
|
||||
DMatrix object such that you may need to call the get_label
|
||||
method. It must return a str, value pair where the str is a name
|
||||
for the evaluation and value is the value of the evaluation
|
||||
function. This objective is always minimized.
|
||||
early_stopping_rounds : int
|
||||
Activates early stopping. Validation error needs to decrease at
|
||||
least every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals. If there's more than one,
|
||||
will use the last. Returns the model from the last iteration
|
||||
(not the best one). If early stopping occurs, the model will
|
||||
have three additional fields: bst.best_score, bst.best_iteration
|
||||
and bst.best_ntree_limit.
|
||||
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||
and/or num_class appears in the parameters)
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
file name of stored xgb model or 'Booster' instance Xgb model to be
|
||||
loaded before training (allows training continuation).
|
||||
"""
|
||||
# check if group information is provided
|
||||
if group is None:
|
||||
raise ValueError("group is required for ranking task")
|
||||
|
||||
if eval_set is not None:
|
||||
if eval_group is None:
|
||||
raise ValueError("eval_group is required if eval_set is not None")
|
||||
elif len(eval_group) != len(eval_set):
|
||||
raise ValueError("length of eval_group should match that of eval_set")
|
||||
elif any(group is None for group in eval_group):
|
||||
raise ValueError("group is required for all eval datasets for ranking task")
|
||||
|
||||
def _dmat_init(group, **params):
|
||||
ret = DMatrix(**params)
|
||||
ret.set_group(group)
|
||||
return ret
|
||||
|
||||
if sample_weight is not None:
|
||||
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
else:
|
||||
train_dmatrix = _dmat_init(group, data=X, label=y,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
|
||||
evals_result = {}
|
||||
|
||||
if eval_set is not None:
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1],
|
||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||
nthread=self.n_jobs) for i in range(len(eval_set))]
|
||||
nevals = len(evals)
|
||||
eval_names = ["eval_{}".format(i) for i in range(nevals)]
|
||||
evals = list(zip(evals, eval_names))
|
||||
else:
|
||||
evals = ()
|
||||
|
||||
params = self.get_xgb_params()
|
||||
|
||||
feval = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({'eval_metric': eval_metric})
|
||||
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.n_estimators,
|
||||
early_stopping_rounds=early_stopping_rounds, evals=evals,
|
||||
evals_result=evals_result, feval=feval,
|
||||
verbose_eval=verbose, xgb_model=xgb_model)
|
||||
|
||||
self.objective = params["objective"]
|
||||
|
||||
if evals_result:
|
||||
for val in evals_result.items():
|
||||
evals_result_key = list(val[1].keys())[0]
|
||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||
self.evals_result = evals_result
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
self.best_score = self._Booster.best_score
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||
|
||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||
if ntree_limit is None:
|
||||
ntree_limit = getattr(self, "best_ntree_limit", 0)
|
||||
|
||||
return self.get_booster().predict(test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit)
|
||||
|
||||
@ -77,6 +77,40 @@ def test_multiclass_classification():
|
||||
check_pred(preds4, labels)
|
||||
|
||||
|
||||
def test_ranking():
|
||||
tm._skip_if_no_sklearn()
|
||||
# generate random data
|
||||
x_train = np.random.rand(1000, 10)
|
||||
y_train = np.random.randint(5, size=1000)
|
||||
train_group = np.repeat(50, 20)
|
||||
x_valid = np.random.rand(200, 10)
|
||||
y_valid = np.random.randint(5, size=200)
|
||||
valid_group = np.repeat(50, 4)
|
||||
x_test = np.random.rand(100, 10)
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, train_group,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||
pred = model.predict(x_test)
|
||||
|
||||
train_data = xgb.DMatrix(x_train, y_train)
|
||||
valid_data = xgb.DMatrix(x_valid, y_valid)
|
||||
test_data = xgb.DMatrix(x_test)
|
||||
train_data.set_group(train_group)
|
||||
valid_data.set_group(valid_group)
|
||||
|
||||
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
|
||||
evals=[(valid_data, 'validation')])
|
||||
pred_orig = xgb_model_orig.predict(test_data)
|
||||
|
||||
np.testing.assert_almost_equal(pred, pred_orig)
|
||||
|
||||
|
||||
def test_feature_importances():
|
||||
tm._skip_if_no_sklearn()
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user