From 37a28376bb28d2be8fcc26607a65547528fb7d89 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 9 Feb 2015 11:04:19 -0800 Subject: [PATCH] complete lbfgs solver --- rabit-learn/solver/lbfgs.h | 350 ++++++++++++++++++++++++++++++------- 1 file changed, 290 insertions(+), 60 deletions(-) diff --git a/rabit-learn/solver/lbfgs.h b/rabit-learn/solver/lbfgs.h index 168a623f8..1e330aa0c 100644 --- a/rabit-learn/solver/lbfgs.h +++ b/rabit-learn/solver/lbfgs.h @@ -13,29 +13,180 @@ namespace rabit { /*! \brief namespace of solver for general problems */ namespace solver { -/*! \brief an L-BFGS solver */ +/*! + * \brief objective function for optimizers + * the objective function can also implement save/load + * to remember the state parameters that might need to remember + */ +template +class IObjFunction : public rabit::ISerializable { + public: + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + virtual void SetParam(const char *name, const char *val) = 0; + /*! + * \brief evaluate function values for a given weight + * \param weight weight of the function + * \param size size of the weight + */ + virtual double Eval(const DType *weight, size_t size) = 0; + /*! + * \brief initialize the weight before starting the solver + */ + virtual void InitModel(DType *weight, size_t size) = 0; + /*! + * \brief calculate gradient for a given weight + * \param out_grad used to store the gradient value of the function + * \param weight weight of the function + * \param size size of the weight + */ + virtual void CalcGrad(DType *out_grad, + const DType *weight, + size_t size); + /*! + * \brief add regularization gradient to the gradient if any + * this is used to add data set invariant regularization + * \param out_grad used to store the gradient value of the function + * \param weight weight of the function + * \param size size of the weight + */ + virtual void AddRegularization(DType *out_grad, + const DType *weight, + size_t size); + +}; + +/*! \brief a basic version L-BFGS solver */ template class LBFGSSolver { public: LBFGSSolver(void) { + // set default values reg_L1 = 0.0f; max_linesearch_iter = 1000; linesearch_backoff = 0.5f; + linesearch_c1 = 1e-4; + min_lbfgs_iter = 5; + max_lbfgs_iter = 1000; + lbfgs_stop_tol = 1e-6f; + silent = 0; } - // initialize the L-BFGS solver - inline void Init(size_t num_feature, size_t size_memory) { - mdot.Init(size_memory_); - hist.Init(num_feature, size_memory_); + virtual ~LBFGSSolver(void) {} + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + virtual void SetParam(const char *name, const char *val) { + if (!strcmp("num_feature", name)) { + gstate.num_feature = static_cast(atol(val)); + } + if (!strcmp("size_memory", name)) { + gstate.size_memory = static_cast(atol(val)); + } + if (!strcmp("reg_L1", name)) { + reg_L1 = atof(val); + } + if (!strcmp("linesearch_backoff", name)) { + linesearch_backoff = atof(val); + } + if (!strcmp("max_linesearch_iter", name)) { + max_linesearch_iter = atoi(val); + } + } + /*! + * \brief set objective function to optimize + * the objective function only need to evaluate and calculate + * gradient with respect to current subset of data + * \param obj the objective function we are looking for + */ + virtual void SetObjFunction(IObjFunction *obj) { + gstate.obj = obj; + } + /*! + * \brief initialize the LBFGS solver + * user must already set the objective function + */ + virtual void Init(void) { + utils::Check(gstate.obj != NULL, + "LBFGSSolver.Init must SetObjFunction first"); + if (rabit::LoadCheckPoint(&gstate, &hist) == 0) { + gstate.Init(); + hist.Init(gstate.num_feature, gstate.size_memory); + if (rabit::GetRank() == 0) { + gstate.obj->InitModel(gstate.weight, gstate.num_feature); + } + // broadcast initialize model + rabit::Broadcast(gstate.weight, + sizeof(DType) * gstate.num_feature, 0); + gstate.old_objval = this->Eval(gstate.weight); + gstate.init_objval = gstate.old_objval; + + if (silent == 0 && rabit::GetRank() == 0) { + rabit::TrackerPrintf + ("L-BFGS solver starts, num_feature=%lu, init_objval=%g\n", + gstate.num_feature, gstate.init_objval); + } + } + } + /*! + * \brief get the current weight vector + * note that if update function is called + * the content of weight vector is no longer valid + * \return weight vector + */ + virtual DType *GetWeight(void) { + return gstate.weight; + } + /*! + * \brief update the weight for one L-BFGS iteration + * \return whether stopping condition is met + */ + virtual bool UpdateOneIter(void) { + bool stop = false; + GlobalState &g = gstate; + g.obj->CalcGrad(g.grad, g.weight, g.num_feature); + rabit::Allreduce(g.grad, g.num_feature); + g.obj->AddRegularization(g.grad, g.weight, g.num_feature); + double vdot = FindChangeDirection(g.tempw, g.grad, g.weight); + int iter = BacktrackLineSearch(g.grad, g.tempw, g.weight, vdot); + utils::Check(iter < max_linesearch_iter, "line search failed"); + std::swap(g.weight, g.grad); + if (gstate.num_iteration > min_lbfgs_iter) { + if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) { + return true; + } + } + if (silent == 0 && rabit::GetRank() == 0) { + rabit::TrackerPrintf + ("[%d] L-BFGS: linesearch finishes in %d rounds, new_objval=%g, improvment=%g\n", + gstate.num_iteration, iter, + gstate.new_objval, + gstate.old_objval - gstate.new_objval); + } + gstate.old_objval = gstate.new_objval; + rabit::CheckPoint(&gstate, &hist); + return stop; + } + /*! \brief run optimization */ + virtual void Run(void) { + this->Init(); + while (gstate.num_iteration < max_lbfgs_iter) { + if (this->UpdateOneIter()) break; + } } - protected: // find the delta value, given gradient // return dot(dir, l1grad) virtual double FindChangeDirection(DType *dir, const DType *grad, const DType *weight) { - int m = static_cast(size_memory_); + int m = static_cast(gstate.size_memory); int n = static_cast(hist.num_useful()); + const size_t num_feature = gstate.num_feature; const DType *gsub = grad + range_begin_; const size_t nsub = range_end_ - range_begin_; double vdot; @@ -61,7 +212,7 @@ class LBFGSSolver { } rabit::Allreduce(BeginPtr(tmp), tmp.size()); for (size_t i = 0; i < tmp.size(); ++i) { - mdot.Get(idxset[i].first, idxset[i].second) = tmp[i]; + gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i]; } // BFGS steps std::vector alpha(n); @@ -71,14 +222,14 @@ class LBFGSSolver { for (int j = n - 1; j >= 0; --j) { double vsum = 0.0; for (size_t k = 0; k < delta.size(); ++k) { - vsum += delta[k] * mdot.Get(k, j); + vsum += delta[k] * gstate.DotBuf(k, j); } - alpha[j] = vsum / mdot.Get(j, m + j); + alpha[j] = vsum / gstate.DotBuf(j, m + j); delta[m + j] = delta[m + j] - alpha[j]; } // scale - double scale = mdot.Get(n - 1, m + n - 1) / - mdot.Get(m + n - 1, m + n - 1); + double scale = gstate.DotBuf(n - 1, m + n - 1) / + gstate.DotBuf(m + n - 1, m + n - 1); for (size_t k = 0; k < delta.size(); ++k) { delta[k] *= scale; } @@ -86,13 +237,13 @@ class LBFGSSolver { for (int j = 0; j < n; ++j) { double vsum = 0.0; for (size_t k = 0; k < delta.size(); ++k) { - vsum += delta[k] * mdot.Get(k, m + j); + vsum += delta[k] * gstate.DotBuf(k, m + j); } - double beta = vsum / mdot.Get(j, m + j); + double beta = vsum / gstate.DotBuf(j, m + j); delta[j] = delta[j] + (alpha[j] - beta); } // set all to zero - std::fill(dir, dir + num_feature_, 0.0f); + std::fill(dir, dir + num_feature, 0.0f); DType *dirsub = dir + range_begin_; for (int i = 0; i < n; ++i) { AddScale(dirsub, dirsub, hist[i], delta[i], nsub); @@ -102,14 +253,14 @@ class LBFGSSolver { FixDirL1Sign(dir + range_begin_, hist[2 * m], nsub); vdot = -Dot(dir + range_begin_, hist[2 * m], nsub); // allreduce to get full direction - rabit::Allreduce(dir, num_feature_); + rabit::Allreduce(dir, num_feature); rabit::Allreduce(&vdot, 1); } else { - SetL1Dir(dir, grad, weight, num_feature_); - vdot = -Dot(dir, dir, num_feature_); + SetL1Dir(dir, grad, weight, num_feature); + vdot = -Dot(dir, dir, num_feature); } // shift the history record - mdot.Shift(); hist.Shift(); + gstate.Shift(); hist.Shift(); // next n if (n < m) n += 1; hist.set_num_useful(n); @@ -119,29 +270,41 @@ class LBFGSSolver { } // line search for given direction // return whether there is a descent - virtual bool BacktrackLineSearch(DType *new_weight, - const DType *dir, - const DType *weight, - double dot_dir_l1grad) { + inline int BacktrackLineSearch(DType *new_weight, + const DType *dir, + const DType *weight, + double dot_dir_l1grad) { utils::Assert(dot_dir_l1grad < 0.0f, "gradient error"); double alpha = 1.0; double backoff = linesearch_backoff; // unit descent direction in first iter - if (hist.num_useful() == 1) { + if (gstate.num_iteration == 0) { + utils::Assert(hist.num_useful() == 1, "hist.nuseful"); alpha = 1.0f / std::sqrt(-dot_dir_l1grad); linesearch_backoff = 0.1f; } - double c1 = 1e-4; - double old_val = this->Eval(weight); - for (int iter = 0; true; ++iter) { - if (iter >= max_linesearch_iter) return false; - AddScale(new_weight, weight, dir, alpha, num_feature_); - this->FixWeightL1Sign(new_weight, weight, num_feature_); + int iter = 0; + + double old_val = gstate.old_objval; + double c1 = this->linesearch_c1; + while (true) { + const size_t num_feature = gstate.num_feature; + if (++iter >= max_linesearch_iter) return iter; + AddScale(new_weight, weight, dir, alpha, num_feature); + this->FixWeightL1Sign(new_weight, weight, num_feature); double new_val = this->Eval(new_weight); - if (new_val - old_val <= c1 * dot_dir_l1grad * alpha) break; + if (new_val - old_val <= c1 * dot_dir_l1grad * alpha) { + gstate.new_objval = new_val; break; + } alpha *= backoff; } - return true; + // hist[n - 1] = new_weight - weight + Minus(hist[hist.num_useful() - 1], + new_weight + range_begin_, + weight + range_begin_, + range_end_ - range_begin_); + gstate.num_iteration += 1; + return iter; } inline void SetL1Dir(DType *dst, const DType *grad, @@ -173,7 +336,7 @@ class LBFGSSolver { inline void FixDirL1Sign(DType *dir, const DType *steepdir, size_t size) { - if (reg_L1 > 0.0) { + if (reg_L1 != 0.0f) { for (size_t i = 0; i < size; ++i) { if (dir[i] * steepdir[i] <= 0.0f) { dir[i] = 0.0f; @@ -185,7 +348,7 @@ class LBFGSSolver { inline void FixWeightL1Sign(DType *new_weight, const DType *weight, size_t size) { - if (reg_L1 > 0.0) { + if (reg_L1 != 0.0f) { for (size_t i = 0; i < size; ++i) { if (new_weight[i] * weight[i] < 0.0f) { new_weight[i] = 0.0f; @@ -193,10 +356,21 @@ class LBFGSSolver { } } } - virtual double Eval(const DType *weight) { - return 0.0f; + inline double Eval(const DType *weight) { + double val = gstate.obj->Eval(weight, gstate.num_feature); + rabit::Allreduce(&val, 1); + if (reg_L1 != 0.0f) { + double l1norm = 0.0; + for (size_t i = 0; i < gstate.num_feature; ++i) { + l1norm += std::abs(weight[i]); + } + val += l1norm * reg_L1; + } + return val; } + private: + // helper functions // dst = lhs + rhs * scale inline static void AddScale(DType *dst, const DType *lhs, @@ -207,7 +381,7 @@ class LBFGSSolver { dst[i] = lhs[i] + rhs[i] * scale; } } - // dst = lhs + rhs + // dst = lhs - rhs inline static void Minus(DType *dst, const DType *lhs, const DType *rhs, @@ -216,6 +390,7 @@ class LBFGSSolver { dst[i] = lhs[i] - rhs[i]; } } + // return dot(lhs, rhs) inline static double Dot(const DType *lhs, const DType *rhs, size_t size) { @@ -225,7 +400,6 @@ class LBFGSSolver { } return res; } - // map rolling array index inline static size_t MapIndex(size_t i, size_t offset, size_t size_memory) { if (i == 2 * size_memory) return i; @@ -237,42 +411,93 @@ class LBFGSSolver { return (i + offset) % size_memory + size_memory; } } - // temp matrix to store the dot product - struct DotMatrix : public rabit::ISerializable { + // global solver state + struct GlobalState : public rabit::ISerializable { public: - // intilize the space of rolling array - inline void Init(size_t size_memory) { - size_memory_ = size_memory; - size_t n = size_memory_ * 2 + 1; - data.resize(n * n, 0.0); + // memory size of L-BFGS + size_t size_memory; + // number of iterations passed + size_t num_iteration; + // number of features in the solver + size_t num_feature; + // initialize objective value + double init_objval; + // history objective value + double old_objval; + // new objective value + double new_objval; + // objective function + IObjFunction *obj; + // temporal storage + DType *grad, *weight, *tempw; + // constructor + GlobalState(void) + : obj(NULL), grad(NULL), + weight(NULL), tempw(NULL) { + size_memory = 10; + num_iteration = 0; + num_feature = 0; + old_objval = 0.0; } - inline double &Get(size_t i, size_t j) { + ~GlobalState(void) { + if (grad != NULL) { + delete [] grad; + delete [] weight; + delete [] tempw; + } + } + // intilize the space of rolling array + inline void Init(void) { + size_t n = size_memory * 2 + 1; + data.resize(n * n, 0.0); + this->AllocSpace(); + } + inline double &DotBuf(size_t i, size_t j) { if (i > j) std::swap(i, j); - return data[MapIndex(i, offset_, size_memory_) * (size_memory_ * 2 + 1) + - MapIndex(j, offset_, size_memory_)]; + return data[MapIndex(i, offset_, size_memory) * (size_memory * 2 + 1) + + MapIndex(j, offset_, size_memory)]; } // load the shift array virtual void Load(rabit::IStream &fi) { - fi.Read(&size_memory_, sizeof(size_memory_)); + fi.Read(&size_memory, sizeof(size_memory)); + fi.Read(&num_iteration, sizeof(num_iteration)); + fi.Read(&num_feature, sizeof(num_feature)); + fi.Read(&init_objval, sizeof(init_objval)); + fi.Read(&old_objval, sizeof(old_objval)); fi.Read(&offset_, sizeof(offset_)); fi.Read(&data); + this->AllocSpace(); + fi.Read(weight, sizeof(DType) * num_feature); + obj->Load(fi); } // save the shift array virtual void Save(rabit::IStream &fo) const { - fo.Write(&size_memory_, sizeof(size_memory_)); + fo.Write(&size_memory, sizeof(size_memory)); + fo.Write(&num_iteration, sizeof(num_iteration)); + fo.Write(&num_feature, sizeof(num_feature)); + fo.Write(&init_objval, sizeof(init_objval)); + fo.Write(&old_objval, sizeof(old_objval)); fo.Write(&offset_, sizeof(offset_)); fo.Write(data); + fo.Write(weight, sizeof(DType) * num_feature); + obj->Save(fo); } inline void Shift(void) { - offset_ = (offset_ + 1) % size_memory_; + offset_ = (offset_ + 1) % size_memory; } - private: - // memory size of L-BFGS - size_t size_memory_; + private: // rolling offset in the current memory size_t offset_; std::vector data; + // allocate sapce + inline void AllocSpace(void) { + if (grad == NULL) { + grad = new DType[num_feature]; + weight = new DType[num_feature]; + tempw = new DType[num_feature]; + } + } }; /*! \brief rolling array that carries history information */ struct HistoryArray : public rabit::ISerializable { @@ -337,7 +562,7 @@ class LBFGSSolver { fi.Write((*this)[i + size_memory_], num_col_ * sizeof(DType)); } } - + private: // number of columns in each of array size_t num_col_; @@ -351,19 +576,24 @@ class LBFGSSolver { size_t offset_; // data pointer DType *dptr_; - }; + }; // data structure for LBFGS - DotMatrix mdot; + GlobalState gstate; HistoryArray hist; - size_t num_feature_; - size_t size_memory_; + // silent + int silent; + // the subrange of current node size_t range_begin_; size_t range_end_; - double old_fval; // L1 regularization co-efficient float reg_L1; + // c1 ratio for line search + float linesearch_c1; float linesearch_backoff; int max_linesearch_iter; + int max_lbfgs_iter; + int min_lbfgs_iter; + float lbfgs_stop_tol; }; } // namespace solver } // namespace rabit