sklearn api for ranking (#3560)
* added xgbranker * fixed predict method and ranking test * reformatted code in accordance with pep8 * fixed lint error * fixed docstring and added checks on objective * added ranking demo for python * fixed suffix in rank.py
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
b13c3a8bcc
commit
24a268a2e3
@@ -1,6 +1,6 @@
|
||||
Learning to rank
|
||||
====
|
||||
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
||||
XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.md#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank, this function is not yet completed. Currently, we provide pairwise rank.
|
||||
|
||||
### Parameters
|
||||
The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives:
|
||||
@@ -15,14 +15,27 @@ For more usage details please refer to the [binary classification demo](../binar
|
||||
Instructions
|
||||
====
|
||||
The dataset for ranking demo is from LETOR04 MQ2008 fold1.
|
||||
You can use the following command to run the example:
|
||||
Before running the examples, you need to get the data by running:
|
||||
|
||||
Get the data:
|
||||
```
|
||||
./wgetdata.sh
|
||||
```
|
||||
|
||||
### Command Line
|
||||
Run the example:
|
||||
```
|
||||
./runexp.sh
|
||||
```
|
||||
|
||||
### Python
|
||||
There are two ways of doing ranking in python.
|
||||
|
||||
Run the example using `xgboost.train`:
|
||||
```
|
||||
python rank.py
|
||||
```
|
||||
|
||||
Run the example using `XGBRanker`:
|
||||
```
|
||||
python rank_sklearn.py
|
||||
```
|
||||
|
||||
41
demo/rank/rank.py
Normal file
41
demo/rank/rank.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
from xgboost import DMatrix
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
|
||||
# This script demonstrate how to do ranking with xgboost.train
|
||||
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||
|
||||
group_train = []
|
||||
with open("mq2008.train.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_train.append(int(line.split("\n")[0]))
|
||||
|
||||
group_valid = []
|
||||
with open("mq2008.vali.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_valid.append(int(line.split("\n")[0]))
|
||||
|
||||
group_test = []
|
||||
with open("mq2008.test.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_test.append(int(line.split("\n")[0]))
|
||||
|
||||
train_dmatrix = DMatrix(x_train, y_train)
|
||||
valid_dmatrix = DMatrix(x_valid, y_valid)
|
||||
test_dmatrix = DMatrix(x_test)
|
||||
|
||||
train_dmatrix.set_group(group_train)
|
||||
valid_dmatrix.set_group(group_valid)
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4,
|
||||
evals=[(valid_dmatrix, 'validation')])
|
||||
pred = xgb_model.predict(test_dmatrix)
|
||||
35
demo/rank/rank_sklearn.py
Normal file
35
demo/rank/rank_sklearn.py
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
|
||||
# This script demonstrate how to do ranking with XGBRanker
|
||||
x_train, y_train = load_svmlight_file("mq2008.train")
|
||||
x_valid, y_valid = load_svmlight_file("mq2008.vali")
|
||||
x_test, y_test = load_svmlight_file("mq2008.test")
|
||||
|
||||
group_train = []
|
||||
with open("mq2008.train.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_train.append(int(line.split("\n")[0]))
|
||||
|
||||
group_valid = []
|
||||
with open("mq2008.vali.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_valid.append(int(line.split("\n")[0]))
|
||||
|
||||
group_test = []
|
||||
with open("mq2008.test.group", "r") as f:
|
||||
data = f.readlines()
|
||||
for line in data:
|
||||
group_test.append(int(line.split("\n")[0]))
|
||||
|
||||
params = {'objective': 'rank:pairwise', 'learning_rate': 0.1,
|
||||
'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, group_train,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[group_valid])
|
||||
pred = model.predict(x_test)
|
||||
Reference in New Issue
Block a user