diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index 50827b758..14e66c5b7 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -115,6 +115,7 @@ struct EvalCTest: public IEvaluator { utils::Check(preds.size() % info.labels.size() == 0, "label and prediction size not match"); size_t ngroup = preds.size() / info.labels.size() - 1; + ngroup = 1; const unsigned ndata = static_cast(info.labels.size()); utils::Check(ngroup > 1, "pred size does not meet requirement"); utils::Check(ndata == info.info.fold_index.size(), "need fold index"); @@ -208,9 +209,11 @@ struct EvalPrecisionRatio : public IEvaluator{ } virtual float Eval(const std::vector &preds, const MetaInfo &info) const { - utils::Assert(preds.size() == info.labels.size(), "label size predict size not match"); + utils::Check(info.labels.size() != 0, "label set cannot be empty"); + utils::Assert(preds.size() % info.labels.size() == 0, + "label size predict size not match"); std::vector< std::pair > rec; - for (size_t j = 0; j < preds.size(); ++j) { + for (size_t j = 0; j < info.labels.size(); ++j) { rec.push_back(std::make_pair(preds[j], static_cast(j))); } std::sort(rec.begin(), rec.end(), CmpFirst); diff --git a/src/tree/param.h b/src/tree/param.h index 52c273749..b1bfe69a3 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -244,6 +244,7 @@ struct CVGradStats : public GradStats { } /*! \brief calculate gain of the solution */ inline double CalcGain(const TrainParam ¶m) const { + return param.CalcGain(train[0].sum_grad, train[0].sum_hess); double ret = 0.0; for (unsigned i = 0; i < vsize; ++i) { ret += param.CalcGain(train[i].sum_grad, diff --git a/src/tree/updater.h b/src/tree/updater.h index 91e9c4079..fa8605594 100644 --- a/src/tree/updater.h +++ b/src/tree/updater.h @@ -63,7 +63,7 @@ inline IUpdater* CreateUpdater(const char *name) { if (!strcmp(name, "refresh")) return new TreeRefresher(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); if (!strcmp(name, "grow_colmaker2")) return new ColMaker >(); - if (!strcmp(name, "grow_colmaker5")) return new ColMaker >(); + // if (!strcmp(name, "grow_colmaker5")) return new ColMaker >(); utils::Error("unknown updater:%s", name); return NULL; } diff --git a/wrapper/xgboost_wrapper.cpp b/wrapper/xgboost_wrapper.cpp index 975d48015..3c97e4475 100644 --- a/wrapper/xgboost_wrapper.cpp +++ b/wrapper/xgboost_wrapper.cpp @@ -154,6 +154,9 @@ extern "C"{ if (src.info.info.root_index.size() != 0) { ret.info.info.root_index.push_back(src.info.info.root_index[ridx]); } + if (src.info.info.fold_index.size() != 0) { + ret.info.info.fold_index.push_back(src.info.info.fold_index[ridx]); + } } return p_ret; }