From 248b2cf74d7a5de3e5aeeaf772a98daaf76ad587 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 6 May 2014 16:49:10 -0700 Subject: [PATCH] right group size --- regrank/xgboost_regrank_data.h | 7 ++++--- regrank/xgboost_regrank_eval.h | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/regrank/xgboost_regrank_data.h b/regrank/xgboost_regrank_data.h index 3653021c5..aff70928a 100644 --- a/regrank/xgboost_regrank_data.h +++ b/regrank/xgboost_regrank_data.h @@ -113,6 +113,7 @@ namespace xgboost{ if( fs.Read(&ngptr, sizeof(unsigned) ) != 0 ){ info.group_ptr.resize( ngptr ); utils::Assert( fs.Read(&info.group_ptr[0], sizeof(unsigned) * ngptr) != 0, "Load group file"); + utils::Assert( info.group_ptr.back() == data.NumRow(), "number of group must match number of record" ); } } fs.Close(); @@ -123,7 +124,7 @@ namespace xgboost{ printf("%ux%u matrix with %lu entries is loaded from %s\n", (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); if( info.group_ptr.size() != 0 ){ - printf("data contains %u groups\n", (unsigned)info.group_ptr.size() ); + printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1 ); } } this->TryLoadWeight(fname, silent); @@ -143,7 +144,7 @@ namespace xgboost{ utils::Assert( info.labels.size() == data.NumRow(), "label size is not consistent with feature matrix size" ); fs.Write(&info.labels[0], sizeof(float) * data.NumRow()); {// write out group ptr - unsigned ngptr = static_cast( info.group_ptr.size() ); + unsigned ngptr = static_cast( info.group_ptr.size() ); fs.Write(&ngptr, sizeof(unsigned) ); fs.Write(&info.group_ptr[0], sizeof(unsigned) * ngptr); } @@ -152,7 +153,7 @@ namespace xgboost{ printf("%ux%u matrix with %lu entries is saved to %s\n", (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); if( info.group_ptr.size() != 0 ){ - printf("data contains %u groups\n", (unsigned)info.group_ptr.size() ); + printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1 ); } } } diff --git a/regrank/xgboost_regrank_eval.h b/regrank/xgboost_regrank_eval.h index 0c03a1769..8d3b06e4a 100644 --- a/regrank/xgboost_regrank_eval.h +++ b/regrank/xgboost_regrank_eval.h @@ -160,7 +160,8 @@ namespace xgboost{ virtual float Eval(const std::vector &preds, const DMatrix::Info &info) const { const std::vector &gptr = info.group_ptr; - utils::Assert(gptr.size() != 0 && gptr.back() == preds.size(), "EvalAuc: group structure must match number of prediction"); + utils::Assert(gptr.size() != 0, "must specify group when constructing rank file"); + utils::Assert( gptr.back() == preds.size(), "EvalRanklist: group structure must match number of prediction"); const unsigned ngroup = static_cast(gptr.size() - 1); double sum_metric = 0.0f;