From 2ac8cdb873f9abed0b7a33cefc6c0f7070ac0fad Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 22 Aug 2014 19:27:33 -0700 Subject: [PATCH] check in linear model --- python/xgboost_wrapper.cpp | 2 +- src/data.h | 8 +- src/gbm/gblinear-inl.hpp | 262 ++++++++++++++++++++++++++++++ src/gbm/gbm.h | 13 +- src/gbm/gbtree-inl.hpp | 7 +- src/learner/evaluation-inl.hpp | 2 +- src/learner/learner-inl.hpp | 2 +- src/tree/updater_colmaker-inl.hpp | 10 +- 8 files changed, 287 insertions(+), 19 deletions(-) create mode 100644 src/gbm/gblinear-inl.hpp diff --git a/python/xgboost_wrapper.cpp b/python/xgboost_wrapper.cpp index df05d9521..7bc25eb40 100644 --- a/python/xgboost_wrapper.cpp +++ b/python/xgboost_wrapper.cpp @@ -37,7 +37,7 @@ class Booster: public learner::BoostLearner { for (unsigned j = 0; j < ndata; ++j) { gpair_[j] = bst_gpair(grad[j], hess[j]); } - gbm_->DoBoost(gpair_, train.fmat, train.info.info); + gbm_->DoBoost(train.fmat, train.info.info, &gpair_); } inline void CheckInitModel(void) { if (!init_model) { diff --git a/src/data.h b/src/data.h index 61d61e6a0..6f8297311 100644 --- a/src/data.h +++ b/src/data.h @@ -217,7 +217,7 @@ class FMatrixS : public FMatrixInterface{ utils::Check(this->HaveColAccess(), "NumCol:need column access"); return col_ptr_.size() - 1; } - /*! \brief get number of buffered rows */ + /*! \brief get number of buffered rows */ inline const std::vector buffered_rowset(void) const { return buffered_rowset_; } @@ -333,7 +333,7 @@ class FMatrixS : public FMatrixInterface{ while (iter_->Next()) { const SparseBatch &batch = iter_->Value(); for (size_t i = 0; i < batch.size; ++i) { - if (pkeep==1.0f || random::SampleBinary(pkeep)) { + if (pkeep == 1.0f || random::SampleBinary(pkeep)) { buffered_rowset_.push_back(batch.base_rowid+i); SparseBatch::Inst inst = batch[i]; for (bst_uint j = 0; j < inst.length; ++j) { @@ -349,9 +349,9 @@ class FMatrixS : public FMatrixInterface{ while (iter_->Next()) { const SparseBatch &batch = iter_->Value(); for (size_t i = 0; i < batch.size; ++i) { - if (ktop < buffered_rowset_.size() && + if (ktop < buffered_rowset_.size() && buffered_rowset_[ktop] == batch.base_rowid+i) { - ++ ktop; + ++ktop; SparseBatch::Inst inst = batch[i]; for (bst_uint j = 0; j < inst.length; ++j) { builder.PushElem(inst[j].findex, diff --git a/src/gbm/gblinear-inl.hpp b/src/gbm/gblinear-inl.hpp new file mode 100644 index 000000000..0c346d687 --- /dev/null +++ b/src/gbm/gblinear-inl.hpp @@ -0,0 +1,262 @@ +#ifndef XGBOOST_GBM_GBLINEAR_INL_HPP_ +#define XGBOOST_GBM_GBLINEAR_INL_HPP_ +/*! + * \file gblinear-inl.hpp + * \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net + * the update rule is parallel coordinate descent (shotgun) + * \author Tianqi Chen + */ +#include +#include +#include +#include "./gbm.h" +#include "../tree/updater.h" + +namespace xgboost { +namespace gbm { +/*! + * \brief gradient boosted linear model + * \tparam FMatrix the data type updater taking + */ +template +class GBLinear : public IGradBooster { + public: + virtual ~GBLinear(void) { + } + // set model parameters + virtual void SetParam(const char *name, const char *val) { + if (!strncmp(name, "bst:", 4)) { + param.SetParam(name + 4, val); + } + if (model.weight.size() == 0) { + model.param.SetParam(name, val); + } + } + virtual void LoadModel(utils::IStream &fi) { + model.LoadModel(fi); + } + virtual void SaveModel(utils::IStream &fo) const { + model.SaveModel(fo); + } + virtual void InitModel(void) { + model.InitModel(); + } + virtual void DoBoost(const FMatrix &fmat, + const BoosterInfo &info, + std::vector *in_gpair) { + this->InitFeatIndex(fmat); + std::vector &gpair = *in_gpair; + const int ngroup = model.param.num_output_group; + const std::vector &rowset = fmat.buffered_rowset(); + // for all the output group + for (int gid = 0; gid < ngroup; ++gid) { + double sum_grad = 0.0, sum_hess = 0.0; + const unsigned ndata = static_cast(rowset.size()); + #pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess) + for (unsigned i = 0; i < ndata; ++i) { + bst_gpair &p = gpair[rowset[i] * ngroup + gid]; + if (p.hess >= 0.0f) { + sum_grad += p.grad; sum_hess += p.hess; + } + } + // remove bias effect + double dw = param.learning_rate * param.CalcDeltaBias(sum_grad, sum_hess, model.bias()[gid]); + model.bias()[gid] += dw; + // update grad value + #pragma omp parallel for schedule(static) + for (unsigned i = 0; i < ndata; ++i) { + bst_gpair &p = gpair[rowset[i] * ngroup + gid]; + if (p.hess >= 0.0f) { + p.grad += p.hess * dw; + } + } + } + // number of features + const unsigned nfeat = static_cast(feat_index.size()); + #pragma omp parallel for schedule(static) + for (unsigned i = 0; i < nfeat; ++i) { + const bst_uint fid = feat_index[i]; + for (int gid = 0; gid < ngroup; ++gid) { + double sum_grad = 0.0, sum_hess = 0.0; + for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) { + const float v = it.fvalue(); + bst_gpair &p = gpair[it.rindex() * ngroup + gid]; + if (p.hess < 0.0f) continue; + sum_grad += p.grad * v; + sum_hess += p.hess * v * v; + } + float &w = model[fid][gid]; + double dw = param.learning_rate * param.CalcDelta(sum_grad, sum_hess, w); + w += dw; + // update grad value + for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) { + bst_gpair &p = gpair[it.rindex() * ngroup + gid]; + if (p.hess < 0.0f) continue; + p.grad += p.hess * it.fvalue() * dw; + } + } + } + } + + virtual void Predict(const FMatrix &fmat, + int64_t buffer_offset, + const BoosterInfo &info, + std::vector *out_preds) { + std::vector &preds = *out_preds; + preds.resize(0); + // start collecting the prediction + utils::IIterator *iter = fmat.RowIterator(); + iter->BeforeFirst(); + const int ngroup = model.param.num_output_group; + while (iter->Next()) { + const SparseBatch &batch = iter->Value(); + utils::Assert(batch.base_rowid * ngroup == preds.size(), + "base_rowid is not set correctly"); + // output convention: nrow * k, where nrow is number of rows + // k is number of group + preds.resize(preds.size() + batch.size * ngroup); + // parallel over local batch + const unsigned nsize = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (unsigned i = 0; i < nsize; ++i) { + const size_t ridx = batch.base_rowid + i; + // loop over output groups + for (int gid = 0; gid < ngroup; ++gid) { + this->Pred(batch[i], &preds[ridx * ngroup]); + } + } + } + } + virtual std::vector DumpModel(const utils::FeatMap& fmap, int option) { + utils::Error("gblinear does not support dump model"); + return std::vector(); + } + + protected: + inline void InitFeatIndex(const FMatrix &fmat) { + if (feat_index.size() != 0) return; + // initialize feature index + unsigned ncol = static_cast(fmat.NumCol()); + feat_index.reserve(ncol); + for (unsigned i = 0; i < ncol; ++i) { + if (fmat.GetColSize(i) != 0) { + feat_index.push_back(i); + } + } + random::Shuffle(feat_index); + } + inline void Pred(const SparseBatch::Inst &inst, float *preds) { + for (int gid = 0; gid < model.param.num_output_group; ++gid) { + float psum = model.bias()[gid]; + for (bst_uint i = 0; i < inst.length; ++i) { + psum += inst[i].fvalue * model[inst[i].findex][gid]; + } + preds[gid] = psum; + } + } + // training parameter + struct ParamTrain { + /*! \brief learning_rate */ + float learning_rate; + /*! \brief regularization weight for L2 norm */ + float reg_lambda; + /*! \brief regularization weight for L1 norm */ + float reg_alpha; + /*! \brief regularization weight for L2 norm in bias */ + float reg_lambda_bias; + // parameter + ParamTrain(void) { + reg_alpha = 0.0f; + reg_lambda = 0.0f; + reg_lambda_bias = 0.0f; + learning_rate = 1.0f; + } + inline void SetParam(const char *name, const char *val) { + // sync-names + if (!strcmp("eta", name)) learning_rate = static_cast(atof(val)); + if (!strcmp("lambda", name)) reg_lambda = static_cast(atof(val)); + if (!strcmp( "alpha", name)) reg_alpha = static_cast(atof(val)); + if (!strcmp( "lambda_bias", name)) reg_lambda_bias = static_cast(atof(val)); + // real names + if (!strcmp( "learning_rate", name)) learning_rate = static_cast(atof(val)); + if (!strcmp( "reg_lambda", name)) reg_lambda = static_cast(atof(val)); + if (!strcmp( "reg_alpha", name)) reg_alpha = static_cast(atof(val)); + if (!strcmp( "reg_lambda_bias", name)) reg_lambda_bias = static_cast(atof(val)); + } + // given original weight calculate delta + inline double CalcDelta(double sum_grad, double sum_hess, double w) { + if (sum_hess < 1e-5f) return 0.0f; + double tmp = w - (sum_grad + reg_lambda * w) / (sum_hess + reg_lambda); + if (tmp >=0) { + return std::max(-(sum_grad + reg_lambda * w + reg_alpha) / (sum_hess + reg_lambda), -w); + } else { + return std::min(-(sum_grad + reg_lambda * w - reg_alpha) / (sum_hess + reg_lambda), -w); + } + } + // given original weight calculate delta bias + inline double CalcDeltaBias(double sum_grad, double sum_hess, double w) { + return - (sum_grad + reg_lambda_bias * w) / (sum_hess + reg_lambda_bias); + } + }; + // model for linear booster + class Model { + public: + // model parameter + struct Param { + // number of feature dimension + int num_feature; + // number of output group + int num_output_group; + // reserved field + int reserved[32]; + // constructor + Param(void) { + num_feature = 0; + num_output_group = 1; + memset(reserved, 0, sizeof(reserved)); + } + inline void SetParam(const char *name, const char *val) { + if (!strcmp(name, "bst:num_feature")) num_feature = atoi(val); + if (!strcmp(name, "num_output_group")) num_output_group = atoi(val); + } + }; + // parameter + Param param; + // weight for each of feature, bias is the last one + std::vector weight; + // initialize the model parameter + inline void InitModel(void) { + // bias is the last weight + weight.resize((param.num_feature + 1) * param.num_output_group); + std::fill(weight.begin(), weight.end(), 0.0f); + } + // save the model to file + inline void SaveModel(utils::IStream &fo) const { + fo.Write(¶m, sizeof(Param)); + fo.Write(weight); + } + // load model from file + inline void LoadModel(utils::IStream &fi) { + utils::Assert(fi.Read(¶m, sizeof(Param)) != 0, "Load LinearBooster"); + fi.Read(&weight); + } + // model bias + inline float* bias(void) { + return &weight[param.num_feature * param.num_output_group]; + } + // get i-th weight + inline float* operator[](size_t i) { + return &weight[i * param.num_output_group]; + } + }; + // model field + Model model; + // training parameter + ParamTrain param; + // Per feature: shuffle index of each feature index + std::vector feat_index; +}; + +} // namespace gbm +} // namespace xgboost +#endif // XGBOOST_GBM_GBLINEAR_INL_HPP_ diff --git a/src/gbm/gbm.h b/src/gbm/gbm.h index f47adfdd2..7b551553a 100644 --- a/src/gbm/gbm.h +++ b/src/gbm/gbm.h @@ -41,13 +41,14 @@ class IGradBooster { virtual void InitModel(void) = 0; /*! * \brief peform update to the model(boosting) - * \param gpair the gradient pair statistics of the data * \param fmat feature matrix that provide access to features * \param info meta information about training + * \param in_gpair address of the gradient pair statistics of the data + * the booster may change content of gpair */ - virtual void DoBoost(const std::vector &gpair, - const FMatrix &fmat, - const BoosterInfo &info) = 0; + virtual void DoBoost(const FMatrix &fmat, + const BoosterInfo &info, + std::vector *in_gpair) = 0; /*! * \brief generate predictions for given feature matrix * \param fmat feature matrix @@ -74,12 +75,16 @@ class IGradBooster { }; } // namespace gbm } // namespace xgboost + #include "gbtree-inl.hpp" +#include "gblinear-inl.hpp" + namespace xgboost { namespace gbm { template inline IGradBooster* CreateGradBooster(const char *name) { if (!strcmp("gbtree", name)) return new GBTree(); + if (!strcmp("gblinear", name)) return new GBLinear(); utils::Error("unknown booster type: %s", name); return NULL; } diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 3fa0f4dd7..0e001a4e8 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -82,9 +82,10 @@ class GBTree : public IGradBooster { utils::Assert(mparam.num_trees == 0, "GBTree: model already initialized"); utils::Assert(trees.size() == 0, "GBTree: model already initialized"); } - virtual void DoBoost(const std::vector &gpair, - const FMatrix &fmat, - const BoosterInfo &info) { + virtual void DoBoost(const FMatrix &fmat, + const BoosterInfo &info, + std::vector *in_gpair) { + const std::vector &gpair = *in_gpair; if (mparam.num_output_group == 1) { this->BoostNewTrees(gpair, fmat, info, 0); } else { diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index 69f0bb4d9..72085be46 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -28,7 +28,7 @@ struct EvalEWiseBase : public IEvaluator { "label and prediction size not match"); const unsigned ndata = static_cast(preds.size()); float sum = 0.0, wsum = 0.0; - #pragma omp parallel for reduction(+:sum, wsum) schedule(static) + #pragma omp parallel for reduction(+: sum, wsum) schedule(static) for (unsigned i = 0; i < ndata; ++i) { const float wt = info.GetWeight(i); sum += Derived::EvalRow(info.labels[i], preds[i]) * wt; diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 6d00c3090..9150b5379 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -164,7 +164,7 @@ class BoostLearner { inline void UpdateOneIter(int iter, const DMatrix &train) { this->PredictRaw(train, &preds_); obj_->GetGradient(preds_, train.info, iter, &gpair_); - gbm_->DoBoost(gpair_, train.fmat, train.info.info); + gbm_->DoBoost(train.fmat, train.info.info, &gpair_); } /*! * \brief evaluate the model for specific iteration diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 919dfcc28..afeccb206 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -81,7 +81,7 @@ class ColMaker: public IUpdater { RegTree *p_tree) { this->InitData(gpair, fmat, info.root_index, *p_tree); this->InitNewNode(qexpand, gpair, fmat, *p_tree); - + for (int depth = 0; depth < param.max_depth; ++depth) { this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree); this->ResetPosition(this->qexpand, fmat, *p_tree); @@ -89,7 +89,7 @@ class ColMaker: public IUpdater { this->InitNewNode(qexpand, gpair, fmat, *p_tree); // if nothing left to be expand, break if (qexpand.size() == 0) break; - } + } // set all the rest expanding nodes to leaf for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; @@ -182,7 +182,7 @@ class ColMaker: public IUpdater { } snode.resize(tree.param.num_nodes, NodeEntry()); } - const std::vector &rowset = fmat.buffered_rowset(); + const std::vector &rowset = fmat.buffered_rowset(); // setup position const unsigned ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) @@ -316,8 +316,8 @@ class ColMaker: public IUpdater { // step 1, set default direct nodes to default, and leaf nodes to -1 const unsigned ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) - for (unsigned i = 0; i < ndata; ++i) { - const bst_uint ridx = rowset[i]; + for (unsigned i = 0; i < ndata; ++i) { + const bst_uint ridx = rowset[i]; const int nid = position[ridx]; if (nid >= 0) { if (tree[nid].is_leaf()) {