cleanup code

This commit is contained in:
tqchen 2014-05-15 15:01:41 -07:00
parent 3960ac9cb4
commit 37e1473cea
3 changed files with 15 additions and 28 deletions

View File

@ -1,29 +1,25 @@
# General Parameters, see comment for each definition # General Parameters, see comment for each definition
# choose the tree booster, 0: tree, 1: linear # choose the tree booster, 0: tree, 1: linear
booster_type = 0 booster_type = 0
# this is the only difference with classification, use 0: linear regression
# when labels are in [0,1] we can also use 1: logistic regression
loss_type = 0
#objective="rank:pairwise" #objective="rank:pairwise"
#objective="rank:softmax" #objective="rank:softmax"
#objective="lambdarank:map" #objective="lambdarank:map"
#objective="lambdarank:ndcg" #objective="lambdarank:ndcg"
num_feature=50
# Tree Booster Parameters # Tree Booster Parameters
# step size shrinkage # step size shrinkage
bst:eta = 1.0 bst:eta = 0.1
# minimum loss reduction required to make a further partition # minimum loss reduction required to make a further partition
bst:gamma = 1.0 bst:gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child # minimum sum of instance weight(hessian) needed in a child
bst:min_child_weight = 1 bst:min_child_weight = 1
# maximum depth of a tree # maximum depth of a tree
bst:max_depth = 3 bst:max_depth = 3
eval_metric='ndcg'
# Task parameters # Task parameters
# the number of round to do boosting # the number of round to do boosting
num_round = 2 num_round = 4
# 0 means do not save any model except the final round model # 0 means do not save any model except the final round model
save_period = 0 save_period = 0
# The path of training data # The path of training data

View File

@ -6,20 +6,22 @@ def save_data(group_data,output_feature,output_group):
output_group.write(str(len(group_data))+"\n") output_group.write(str(len(group_data))+"\n")
for data in group_data: for data in group_data:
output_feature.write(data[0] + " " + " ".join(data[2:]) + "\n") # only include nonzero features
feats = [ p for p in data[2:] if float(p.split(':')[1]) != 0.0 ]
output_feature.write(data[0] + " " + " ".join(feats) + "\n")
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 4: if len(sys.argv) != 4:
print "Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]" print "Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]"
sys.exit(0) sys.exit(0)
input = open(sys.argv[1]) fi = open(sys.argv[1])
output_feature = open(sys.argv[2],"w") output_feature = open(sys.argv[2],"w")
output_group = open(sys.argv[3],"w") output_group = open(sys.argv[3],"w")
group_data = [] group_data = []
group = "" group = ""
for line in input: for line in fi:
if not line: if not line:
break break
if "#" in line: if "#" in line:
@ -33,8 +35,7 @@ if __name__ == "__main__":
save_data(group_data,output_feature,output_group) save_data(group_data,output_feature,output_group)
input.close() fi.close()
output_feature.close() output_feature.close()
output_group.close() output_group.close()

View File

@ -104,6 +104,7 @@ namespace xgboost{
public: public:
EvalAMS(const char *name){ EvalAMS(const char *name){
name_ = name; name_ = name;
// note: ams@0 will automatically select which ratio to go
utils::Assert( sscanf(name, "ams@%f", &ratio_ ) == 1, "invalid ams format" ); utils::Assert( sscanf(name, "ams@%f", &ratio_ ) == 1, "invalid ams format" );
} }
virtual float Eval(const std::vector<float> &preds, virtual float Eval(const std::vector<float> &preds,
@ -152,14 +153,9 @@ namespace xgboost{
float ratio_; float ratio_;
}; };
/*! \brief Error */ /*! \brief Error for multi-class classification, need exact match */
struct EvalMatchError : public IEvaluator{ struct EvalMatchError : public IEvaluator{
public: public:
EvalMatchError(const char *name){
name_ = name;
abs_ = 0;
if(!strcmp("mabserror", name)) abs_ =1;
}
virtual float Eval(const std::vector<float> &preds, virtual float Eval(const std::vector<float> &preds,
const DMatrix::Info &info) const { const DMatrix::Info &info) const {
const unsigned ndata = static_cast<unsigned>(preds.size()); const unsigned ndata = static_cast<unsigned>(preds.size());
@ -168,19 +164,14 @@ namespace xgboost{
for (unsigned i = 0; i < ndata; ++i){ for (unsigned i = 0; i < ndata; ++i){
const float wt = info.GetWeight(i); const float wt = info.GetWeight(i);
int label = static_cast<int>(info.labels[i]); int label = static_cast<int>(info.labels[i]);
if( label < 0 && abs_ != 0 ) label = -label-1; if (static_cast<int>(preds[i]) != label ) sum += wt;
if (static_cast<int>(preds[i]) != label ){
sum += wt;
}
wsum += wt; wsum += wt;
} }
return sum / wsum; return sum / wsum;
} }
virtual const char *Name(void) const{ virtual const char *Name(void) const{
return name_.c_str(); return "merror";
} }
int abs_;
std::string name_;
}; };
@ -328,7 +319,7 @@ namespace xgboost{
float idcg = this->CalcDCG(rec); float idcg = this->CalcDCG(rec);
std::sort(rec.begin(), rec.end(), CmpSecond); std::sort(rec.begin(), rec.end(), CmpSecond);
float dcg = this->CalcDCG(rec); float dcg = this->CalcDCG(rec);
if( idcg == 0.0f ) return 0.0f; if( idcg == 0.0f ) return 0.0f;
else return dcg/idcg; else return dcg/idcg;
} }
}; };
@ -366,8 +357,7 @@ namespace xgboost{
} }
if (!strcmp(name, "rmse")) evals_.push_back(new EvalRMSE()); if (!strcmp(name, "rmse")) evals_.push_back(new EvalRMSE());
if (!strcmp(name, "error")) evals_.push_back(new EvalError()); if (!strcmp(name, "error")) evals_.push_back(new EvalError());
if (!strcmp(name, "merror")) evals_.push_back(new EvalMatchError("merror")); if (!strcmp(name, "merror")) evals_.push_back(new EvalMatchError());
if (!strcmp(name, "mabserror")) evals_.push_back(new EvalMatchError("mabserror"));
if (!strcmp(name, "logloss")) evals_.push_back(new EvalLogLoss()); if (!strcmp(name, "logloss")) evals_.push_back(new EvalLogLoss());
if (!strcmp(name, "auc")) evals_.push_back(new EvalAuc()); if (!strcmp(name, "auc")) evals_.push_back(new EvalAuc());
if (!strncmp(name, "ams@",4)) evals_.push_back(new EvalAMS(name)); if (!strncmp(name, "ams@",4)) evals_.push_back(new EvalAMS(name));