sklearn api for ranking (#3560)

* added xgbranker

* fixed predict method and ranking test

* reformatted code in accordance with pep8

* fixed lint error

* fixed docstring and added checks on objective

* added ranking demo for python

* fixed suffix in rank.py
This commit is contained in:
Shiki-H 2018-08-21 11:26:48 -04:00 committed by Philip Hyunsu Cho
parent b13c3a8bcc
commit 24a268a2e3
6 changed files with 359 additions and 7 deletions

View File

@ -1,6 +1,6 @@
Learning to rank
====
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
### Parameters
The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives:
@ -15,14 +15,27 @@ For more usage details please refer to the [binary classification demo](../binar
Instructions
====
The dataset for ranking demo is from LETOR04 MQ2008 fold1.
You can use the following command to run the example:
Before running the examples, you need to get the data by running:
Get the data:
```
./wgetdata.sh
```
### Command Line
Run the example:
```
./runexp.sh
```
### Python
There are two ways of doing ranking in python.
Run the example using `xgboost.train`:
```
python rank.py
```
Run the example using `XGBRanker`:
```
python rank_sklearn.py
```

41
demo/rank/rank.py Normal file
View File

@ -0,0 +1,41 @@
#!/usr/bin/python
import xgboost as xgb
from xgboost import DMatrix
from sklearn.datasets import load_svmlight_file
# This script demonstrate how to do ranking with xgboost.train
x_train, y_train = load_svmlight_file("mq2008.train")
x_valid, y_valid = load_svmlight_file("mq2008.vali")
x_test, y_test = load_svmlight_file("mq2008.test")
group_train = []
with open("mq2008.train.group", "r") as f:
data = f.readlines()
for line in data:
group_train.append(int(line.split("\n")[0]))
group_valid = []
with open("mq2008.vali.group", "r") as f:
data = f.readlines()
for line in data:
group_valid.append(int(line.split("\n")[0]))
group_test = []
with open("mq2008.test.group", "r") as f:
data = f.readlines()
for line in data:
group_test.append(int(line.split("\n")[0]))
train_dmatrix = DMatrix(x_train, y_train)
valid_dmatrix = DMatrix(x_valid, y_valid)
test_dmatrix = DMatrix(x_test)
train_dmatrix.set_group(group_train)
valid_dmatrix.set_group(group_valid)
params = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
'min_child_weight': 0.1, 'max_depth': 6}
xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4,
evals=[(valid_dmatrix, 'validation')])
pred = xgb_model.predict(test_dmatrix)

35
demo/rank/rank_sklearn.py Normal file
View File

@ -0,0 +1,35 @@
#!/usr/bin/python
import xgboost as xgb
from sklearn.datasets import load_svmlight_file
# This script demonstrate how to do ranking with XGBRanker
x_train, y_train = load_svmlight_file("mq2008.train")
x_valid, y_valid = load_svmlight_file("mq2008.vali")
x_test, y_test = load_svmlight_file("mq2008.test")
group_train = []
with open("mq2008.train.group", "r") as f:
data = f.readlines()
for line in data:
group_train.append(int(line.split("\n")[0]))
group_valid = []
with open("mq2008.vali.group", "r") as f:
data = f.readlines()
for line in data:
group_valid.append(int(line.split("\n")[0]))
group_test = []
with open("mq2008.test.group", "r") as f:
data = f.readlines()
for line in data:
group_test.append(int(line.split("\n")[0]))
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6, 'n_estimators': 4}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, group_train,
eval_set=[(x_valid, y_valid)], eval_group=[group_valid])
pred = model.predict(x_test)

@ -1 +1 @@
Subproject commit f2afdc7788ee8ed6fd06cc095b6838d4ce61bb5a
Subproject commit 459ab734d15acd68fd437abf845c7c1730b5a38f

View File

