Add LETOR MQ2008 for rank demo

This commit is contained in:
kalenhaha 2014-05-12 22:21:07 +08:00
parent 6648a15817
commit 5411e2a500
4 changed files with 56 additions and 16 deletions

View File

@ -1,13 +1 @@
Demonstrating how to use XGBoost accomplish regression tasks on computer hardware dataset https://archive.ics.uci.edu/ml/datasets/Computer+Hardware
Run: ./runexp.sh
Format of input: LIBSVM format
Format of ```featmap.txt: <featureid> <featurename> <q or i or int>\n ```:
- Feature id must be from 0 to number of features, in sorted order.
- i means this feature is binary indicator feature
- q means this feature is a quantitative value, such as age, time, can be missing
- int means this feature is integer value (when int is hinted, the decision boundary will be integer)
Explainations: https://github.com/tqchen/xgboost/wiki/Regression
The dataset for ranking demo is from LETOR04 MQ2008 fold1,http://research.microsoft.com/en-us/um/beijing/projects/letor/letor4download.aspx

View File

@ -10,6 +10,7 @@ objective="rank:pairwise"
#objective="lambdarank:map"
#objective="lambdarank:ndcg"
num_feature=50
# Tree Booster Parameters
# step size shrinkage
bst:eta = 1.0
@ -26,10 +27,10 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# The path of training data
data = "toy.train"
data = "mq2008.train"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "toy.eval"
eval[test] = "mq2008.vali"
# The path of test data
test:data = "toy.test"
test:data = "mq2008.test"

11
demo/rank/runexp.sh Normal file
View File

@ -0,0 +1,11 @@
python trans_data.py train.txt mq2008.train mq2008.train.group
python trans_data.py test.txt mq2008.test mq2008.test.group
python trans_data.py vali.txt mq2008.vali mq2008.vali.group
../../xgboost mq2008.conf
../../xgboost mq2008.conf task=pred model_in=0002.model
../../xgboost mq2008.conf task=dump model_in=0002.model name_dump=dump.raw.txt

40
demo/rank/trans_data.py Normal file
View File

@ -0,0 +1,40 @@
import sys
def save_data(group_data,output_feature,output_group):
if len(group_data) == 0:
return
output_group.write(str(len(group_data))+"\n")
for data in group_data:
output_feature.write(data[0] + " " + " ".join(data[2:]) + "\n")
if __name__ == "__main__":
if len(sys.argv) != 4:
print "Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]"
sys.exit(0)
input = open(sys.argv[1])
output_feature = open(sys.argv[2],"w")
output_group = open(sys.argv[3],"w")
group_data = []
group = ""
for line in input:
if not line:
break
if "#" in line:
line = line[:line.index("#")]
splits = line.strip().split(" ")
if splits[1] != group:
save_data(group_data,output_feature,output_group)
group_data = []
group = splits[1]
group_data.append(splits)
save_data(group_data,output_feature,output_group)
input.close()
output_feature.close()
output_group.close()