xgboost/demo/rank/trans_data.py
2014-05-12 22:21:07 +08:00

41 lines
999 B
Python

import sys
def save_data(group_data,output_feature,output_group):
if len(group_data) == 0:
return
output_group.write(str(len(group_data))+"\n")
for data in group_data:
output_feature.write(data[0] + " " + " ".join(data[2:]) + "\n")
if __name__ == "__main__":
if len(sys.argv) != 4:
print "Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]"
sys.exit(0)
input = open(sys.argv[1])
output_feature = open(sys.argv[2],"w")
output_group = open(sys.argv[3],"w")
group_data = []
group = ""
for line in input:
if not line:
break
if "#" in line:
line = line[:line.index("#")]
splits = line.strip().split(" ")
if splits[1] != group:
save_data(group_data,output_feature,output_group)
group_data = []
group = splits[1]
group_data.append(splits)
save_data(group_data,output_feature,output_group)
input.close()
output_feature.close()
output_group.close()