@ -101,7 +101,7 @@ class XGBModel(XGBModelBase):
None, defaults to np.nan.
**kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md.
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
will result in a TypeError.
Note:
@ -259,7 +259,7 @@ class XGBModel(XGBModelBase):
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
@ -465,7 +465,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
instance weights on the i-th validation set.
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
@ -679,3 +679,232 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = """Implementation of the scikit-learn API for XGBoost regression.
""" + '\n'.join(XGBModel.__doc__.split('\n')[2:])
class XGBRanker(XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
Parameters
----------
max_depth : int
Maximum tree depth for base learners.
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
objective : string
Specify the learning task and the corresponding learning objective.
Only "rank:pairwise" is supported currently.
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread)
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
colsample_bylevel : float
Subsample ratio of columns for each split, in each level.
reg_alpha : float (xgb's alpha)
L1 regularization term on weights
reg_lambda : float (xgb's lambda)
L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
seed : int
Random number seed. (Deprecated, please use random_state)
random_state : int
Random number seed. (replaces seed)
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
**kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.rst.
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
will result in a TypeError.
Note:
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
this argument will interact properly with Sklearn.
Note
----
A custom objective function is currently not supported by XGBRanker.
Group information is required for ranking tasks. Before fitting the model, your data need to
be sorted by group. When fitting the model, you need to provide an additional array that
contains the size of each group.
For example, if your original data look like:
| qid | label | features |
| 1 | 0 | x_1 |
| 1 | 1 | x_2 |
| 1 | 0 | x_3 |
| 2 | 0 | x_4 |
| 2 | 1 | x_5 |
| 2 | 1 | x_6 |
| 2 | 1 | x_7 |
then your group array should be [3, 4].
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="rank:pairwise", booster='gbtree',
n_jobs=-1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
super(XGBRanker, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
n_jobs, nthread, gamma, min_child_weight, max_delta_step,
subsample, colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda, scale_pos_weight,
base_score, random_state, seed, missing)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
elif self.objective != "rank:pairwise":
raise ValueError("please use XGBRanker for ranking task")
def fit(self, X, y, group, sample_weight=None, eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None, early_stopping_rounds=None,
verbose=False, xgb_model=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit the gradient boosting model
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
group : array_like
group size of training data
sample_weight : array_like
instance weights
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
instance weights on the i-th validation set.
eval_group : list of arrays, optional
A list that contains the group size corresponds to each
(X, y) pair in eval_set
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have three additional fields: bst.best_score, bst.best_iteration
and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
xgb_model : str
file name of stored xgb model or 'Booster' instance Xgb model to be
loaded before training (allows training continuation).
"""
# check if group information is provided
if group is None:
raise ValueError("group is required for ranking task")
if eval_set is not None:
if eval_group is None:
raise ValueError("eval_group is required if eval_set is not None")
elif len(eval_group) != len(eval_set):
raise ValueError("length of eval_group should match that of eval_set")
elif any(group is None for group in eval_group):
raise ValueError("group is required for all eval datasets for ranking task")
def _dmat_init(group, **params):
ret = DMatrix(**params)
ret.set_group(group)
return ret
if sample_weight is not None:
train_dmatrix = _dmat_init(group, data=X, label=y, weight=sample_weight,
missing=self.missing, nthread=self.n_jobs)
else:
train_dmatrix = _dmat_init(group, data=X, label=y,
missing=self.missing, nthread=self.n_jobs)
evals_result = {}
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
evals = [_dmat_init(eval_group[i], data=eval_set[i][0], label=eval_set[i][1],
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs) for i in range(len(eval_set))]
nevals = len(evals)
eval_names = ["eval_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names))
else:
evals = ()
params = self.get_xgb_params()
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})
self._Booster = train(params, train_dmatrix,
self.n_estimators,
early_stopping_rounds=early_stopping_rounds, evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model)
self.objective = params["objective"]
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
self.evals_result = evals_result
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
self.best_ntree_limit = self._Booster.best_ntree_limit
return self
def predict(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing)
if ntree_limit is None:
ntree_limit = getattr(self, "best_ntree_limit", 0)
return self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)

View File

@ -77,6 +77,40 @@ def test_multiclass_classification():
check_pred(preds4, labels)
def test_ranking():
tm._skip_if_no_sklearn()
# generate random data
x_train = np.random.rand(1000, 10)
y_train = np.random.randint(5, size=1000)
train_group = np.repeat(50, 20)
x_valid = np.random.rand(200, 10)
y_valid = np.random.randint(5, size=200)
valid_group = np.repeat(50, 4)
x_test = np.random.rand(100, 10)
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6, 'n_estimators': 4}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, train_group,
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
pred = model.predict(x_test)
train_data = xgb.DMatrix(x_train, y_train)
valid_data = xgb.DMatrix(x_valid, y_valid)
test_data = xgb.DMatrix(x_test)
train_data.set_group(train_group)
valid_data.set_group(valid_group)
params_orig = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
'min_child_weight': 0.1, 'max_depth': 6}
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
evals=[(valid_data, 'validation')])
pred_orig = xgb_model_orig.predict(test_data)
np.testing.assert_almost_equal(pred, pred_orig)
def test_feature_importances():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits