From 37e1473ceaa2f671f7ea19a55045228d4f1cb013 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 15 May 2014 15:01:41 -0700 Subject: [PATCH] cleanup code --- demo/rank/mq2008.conf | 10 +++------- demo/rank/trans_data.py | 11 ++++++----- regrank/xgboost_regrank_eval.h | 22 ++++++---------------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/demo/rank/mq2008.conf b/demo/rank/mq2008.conf index bf7dfb47e..8145d77b3 100644 --- a/demo/rank/mq2008.conf +++ b/demo/rank/mq2008.conf @@ -1,29 +1,25 @@ # General Parameters, see comment for each definition # choose the tree booster, 0: tree, 1: linear booster_type = 0 -# this is the only difference with classification, use 0: linear regression -# when labels are in [0,1] we can also use 1: logistic regression -loss_type = 0 #objective="rank:pairwise" #objective="rank:softmax" #objective="lambdarank:map" #objective="lambdarank:ndcg" -num_feature=50 # Tree Booster Parameters # step size shrinkage -bst:eta = 1.0 +bst:eta = 0.1 # minimum loss reduction required to make a further partition bst:gamma = 1.0 # minimum sum of instance weight(hessian) needed in a child bst:min_child_weight = 1 # maximum depth of a tree bst:max_depth = 3 - +eval_metric='ndcg' # Task parameters # the number of round to do boosting -num_round = 2 +num_round = 4 # 0 means do not save any model except the final round model save_period = 0 # The path of training data diff --git a/demo/rank/trans_data.py b/demo/rank/trans_data.py index fe8fde753..3c9865106 100644 --- a/demo/rank/trans_data.py +++ b/demo/rank/trans_data.py @@ -6,20 +6,22 @@ def save_data(group_data,output_feature,output_group): output_group.write(str(len(group_data))+"\n") for data in group_data: - output_feature.write(data[0] + " " + " ".join(data[2:]) + "\n") + # only include nonzero features + feats = [ p for p in data[2:] if float(p.split(':')[1]) != 0.0 ] + output_feature.write(data[0] + " " + " ".join(feats) + "\n") if __name__ == "__main__": if len(sys.argv) != 4: print "Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]" sys.exit(0) - input = open(sys.argv[1]) + fi = open(sys.argv[1]) output_feature = open(sys.argv[2],"w") output_group = open(sys.argv[3],"w") group_data = [] group = "" - for line in input: + for line in fi: if not line: break if "#" in line: @@ -33,8 +35,7 @@ if __name__ == "__main__": save_data(group_data,output_feature,output_group) - input.close() + fi.close() output_feature.close() output_group.close() - diff --git a/regrank/xgboost_regrank_eval.h b/regrank/xgboost_regrank_eval.h index fcdabd68e..df7a85555 100644 --- a/regrank/xgboost_regrank_eval.h +++ b/regrank/xgboost_regrank_eval.h @@ -104,6 +104,7 @@ namespace xgboost{ public: EvalAMS(const char *name){ name_ = name; + // note: ams@0 will automatically select which ratio to go utils::Assert( sscanf(name, "ams@%f", &ratio_ ) == 1, "invalid ams format" ); } virtual float Eval(const std::vector &preds, @@ -152,14 +153,9 @@ namespace xgboost{ float ratio_; }; - /*! \brief Error */ + /*! \brief Error for multi-class classification, need exact match */ struct EvalMatchError : public IEvaluator{ public: - EvalMatchError(const char *name){ - name_ = name; - abs_ = 0; - if(!strcmp("mabserror", name)) abs_ =1; - } virtual float Eval(const std::vector &preds, const DMatrix::Info &info) const { const unsigned ndata = static_cast(preds.size()); @@ -168,19 +164,14 @@ namespace xgboost{ for (unsigned i = 0; i < ndata; ++i){ const float wt = info.GetWeight(i); int label = static_cast(info.labels[i]); - if( label < 0 && abs_ != 0 ) label = -label-1; - if (static_cast(preds[i]) != label ){ - sum += wt; - } + if (static_cast(preds[i]) != label ) sum += wt; wsum += wt; } return sum / wsum; } virtual const char *Name(void) const{ - return name_.c_str(); + return "merror"; } - int abs_; - std::string name_; }; @@ -328,7 +319,7 @@ namespace xgboost{ float idcg = this->CalcDCG(rec); std::sort(rec.begin(), rec.end(), CmpSecond); float dcg = this->CalcDCG(rec); - if( idcg == 0.0f ) return 0.0f; + if( idcg == 0.0f ) return 0.0f; else return dcg/idcg; } }; @@ -366,8 +357,7 @@ namespace xgboost{ } if (!strcmp(name, "rmse")) evals_.push_back(new EvalRMSE()); if (!strcmp(name, "error")) evals_.push_back(new EvalError()); - if (!strcmp(name, "merror")) evals_.push_back(new EvalMatchError("merror")); - if (!strcmp(name, "mabserror")) evals_.push_back(new EvalMatchError("mabserror")); + if (!strcmp(name, "merror")) evals_.push_back(new EvalMatchError()); if (!strcmp(name, "logloss")) evals_.push_back(new EvalLogLoss()); if (!strcmp(name, "auc")) evals_.push_back(new EvalAuc()); if (!strncmp(name, "ams@",4)) evals_.push_back(new EvalAMS(name));