diff --git a/Makefile b/Makefile index 80848649d..0715ec379 100644 --- a/Makefile +++ b/Makefile @@ -5,9 +5,9 @@ export LDFLAGS= -pthread -lm # add include path to Rinternals.h here ifeq ($(no_omp),1) - export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP + export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP -funroll-loops else - export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp + export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp -funroll-loops endif # expose these flags to R CMD SHLIB @@ -18,11 +18,11 @@ BIN = xgboost OBJ = SLIB = wrapper/libxgboostwrapper.so RLIB = wrapper/libxgboostR.so -.PHONY: clean all R +.PHONY: clean all R python all: $(BIN) wrapper/libxgboostwrapper.so R: wrapper/libxgboostR.so - +python: wrapper/libxgboostwrapper.so xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp # now the wrapper takes in two files. io and wrapper part wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 2c48f501c..bd3adac08 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -117,17 +117,13 @@ class GBTree : public IGradBooster { } std::vector &preds = *out_preds; - preds.resize(0); + const size_t stride = info.num_row * mparam.num_output_group; + preds.resize(stride * (mparam.size_leaf_vector+1)); // start collecting the prediction utils::IIterator *iter = fmat.RowIterator(); iter->BeforeFirst(); while (iter->Next()) { const SparseBatch &batch = iter->Value(); - utils::Assert(batch.base_rowid * mparam.num_output_group == preds.size(), - "base_rowid is not set correctly"); - // output convention: nrow * k, where nrow is number of rows - // k is number of group - preds.resize(preds.size() + batch.size * mparam.num_output_group); // parallel over local batch const unsigned nsize = static_cast(batch.size); #pragma omp parallel for schedule(static) @@ -135,13 +131,13 @@ class GBTree : public IGradBooster { const int tid = omp_get_thread_num(); tree::RegTree::FVec &feats = thread_temp[tid]; int64_t ridx = static_cast(batch.base_rowid + i); - const unsigned root_idx = info.GetRoot(ridx); + utils::Assert(static_cast(ridx) < info.num_row, "data row index exceed bound"); // loop over output groups for (int gid = 0; gid < mparam.num_output_group; ++gid) { - preds[ridx * mparam.num_output_group + gid] = - this->Pred(batch[i], - buffer_offset < 0 ? -1 : buffer_offset+ridx, - gid, root_idx, &feats); + this->Pred(batch[i], + buffer_offset < 0 ? -1 : buffer_offset + ridx, + gid, info.GetRoot(ridx), &feats, + &preds[ridx * mparam.num_output_group + gid], stride); } } } @@ -211,24 +207,34 @@ class GBTree : public IGradBooster { mparam.num_trees += tparam.num_parallel_tree; } // make a prediction for a single instance - inline float Pred(const SparseBatch::Inst &inst, - int64_t buffer_index, - int bst_group, - unsigned root_index, - tree::RegTree::FVec *p_feats) { + inline void Pred(const SparseBatch::Inst &inst, + int64_t buffer_index, + int bst_group, + unsigned root_index, + tree::RegTree::FVec *p_feats, + float *out_pred, size_t stride) { size_t itop = 0; float psum = 0.0f; + // sum of leaf vector + std::vector vec_psum(mparam.size_leaf_vector, 0.0f); const int64_t bid = mparam.BufferOffset(buffer_index, bst_group); // load buffered results if any if (bid >= 0) { itop = pred_counter[bid]; psum = pred_buffer[bid]; + for (int i = 0; i < mparam.size_leaf_vector; ++i) { + vec_psum[i] = pred_buffer[bid + i + 1]; + } } if (itop != trees.size()) { p_feats->Fill(inst); for (size_t i = itop; i < trees.size(); ++i) { if (tree_info[i] == bst_group) { - psum += trees[i]->Predict(*p_feats, root_index); + int tid = trees[i]->GetLeafIndex(*p_feats, root_index); + psum += (*trees[i])[tid].leaf_value(); + for (int j = 0; j < mparam.size_leaf_vector; ++j) { + vec_psum[j] += trees[i]->leafvec(tid)[j]; + } } } p_feats->Drop(inst); @@ -237,8 +243,14 @@ class GBTree : public IGradBooster { if (bid >= 0) { pred_counter[bid] = static_cast(trees.size()); pred_buffer[bid] = psum; + for (int i = 0; i < mparam.size_leaf_vector; ++i) { + pred_buffer[bid + i + 1] = vec_psum[i]; + } + } + out_pred[0] = psum; + for (int i = 0; i < mparam.size_leaf_vector; ++i) { + out_pred[stride * (i + 1)] = vec_psum[i]; } - return psum; } // --- data structure --- /*! \brief training parameters */ @@ -291,14 +303,17 @@ class GBTree : public IGradBooster { * suppose we have n instance and k group, output will be k*n */ int num_output_group; + /*! \brief size of leaf vector needed in tree */ + int size_leaf_vector; /*! \brief reserved parameters */ - int reserved[32]; + int reserved[31]; /*! \brief constructor */ ModelParam(void) { num_trees = 0; num_roots = num_feature = 0; num_pbuffer = 0; num_output_group = 1; + size_leaf_vector = 0; memset(reserved, 0, sizeof(reserved)); } /*! @@ -311,10 +326,11 @@ class GBTree : public IGradBooster { if (!strcmp("num_output_group", name)) num_output_group = atol(val); if (!strcmp("bst:num_roots", name)) num_roots = atoi(val); if (!strcmp("bst:num_feature", name)) num_feature = atoi(val); + if (!strcmp("bst:size_leaf_vector", name)) size_leaf_vector = atoi(val); } /*! \return size of prediction buffer actually needed */ inline size_t PredBufferSize(void) const { - return num_output_group * num_pbuffer; + return num_output_group * num_pbuffer * (size_leaf_vector + 1); } /*! * \brief get the buffer offset given a buffer index and group id @@ -323,7 +339,7 @@ class GBTree : public IGradBooster { inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const { if (buffer_index < 0) return -1; utils::Check(buffer_index < num_pbuffer, "buffer_index exceed num_pbuffer"); - return buffer_index + num_pbuffer * bst_group; + return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1); } }; // training parameter diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index 0b207f4b9..50827b758 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -24,9 +24,10 @@ template struct EvalEWiseBase : public IEvaluator { virtual float Eval(const std::vector &preds, const MetaInfo &info) const { - utils::Check(preds.size() == info.labels.size(), + utils::Check(info.labels.size() != 0, "label set cannot be empty"); + utils::Check(preds.size() % info.labels.size() == 0, "label and prediction size not match"); - const unsigned ndata = static_cast(preds.size()); + const unsigned ndata = static_cast(info.labels.size()); float sum = 0.0, wsum = 0.0; #pragma omp parallel for reduction(+: sum, wsum) schedule(static) for (unsigned i = 0; i < ndata; ++i) { @@ -99,6 +100,45 @@ struct EvalMatchError : public EvalEWiseBase { } }; +/*! \brief ctest */ +struct EvalCTest: public IEvaluator { + EvalCTest(IEvaluator *base, const char *name) + : base_(base), name_(name) {} + virtual ~EvalCTest(void) { + delete base_; + } + virtual const char *Name(void) const { + return name_.c_str(); + } + virtual float Eval(const std::vector &preds, + const MetaInfo &info) const { + utils::Check(preds.size() % info.labels.size() == 0, + "label and prediction size not match"); + size_t ngroup = preds.size() / info.labels.size() - 1; + const unsigned ndata = static_cast(info.labels.size()); + utils::Check(ngroup > 1, "pred size does not meet requirement"); + utils::Check(ndata == info.info.fold_index.size(), "need fold index"); + double wsum = 0.0; + for (size_t k = 0; k < ngroup; ++k) { + std::vector tpred; + MetaInfo tinfo; + for (unsigned i = 0; i < ndata; ++i) { + if (info.info.fold_index[i] == k) { + tpred.push_back(preds[i + (k + 1) * ndata]); + tinfo.labels.push_back(info.labels[i]); + tinfo.weights.push_back(info.GetWeight(i)); + } + } + wsum += base_->Eval(tpred, tinfo); + } + return wsum / ngroup; + } + + private: + IEvaluator *base_; + std::string name_; +}; + /*! \brief AMS: also records best threshold */ struct EvalAMS : public IEvaluator { public: @@ -109,7 +149,7 @@ struct EvalAMS : public IEvaluator { } virtual float Eval(const std::vector &preds, const MetaInfo &info) const { - const unsigned ndata = static_cast(preds.size()); + const unsigned ndata = static_cast(info.labels.size()); utils::Check(info.weights.size() == ndata, "we need weight to evaluate ams"); std::vector< std::pair > rec(ndata); @@ -206,10 +246,14 @@ struct EvalPrecisionRatio : public IEvaluator{ struct EvalAuc : public IEvaluator { virtual float Eval(const std::vector &preds, const MetaInfo &info) const { - utils::Check(preds.size() == info.labels.size(), "label size predict size not match"); - std::vector tgptr(2, 0); tgptr[1] = static_cast(preds.size()); + utils::Check(info.labels.size() != 0, "label set cannot be empty"); + utils::Check(preds.size() % info.labels.size() == 0, + "label size predict size not match"); + std::vector tgptr(2, 0); + tgptr[1] = static_cast(info.labels.size()); + const std::vector &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr; - utils::Check(gptr.back() == preds.size(), + utils::Check(gptr.back() == info.labels.size(), "EvalAuc: group structure must match number of prediction"); const unsigned ngroup = static_cast(gptr.size() - 1); // sum statictis diff --git a/src/learner/evaluation.h b/src/learner/evaluation.h index d2134bfbd..90f4a5839 100644 --- a/src/learner/evaluation.h +++ b/src/learner/evaluation.h @@ -45,7 +45,9 @@ inline IEvaluator* CreateEvaluator(const char *name) { if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name); if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name); if (!strncmp(name, "map", 3)) return new EvalMAP(name); - if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name); + if (!strncmp(name, "ndcg", 4)) return new EvalNDCG(name); + if (!strncmp(name, "ct-", 3)) return new EvalCTest(CreateEvaluator(name+3), name); + utils::Error("unknown evaluation metric type: %s", name); return NULL; } diff --git a/src/learner/objective-inl.hpp b/src/learner/objective-inl.hpp index 7f7f08cc3..4b5b4f014 100644 --- a/src/learner/objective-inl.hpp +++ b/src/learner/objective-inl.hpp @@ -123,7 +123,7 @@ class RegLossObj : public IObjFunction{ float p = loss.PredTransform(preds[i]); float w = info.GetWeight(j); if (info.labels[j] == 1.0f) w *= scale_pos_weight; - gpair[j] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w, + gpair[i] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w, loss.SecondOrderGradient(p, info.labels[j]) * w); } } diff --git a/src/tree/model.h b/src/tree/model.h index af99a5145..f91e453f8 100644 --- a/src/tree/model.h +++ b/src/tree/model.h @@ -270,6 +270,7 @@ class TreeModel { param.num_nodes = param.num_roots; nodes.resize(param.num_nodes); stats.resize(param.num_nodes); + leaf_vector.resize(param.num_nodes * param.size_leaf_vector, 0.0f); for (int i = 0; i < param.num_nodes; i ++) { nodes[i].set_leaf(0.0f); nodes[i].set_parent(-1); diff --git a/src/tree/param.h b/src/tree/param.h index 5f02c065d..52c273749 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -22,10 +22,10 @@ struct TrainParam{ //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child float min_child_weight; - // weight decay parameter used to control leaf fitting + // L2 regularization factor float reg_lambda; - // reg method - int reg_method; + // L1 regularization factor + float reg_alpha; // default direction choice int default_direction; // whether we want to do subsample @@ -36,6 +36,8 @@ struct TrainParam{ float colsample_bytree; // speed optimization for dense column float opt_dense_col; + // leaf vector size + int size_leaf_vector; // number of threads to be used for tree construction, // if OpenMP is enabled, if equals 0, use system default int nthread; @@ -45,13 +47,14 @@ struct TrainParam{ min_child_weight = 1.0f; max_depth = 6; reg_lambda = 1.0f; - reg_method = 2; + reg_alpha = 0.0f; default_direction = 0; subsample = 1.0f; colsample_bytree = 1.0f; colsample_bylevel = 1.0f; opt_dense_col = 1.0f; nthread = 0; + size_leaf_vector = 0; } /*! * \brief set parameters from outside @@ -63,15 +66,17 @@ struct TrainParam{ if (!strcmp(name, "gamma")) min_split_loss = static_cast(atof(val)); if (!strcmp(name, "eta")) learning_rate = static_cast(atof(val)); if (!strcmp(name, "lambda")) reg_lambda = static_cast(atof(val)); + if (!strcmp(name, "alpha")) reg_alpha = static_cast(atof(val)); if (!strcmp(name, "learning_rate")) learning_rate = static_cast(atof(val)); if (!strcmp(name, "min_child_weight")) min_child_weight = static_cast(atof(val)); if (!strcmp(name, "min_split_loss")) min_split_loss = static_cast(atof(val)); if (!strcmp(name, "reg_lambda")) reg_lambda = static_cast(atof(val)); - if (!strcmp(name, "reg_method")) reg_method = atoi(val); + if (!strcmp(name, "reg_alpha")) reg_alpha = static_cast(atof(val)); if (!strcmp(name, "subsample")) subsample = static_cast(atof(val)); if (!strcmp(name, "colsample_bylevel")) colsample_bylevel = static_cast(atof(val)); if (!strcmp(name, "colsample_bytree")) colsample_bytree = static_cast(atof(val)); if (!strcmp(name, "opt_dense_col")) opt_dense_col = static_cast(atof(val)); + if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val); if (!strcmp(name, "max_depth")) max_depth = atoi(val); if (!strcmp(name, "nthread")) nthread = atoi(val); if (!strcmp(name, "default_direction")) { @@ -82,31 +87,31 @@ struct TrainParam{ } // calculate the cost of loss function inline double CalcGain(double sum_grad, double sum_hess) const { - if (sum_hess < min_child_weight) { - return 0.0; + if (sum_hess < min_child_weight) return 0.0; + if (reg_alpha == 0.0f) { + return Sqr(sum_grad) / (sum_hess + reg_lambda); + } else { + return Sqr(ThresholdL1(sum_grad, reg_alpha)) / (sum_hess + reg_lambda); } - switch (reg_method) { - case 1 : return Sqr(ThresholdL1(sum_grad, reg_lambda)) / sum_hess; - case 2 : return Sqr(sum_grad) / (sum_hess + reg_lambda); - case 3 : return - Sqr(ThresholdL1(sum_grad, 0.5 * reg_lambda)) / - (sum_hess + 0.5 * reg_lambda); - default: return Sqr(sum_grad) / sum_hess; + } + // calculate cost of loss function with four stati + inline double CalcGain(double sum_grad, double sum_hess, + double test_grad, double test_hess) const { + double w = CalcWeight(sum_grad, sum_hess); + double ret = test_grad * w + 0.5 * (test_hess + reg_lambda) * Sqr(w); + if (reg_alpha == 0.0f) { + return - 2.0 * ret; + } else { + return - 2.0 * (ret + reg_alpha * std::abs(w)); } } // calculate weight given the statistics inline double CalcWeight(double sum_grad, double sum_hess) const { - if (sum_hess < min_child_weight) { - return 0.0; + if (sum_hess < min_child_weight) return 0.0; + if (reg_alpha == 0.0f) { + return -sum_grad / (sum_hess + reg_lambda); } else { - switch (reg_method) { - case 1: return - ThresholdL1(sum_grad, reg_lambda) / sum_hess; - case 2: return - sum_grad / (sum_hess + reg_lambda); - case 3: return - - ThresholdL1(sum_grad, 0.5 * reg_lambda) / - (sum_hess + 0.5 * reg_lambda); - default: return - sum_grad / sum_hess; - } + return -ThresholdL1(sum_grad, reg_alpha) / (sum_hess + reg_lambda); } } /*! \brief whether need forward small to big search: default right */ @@ -153,6 +158,9 @@ struct GradStats { inline void Clear(void) { sum_grad = sum_hess = 0.0f; } + /*! \brief check if necessary information is ready */ + inline static void CheckInfo(const BoosterInfo &info) { + } /*! * \brief accumulate statistics, * \param gpair the vector storing the gradient statistics @@ -188,14 +196,88 @@ struct GradStats { } /*! \brief set leaf vector value based on statistics */ inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{ - } - protected: + } + // constructor to allow inheritance + GradStats(void) {} /*! \brief add statistics to the data */ inline void Add(double grad, double hess) { sum_grad += grad; sum_hess += hess; } }; +/*! \brief vectorized cv statistics */ +template +struct CVGradStats : public GradStats { + // additional statistics + GradStats train[vsize], valid[vsize]; + // constructor + explicit CVGradStats(const TrainParam ¶m) { + utils::Check(param.size_leaf_vector == vsize, + "CVGradStats: vsize must match size_leaf_vector"); + this->Clear(); + } + /*! \brief check if necessary information is ready */ + inline static void CheckInfo(const BoosterInfo &info) { + utils::Check(info.fold_index.size() != 0, + "CVGradStats: require fold_index"); + } + /*! \brief clear the statistics */ + inline void Clear(void) { + GradStats::Clear(); + for (unsigned i = 0; i < vsize; ++i) { + train[i].Clear(); valid[i].Clear(); + } + } + inline void Add(const std::vector &gpair, + const BoosterInfo &info, + bst_uint ridx) { + GradStats::Add(gpair[ridx].grad, gpair[ridx].hess); + const size_t step = info.fold_index.size(); + for (unsigned i = 0; i < vsize; ++i) { + const bst_gpair &b = gpair[(i + 1) * step + ridx]; + if (info.fold_index[ridx] == i) { + valid[i].Add(b.grad, b.hess); + } else { + train[i].Add(b.grad, b.hess); + } + } + } + /*! \brief calculate gain of the solution */ + inline double CalcGain(const TrainParam ¶m) const { + double ret = 0.0; + for (unsigned i = 0; i < vsize; ++i) { + ret += param.CalcGain(train[i].sum_grad, + train[i].sum_hess, + vsize * valid[i].sum_grad, + vsize * valid[i].sum_hess); + } + return ret / vsize; + } + /*! \brief add statistics to the data */ + inline void Add(const CVGradStats &b) { + GradStats::Add(b); + for (unsigned i = 0; i < vsize; ++i) { + train[i].Add(b.train[i]); + valid[i].Add(b.valid[i]); + } + } + /*! \brief set current value to a - b */ + inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) { + GradStats::SetSubstract(a, b); + for (int i = 0; i < vsize; ++i) { + train[i].SetSubstract(a.train[i], b.train[i]); + valid[i].SetSubstract(a.valid[i], b.valid[i]); + } + } + /*! \brief set leaf vector value based on statistics */ + inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{ + for (int i = 0; i < vsize; ++i) { + vec[i] = param.learning_rate * + param.CalcWeight(train[i].sum_grad, train[i].sum_hess); + } + } +}; + /*! * \brief statistics that is helpful to store * and represent a split solution for the tree diff --git a/src/tree/updater.h b/src/tree/updater.h index b33ee1833..91e9c4079 100644 --- a/src/tree/updater.h +++ b/src/tree/updater.h @@ -62,6 +62,8 @@ inline IUpdater* CreateUpdater(const char *name) { if (!strcmp(name, "prune")) return new TreePruner(); if (!strcmp(name, "refresh")) return new TreeRefresher(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); + if (!strcmp(name, "grow_colmaker2")) return new ColMaker >(); + if (!strcmp(name, "grow_colmaker5")) return new ColMaker >(); utils::Error("unknown updater:%s", name); return NULL; } diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 0c679e748..e26f2ada4 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -27,6 +27,7 @@ class ColMaker: public IUpdater { const FMatrix &fmat, const BoosterInfo &info, const std::vector &trees) { + TStats::CheckInfo(info); // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); @@ -81,7 +82,6 @@ class ColMaker: public IUpdater { RegTree *p_tree) { this->InitData(gpair, fmat, info.root_index, *p_tree); this->InitNewNode(qexpand, gpair, fmat, info, *p_tree); - for (int depth = 0; depth < param.max_depth; ++depth) { this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree); this->ResetPosition(this->qexpand, fmat, *p_tree);