init version of lbfgs

This commit is contained in:
tqchen 2015-02-09 17:44:32 -08:00
parent 37a28376bb
commit 12ee049a74
9 changed files with 505 additions and 87 deletions

View File

@ -4,7 +4,7 @@ export CC = gcc
export CXX = g++
export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../../lib
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../../include -I../common
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include -I../common
.PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN)

View File

@ -1,24 +1,38 @@
#include <rabit.h>
/*!
* Copyright (c) 2015 by Contributors
* \file toolkit_util.h
* \brief simple data structure that could be used by model
*
* \author Tianqi Chen
*/
#ifndef RABIT_TOOLKIT_UTIL_H_
#define RABIT_TOOLKIT_UTIL_H_
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <limits>
#include <cmath>
#include <rabit.h>
namespace rabit {
// typedef index type
typedef unsigned index_t;
/*! \brief sparse matrix, CSR format */
struct SparseMat {
// sparse matrix entry
struct Entry {
// feature index
unsigned findex;
index_t findex;
// feature value
float fvalue;
};
// sparse vector
struct Vector {
const Entry *data;
unsigned length;
index_t length;
inline const Entry &operator[](size_t i) const {
return data[i];
}
@ -26,7 +40,7 @@ struct SparseMat {
inline Vector operator[](size_t i) const {
Vector v;
v.data = &data[0] + row_ptr[i];
v.length = static_cast<unsigned>(row_ptr[i + 1]-row_ptr[i]);
v.length = static_cast<index_t>(row_ptr[i + 1]-row_ptr[i]);
return v;
}
// load data from LibSVM format
@ -35,7 +49,13 @@ struct SparseMat {
if (!strcmp(fname, "stdin")) {
fi = stdin;
} else {
fi = utils::FopenCheck(fname, "r");
if (strchr(fname, '%') != NULL) {
char s_tmp[256];
snprintf(s_tmp, sizeof(s_tmp), fname, rabit::GetRank());
fi = utils::FopenCheck(s_tmp, "r");
} else {
fi = utils::FopenCheck(fname, "r");
}
}
row_ptr.clear();
row_ptr.push_back(0);
@ -45,9 +65,11 @@ struct SparseMat {
char tmp[1024];
while (fscanf(fi, "%s", tmp) == 1) {
Entry e;
if (sscanf(tmp, "%u:%f", &e.findex, &e.fvalue) == 2) {
unsigned long fidx;
if (sscanf(tmp, "%lu:%f", &fidx, &e.fvalue) == 2) {
e.findex = static_cast<index_t>(fidx);
data.push_back(e);
feat_dim = std::max(e.findex, feat_dim);
feat_dim = std::max(fidx, feat_dim);
} else {
if (!init) {
labels.push_back(label);
@ -61,6 +83,9 @@ struct SparseMat {
labels.push_back(label);
row_ptr.push_back(data.size());
feat_dim += 1;
utils::Check(feat_dim < std::numeric_limits<index_t>::max(),
"feature dimension exceed limit of index_t"\
"consider change the index_t to unsigned long");
// close the filed
if (fi != stdin) fclose(fi);
}
@ -68,7 +93,7 @@ struct SparseMat {
return row_ptr.size() - 1;
}
// maximum feature dimension
unsigned feat_dim;
size_t feat_dim;
std::vector<size_t> row_ptr;
std::vector<Entry> data;
std::vector<float> labels;
@ -115,3 +140,4 @@ inline int Random(int value) {
return rand() % value;
}
} // namespace rabit
#endif // RABIT_TOOLKIT_UTIL_H_

View File

@ -83,9 +83,12 @@ inline size_t GetCluster(const Matrix &centroids,
int main(int argc, char *argv[]) {
if (argc < 5) {
// intialize rabit engine
rabit::Init(argc, argv);
if (rabit::GetRank() == 0) {
rabit::TrackerPrintf("Usage: <data_dir> num_cluster max_iter <out_model>\n");
}
rabit::Finalize();
return 0;
}
clock_t tStart = clock();

View File

@ -0,0 +1,15 @@
# specify tensor path
BIN = linear.rabit
MOCKBIN=
MPIBIN =
# objectives that makes up rabit library
OBJ = linear.o
# common build script for programs
include ../common.mk
CFLAGS+=-fopenmp
linear.o: linear.cc ../../src/*.h linear.h ../solver/*.h
# dependenies here
linear.rabit: linear.o lib

View File

@ -0,0 +1,176 @@
#include "./linear.h"
namespace rabit {
namespace linear {
class LinearObjFunction : public solver::IObjFunction<float> {
public:
// training threads
int nthread;
// L2 regularization
float reg_L2;
// model
LinearModel model;
// training data
SparseMat dtrain;
// solver
solver::LBFGSSolver<float> lbfgs;
// constructor
LinearObjFunction(void) {
lbfgs.SetObjFunction(this);
nthread = 1;
reg_L2 = 0.0f;
model.weight = NULL;
task = "train";
model_in = "NULL";
}
virtual ~LinearObjFunction(void) {
if (model.weight != NULL) delete [] model.weight;
}
// set parameters
inline void SetParam(const char *name, const char *val) {
model.param.SetParam(name, val);
lbfgs.SetParam(name, val);
if (!strcmp(name, "num_feature")) {
char ndigit[30];
sprintf(ndigit, "%lu", model.param.num_feature + 1);
lbfgs.SetParam("num_dim", ndigit);
}
if (!strcmp(name, "reg_L2")) {
reg_L2 = static_cast<float>(atof(val));
}
if (!strcmp(name, "nthread")) {
nthread = atoi(val);
}
if (!strcmp(name, "task")) task = val;
if (!strcmp(name, "model_in")) model_in = val;
if (!strcmp(name, "model_out")) model_out = val;
}
inline void Run(void) {
if (model_in != "NULL") {
}
if (task == "train") {
lbfgs.Run();
} else if (task == "pred") {
} else if (task == "eval") {
} else {
utils::Error("unknown task=%s", task.c_str());
}
}
inline void LoadData(const char *fname) {
dtrain.Load(fname);
}
virtual size_t InitNumDim(void) {
if (model_in == "NULL") {
size_t ndim = dtrain.feat_dim;
rabit::Allreduce<rabit::op::Max>(&ndim, 1);
model.param.num_feature = std::max(ndim, model.param.num_feature);
}
return model.param.num_feature + 1;
}
virtual void InitModel(float *weight, size_t size) {
if (model_in == "NULL") {
memset(weight, 0.0f, size * sizeof(float));
model.param.InitBaseScore();
} else {
rabit::Broadcast(model.weight, size * sizeof(float), 0);
memcpy(weight, model.weight, size * sizeof(float));
}
}
// load model
virtual void Load(rabit::IStream &fi) {
fi.Read(&model.param, sizeof(model.param));
}
virtual void Save(rabit::IStream &fo) const {
fo.Write(&model.param, sizeof(model.param));
}
virtual double Eval(const float *weight, size_t size) {
if (nthread != 0) omp_set_num_threads(nthread);
utils::Check(size == model.param.num_feature + 1,
"size consistency check");
double sum_val = 0.0;
#pragma omp parallel for schedule(static) reduction(+:sum_val)
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
float py = model.param.PredictMargin(weight, dtrain[i]);
float fv = model.param.MarginToLoss(dtrain.labels[i], py);
sum_val += fv;
}
if (rabit::GetRank() == 0) {
// only add regularization once
if (reg_L2 != 0.0f) {
double sum_sqr = 0.0;
for (size_t i = 0; i < model.param.num_feature; ++i) {
sum_sqr += weight[i] * weight[i];
}
sum_val += 0.5 * reg_L2 * sum_sqr;
}
}
utils::Check(!std::isnan(sum_val), "nan occurs");
return sum_val;
}
virtual void CalcGrad(float *out_grad,
const float *weight,
size_t size) {
if (nthread != 0) omp_set_num_threads(nthread);
utils::Check(size == model.param.num_feature + 1,
"size consistency check");
memset(out_grad, 0.0f, sizeof(float) * size);
double sum_gbias = 0.0;
#pragma omp parallel for schedule(static) reduction(+:sum_gbias)
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
SparseMat::Vector v = dtrain[i];
float py = model.param.Predict(weight, v);
float grad = model.param.PredToGrad(dtrain.labels[i], py);
for (index_t j = 0; j < v.length; ++j) {
out_grad[v[j].findex] += v[j].fvalue * grad;
}
sum_gbias += grad;
}
out_grad[model.param.num_feature] = static_cast<float>(sum_gbias);
if (rabit::GetRank() == 0) {
// only add regularization once
if (reg_L2 != 0.0f) {
for (size_t i = 0; i < model.param.num_feature; ++i) {
out_grad[i] += reg_L2 * weight[i];
}
}
}
}
private:
std::string task;
std::string model_in;
std::string model_out;
};
} // namespace linear
} // namespace rabit
int main(int argc, char *argv[]) {
if (argc < 2) {
// intialize rabit engine
rabit::Init(argc, argv);
if (rabit::GetRank() == 0) {
rabit::TrackerPrintf("Usage: <data_in> param=val\n");
}
rabit::Finalize();
return 0;
}
rabit::linear::LinearObjFunction linear;
if (!strcmp(argv[1], "stdin")) {
linear.LoadData(argv[1]);
rabit::Init(argc, argv);
} else {
rabit::Init(argc, argv);
linear.LoadData(argv[1]);
}
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
linear.SetParam(name, val);
}
}
linear.Run();
rabit::Finalize();
return 0;
}

131
rabit-learn/linear/linear.h Normal file
View File

@ -0,0 +1,131 @@
/*!
* Copyright (c) 2015 by Contributors
* \file linear.h
* \brief Linear and Logistic regression
*
* \author Tianqi Chen
*/
#ifndef RABIT_LINEAR_H_
#define RABIT_LINEAR_H_
#include <omp.h>
#include "../common/toolkit_util.h"
#include "../solver/lbfgs.h"
namespace rabit {
namespace linear {
/*! \brief simple linear model */
struct LinearModel {
struct ModelParam {
/*! \brief global bias */
float base_score;
/*! \brief number of features */
size_t num_feature;
/*! \brief loss type*/
int loss_type;
// reserved field
int reserved[16];
// constructor
ModelParam(void) {
base_score = 0.5f;
num_feature = 0;
loss_type = 1;
std::memset(reserved, 0, sizeof(reserved));
}
// initialize base score
inline void InitBaseScore(void) {
utils::Check(base_score > 0.0f && base_score < 1.0f,
"base_score must be in (0,1) for logistic loss");
base_score = -std::log(1.0f / base_score - 1.0f);
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name, const char *val) {
using namespace std;
if (!strcmp("base_score", name)) {
base_score = static_cast<float>(atof(val));
}
if (!strcmp("num_feature", name)) {
num_feature = static_cast<size_t>(atol(val));
}
if (!strcmp("objective", name)) {
if (!strcmp("linear", val)) {
loss_type = 0;
} else if (!strcmp("logistic", val)) {
loss_type = 1;
} else {
utils::Error("unknown objective type %s\n", val);
}
}
}
// transform margin to prediction
inline float MarginToPred(float margin) const {
if (loss_type == 1) {
return 1.0f / (1.0f + std::exp(-margin));
} else {
return margin;
}
}
// margin to loss
inline float MarginToLoss(float label, float margin) const {
if (loss_type == 1) {
float nlogprob;
if (margin > 0.0f) {
nlogprob = std::log(1.0f + std::exp(-margin));
} else {
nlogprob = -margin + std::log(1.0f + std::exp(margin));
}
return label * nlogprob +
(1.0f -label) * (margin + nlogprob);
} else {
float diff = margin - label;
return 0.5f * diff * diff;
}
}
inline float PredToGrad(float label, float pred) const {
return pred - label;
}
inline float PredictMargin(const float *weight,
const SparseMat::Vector &v) const {
// weight[num_feature] is bias
float sum = base_score + weight[num_feature];
for (unsigned i = 0; i < v.length; ++i) {
sum += weight[v[i].findex] * v[i].fvalue;
}
return sum;
}
inline float Predict(const float *weight,
const SparseMat::Vector &v) const {
return MarginToPred(PredictMargin(weight, v));
}
};
// model parameter
ModelParam param;
// weight corresponding to the model
float *weight;
LinearModel(void) : weight(NULL) {
}
~LinearModel(void) {
if (weight != NULL) delete [] weight;
}
// load model
inline void Load(rabit::IStream &fi) {
fi.Read(&param, sizeof(param));
if (weight == NULL) {
weight = new float[param.num_feature + 1];
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
}
}
inline void Save(rabit::IStream &fo) const {
fo.Write(&param, sizeof(param));
fo.Write(weight, sizeof(float) * (param.num_feature + 1));
}
inline float Predict(const SparseMat::Vector &v) const {
return param.Predict(weight, v);
}
};
} // namespace linear
} // namespace rabit
#endif // RABIT_LINEAR_H_

View File

@ -0,0 +1,15 @@
#!/bin/bash
if [[ $# -lt 1 ]]
then
echo "Usage: nprocess"
exit -1
fi
rm -rf mushroom.row* *.model
k=$1
# split the lib svm file into k subfiles
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}"

View File

@ -0,0 +1,24 @@
#!/usr/bin/python
import sys
import random
# split libsvm file into different rows
if len(sys.argv) < 4:
print ('Usage:<fin> <fo> k')
exit(0)
random.seed(10)
k = int(sys.argv[3])
fi = open( sys.argv[1], 'r' )
fos = []
for i in range(k):
fos.append(open( sys.argv[2]+'.row%d' % i, 'w' ))
for l in open(sys.argv[1]):
i = random.randint(0, k-1)
fos[i].write(l)
for f in fos:
f.close()

View File

@ -21,12 +21,8 @@ namespace solver {
template<typename DType>
class IObjFunction : public rabit::ISerializable {
public:
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
virtual void SetParam(const char *name, const char *val) = 0;
// destructor
virtual ~IObjFunction(void){}
/*!
* \brief evaluate function values for a given weight
* \param weight weight of the function
@ -34,7 +30,13 @@ class IObjFunction : public rabit::ISerializable {
*/
virtual double Eval(const DType *weight, size_t size) = 0;
/*!
* \brief initialize the weight before starting the solver
* \return number of feature dimension to be allocated
* only called once during initialization
*/
virtual size_t InitNumDim(void) = 0;
/*!
* \brief initialize the weight before starting the solver
* only called once for initialization
*/
virtual void InitModel(DType *weight, size_t size) = 0;
/*!
@ -45,18 +47,7 @@ class IObjFunction : public rabit::ISerializable {
*/
virtual void CalcGrad(DType *out_grad,
const DType *weight,
size_t size);
/*!
* \brief add regularization gradient to the gradient if any
* this is used to add data set invariant regularization
* \param out_grad used to store the gradient value of the function
* \param weight weight of the function
* \param size size of the weight
*/
virtual void AddRegularization(DType *out_grad,
const DType *weight,
size_t size);
size_t size) = 0;
};
/*! \brief a basic version L-BFGS solver */
@ -71,7 +62,7 @@ class LBFGSSolver {
linesearch_c1 = 1e-4;
min_lbfgs_iter = 5;
max_lbfgs_iter = 1000;
lbfgs_stop_tol = 1e-6f;
lbfgs_stop_tol = 1e-5f;
silent = 0;
}
virtual ~LBFGSSolver(void) {}
@ -81,17 +72,17 @@ class LBFGSSolver {
* \param val value of the parameter
*/
virtual void SetParam(const char *name, const char *val) {
if (!strcmp("num_feature", name)) {
gstate.num_feature = static_cast<size_t>(atol(val));
if (!strcmp("num_dim", name)) {
gstate.num_dim = static_cast<size_t>(atol(val));
}
if (!strcmp("size_memory", name)) {
gstate.size_memory = static_cast<size_t>(atol(val));
}
if (!strcmp("reg_L1", name)) {
reg_L1 = atof(val);
reg_L1 = static_cast<float>(atof(val));
}
if (!strcmp("linesearch_backoff", name)) {
linesearch_backoff = atof(val);
linesearch_backoff = static_cast<float>(atof(val));
}
if (!strcmp("max_linesearch_iter", name)) {
max_linesearch_iter = atoi(val);
@ -113,22 +104,35 @@ class LBFGSSolver {
virtual void Init(void) {
utils::Check(gstate.obj != NULL,
"LBFGSSolver.Init must SetObjFunction first");
if (rabit::LoadCheckPoint(&gstate, &hist) == 0) {
int version = rabit::LoadCheckPoint(&gstate, &hist);
if (version == 0) {
gstate.num_dim = gstate.obj->InitNumDim();
}
{
// decide parameter partition
size_t nproc = rabit::GetWorldSize();
size_t rank = rabit::GetRank();
size_t step = (gstate.num_dim + nproc - 1) / nproc;
// upper align
step = (step + 7) / 8 * 8;
utils::Assert(step * nproc >= gstate.num_dim, "BUG");
range_begin_ = std::min(rank * step, gstate.num_dim);
range_end_ = std::min((rank + 1) * step, gstate.num_dim);
}
if (version == 0) {
gstate.Init();
hist.Init(gstate.num_feature, gstate.size_memory);
if (rabit::GetRank() == 0) {
gstate.obj->InitModel(gstate.weight, gstate.num_feature);
}
hist.Init(range_end_ - range_begin_, gstate.size_memory);
gstate.obj->InitModel(gstate.weight, gstate.num_dim);
// broadcast initialize model
rabit::Broadcast(gstate.weight,
sizeof(DType) * gstate.num_feature, 0);
sizeof(DType) * gstate.num_dim, 0);
gstate.old_objval = this->Eval(gstate.weight);
gstate.init_objval = gstate.old_objval;
if (silent == 0 && rabit::GetRank() == 0) {
rabit::TrackerPrintf
("L-BFGS solver starts, num_feature=%lu, init_objval=%g\n",
gstate.num_feature, gstate.init_objval);
("L-BFGS solver starts, num_dim=%lu, init_objval=%g, size_memory=%lu\n",
gstate.num_dim, gstate.init_objval, gstate.size_memory);
}
}
}
@ -148,13 +152,16 @@ class LBFGSSolver {
virtual bool UpdateOneIter(void) {
bool stop = false;
GlobalState &g = gstate;
g.obj->CalcGrad(g.grad, g.weight, g.num_feature);
rabit::Allreduce<rabit::op::Sum>(g.grad, g.num_feature);
g.obj->AddRegularization(g.grad, g.weight, g.num_feature);
g.obj->CalcGrad(g.grad, g.weight, g.num_dim);
rabit::Allreduce<rabit::op::Sum>(g.grad, g.num_dim);
// find change direction
double vdot = FindChangeDirection(g.tempw, g.grad, g.weight);
// line-search, g.grad is now new weight
int iter = BacktrackLineSearch(g.grad, g.tempw, g.weight, vdot);
utils::Check(iter < max_linesearch_iter, "line search failed");
// swap new weight
std::swap(g.weight, g.grad);
// check stop condition
if (gstate.num_iteration > min_lbfgs_iter) {
if (g.old_objval - g.new_objval < lbfgs_stop_tol * g.init_objval) {
return true;
@ -177,6 +184,10 @@ class LBFGSSolver {
while (gstate.num_iteration < max_lbfgs_iter) {
if (this->UpdateOneIter()) break;
}
if (silent == 0 && rabit::GetRank() == 0) {
rabit::TrackerPrintf
("L-BFGS: finishes at iteration %d\n", gstate.num_iteration);
}
}
protected:
// find the delta value, given gradient
@ -186,7 +197,13 @@ class LBFGSSolver {
const DType *weight) {
int m = static_cast<int>(gstate.size_memory);
int n = static_cast<int>(hist.num_useful());
const size_t num_feature = gstate.num_feature;
if (n < m) {
utils::Assert(hist.num_useful() == gstate.num_iteration,
"BUG2");
} else {
utils::Assert(n == m, "BUG3");
}
const size_t num_dim = gstate.num_dim;
const DType *gsub = grad + range_begin_;
const size_t nsub = range_end_ - range_begin_;
double vdot;
@ -214,10 +231,11 @@ class LBFGSSolver {
for (size_t i = 0; i < tmp.size(); ++i) {
gstate.DotBuf(idxset[i].first, idxset[i].second) = tmp[i];
}
// BFGS steps
// BFGS steps, use vector-free update
// parameterize vector using basis in hist
std::vector<double> alpha(n);
std::vector<double> delta(2 * n + 1, 0.0);
delta[2 * n] = 1.0;
std::vector<double> delta(2 * m + 1, 0.0);
delta[2 * m] = 1.0;
// backward step
for (int j = n - 1; j >= 0; --j) {
double vsum = 0.0;
@ -243,26 +261,30 @@ class LBFGSSolver {
delta[j] = delta[j] + (alpha[j] - beta);
}
// set all to zero
std::fill(dir, dir + num_feature, 0.0f);
std::fill(dir, dir + num_dim, 0.0f);
DType *dirsub = dir + range_begin_;
for (int i = 0; i < n; ++i) {
AddScale(dirsub, dirsub, hist[i], delta[i], nsub);
AddScale(dirsub, dirsub, hist[m + i], delta[m + i], nsub);
}
AddScale(dirsub, dirsub, hist[2 * m], delta[2 * m], nsub);
FixDirL1Sign(dir + range_begin_, hist[2 * m], nsub);
vdot = -Dot(dir + range_begin_, hist[2 * m], nsub);
for (int i = 0; i < n; ++i) {
AddScale(dirsub, dirsub, hist[i], delta[i], nsub);
}
FixDirL1Sign(dirsub, hist[2 * m], nsub);
vdot = -Dot(dirsub, hist[2 * m], nsub);
// allreduce to get full direction
rabit::Allreduce<rabit::op::Sum>(dir, num_feature);
rabit::Allreduce<rabit::op::Sum>(dir, num_dim);
rabit::Allreduce<rabit::op::Sum>(&vdot, 1);
} else {
SetL1Dir(dir, grad, weight, num_feature);
vdot = -Dot(dir, dir, num_feature);
} else {
SetL1Dir(dir, grad, weight, num_dim);
vdot = -Dot(dir, dir, num_dim);
}
// shift the history record
if (n < m) {
n += 1;
} else {
gstate.Shift(); hist.Shift();
}
// shift the history record
gstate.Shift(); hist.Shift();
// next n
if (n < m) n += 1;
hist.set_num_useful(n);
// copy gradient to hist[m + n - 1]
memcpy(hist[m + n - 1], gsub, nsub * sizeof(DType));
@ -274,24 +296,25 @@ class LBFGSSolver {
const DType *dir,
const DType *weight,
double dot_dir_l1grad) {
utils::Assert(dot_dir_l1grad < 0.0f, "gradient error");
utils::Assert(dot_dir_l1grad < 0.0f,
"gradient error, dotv=%g", dot_dir_l1grad);
double alpha = 1.0;
double backoff = linesearch_backoff;
// unit descent direction in first iter
if (gstate.num_iteration == 0) {
utils::Assert(hist.num_useful() == 1, "hist.nuseful");
alpha = 1.0f / std::sqrt(-dot_dir_l1grad);
linesearch_backoff = 0.1f;
backoff = 0.1f;
}
int iter = 0;
double old_val = gstate.old_objval;
double c1 = this->linesearch_c1;
while (true) {
const size_t num_feature = gstate.num_feature;
const size_t num_dim = gstate.num_dim;
if (++iter >= max_linesearch_iter) return iter;
AddScale(new_weight, weight, dir, alpha, num_feature);
this->FixWeightL1Sign(new_weight, weight, num_feature);
AddScale(new_weight, weight, dir, alpha, num_dim);
this->FixWeightL1Sign(new_weight, weight, num_dim);
double new_val = this->Eval(new_weight);
if (new_val - old_val <= c1 * dot_dir_l1grad * alpha) {
gstate.new_objval = new_val; break;
@ -306,15 +329,16 @@ class LBFGSSolver {
gstate.num_iteration += 1;
return iter;
}
// OWL-QN step for L1 regularization
inline void SetL1Dir(DType *dst,
const DType *grad,
const DType *weight,
size_t size) {
const DType *grad,
const DType *weight,
size_t size) {
if (reg_L1 == 0.0) {
for (size_t i = 0; i < size; ++i) {
dst[i] = -grad[i];
}
} else{
} else {
for (size_t i = 0; i < size; ++i) {
if (weight[i] > 0.0f) {
dst[i] = -grad[i] - reg_L1;
@ -332,7 +356,7 @@ class LBFGSSolver {
}
}
}
// fix direction sign to be consistent with proposal
// OWL-QN step: fix direction sign to be consistent with proposal
inline void FixDirL1Sign(DType *dir,
const DType *steepdir,
size_t size) {
@ -344,7 +368,7 @@ class LBFGSSolver {
}
}
}
// fix direction sign to be consistent with proposal
// QWL-QN step: fix direction sign to be consistent with proposal
inline void FixWeightL1Sign(DType *new_weight,
const DType *weight,
size_t size) {
@ -357,11 +381,11 @@ class LBFGSSolver {
}
}
inline double Eval(const DType *weight) {
double val = gstate.obj->Eval(weight, gstate.num_feature);
double val = gstate.obj->Eval(weight, gstate.num_dim);
rabit::Allreduce<rabit::op::Sum>(&val, 1);
if (reg_L1 != 0.0f) {
double l1norm = 0.0;
for (size_t i = 0; i < gstate.num_feature; ++i) {
for (size_t i = 0; i < gstate.num_dim; ++i) {
l1norm += std::abs(weight[i]);
}
val += l1norm * reg_L1;
@ -401,13 +425,14 @@ class LBFGSSolver {
return res;
}
// map rolling array index
inline static size_t MapIndex(size_t i, size_t offset, size_t size_memory) {
inline static size_t MapIndex(size_t i, size_t offset,
size_t size_memory) {
if (i == 2 * size_memory) return i;
if (i < size_memory) {
return (i + offset) % size_memory;
} else {
utils::Assert(i < 2 * size_memory,
"MapIndex: index exceed bound");
"MapIndex: index exceed bound, i=%lu", i);
return (i + offset) % size_memory + size_memory;
}
}
@ -419,7 +444,7 @@ class LBFGSSolver {
// number of iterations passed
size_t num_iteration;
// number of features in the solver
size_t num_feature;
size_t num_dim;
// initialize objective value
double init_objval;
// history objective value
@ -436,7 +461,7 @@ class LBFGSSolver {
weight(NULL), tempw(NULL) {
size_memory = 10;
num_iteration = 0;
num_feature = 0;
num_dim = 0;
old_objval = 0.0;
}
~GlobalState(void) {
@ -461,25 +486,25 @@ class LBFGSSolver {
virtual void Load(rabit::IStream &fi) {
fi.Read(&size_memory, sizeof(size_memory));
fi.Read(&num_iteration, sizeof(num_iteration));
fi.Read(&num_feature, sizeof(num_feature));
fi.Read(&num_dim, sizeof(num_dim));
fi.Read(&init_objval, sizeof(init_objval));
fi.Read(&old_objval, sizeof(old_objval));
fi.Read(&offset_, sizeof(offset_));
fi.Read(&data);
this->AllocSpace();
fi.Read(weight, sizeof(DType) * num_feature);
fi.Read(weight, sizeof(DType) * num_dim);
obj->Load(fi);
}
// save the shift array
virtual void Save(rabit::IStream &fo) const {
fo.Write(&size_memory, sizeof(size_memory));
fo.Write(&num_iteration, sizeof(num_iteration));
fo.Write(&num_feature, sizeof(num_feature));
fo.Write(&num_dim, sizeof(num_dim));
fo.Write(&init_objval, sizeof(init_objval));
fo.Write(&old_objval, sizeof(old_objval));
fo.Write(&offset_, sizeof(offset_));
fo.Write(data);
fo.Write(weight, sizeof(DType) * num_feature);
fo.Write(weight, sizeof(DType) * num_dim);
obj->Save(fo);
}
inline void Shift(void) {
@ -493,16 +518,18 @@ class LBFGSSolver {
// allocate sapce
inline void AllocSpace(void) {
if (grad == NULL) {
grad = new DType[num_feature];
weight = new DType[num_feature];
tempw = new DType[num_feature];
grad = new DType[num_dim];
weight = new DType[num_dim];
tempw = new DType[num_dim];
}
}
};
/*! \brief rolling array that carries history information */
struct HistoryArray : public rabit::ISerializable {
public:
HistoryArray(void) : dptr_(NULL) {}
HistoryArray(void) : dptr_(NULL) {
num_useful_ = 0;
}
~HistoryArray(void) {
if (dptr_ != NULL) delete [] dptr_;
}
@ -516,7 +543,8 @@ class LBFGSSolver {
size_memory_ = size_memory;
stride_ = num_col_;
offset_ = 0;
dptr_ = new DType[num_col_ * stride_];
size_t n = size_memory * 2 + 1;
dptr_ = new DType[n * stride_];
}
// fetch element from rolling array
inline const DType *operator[](size_t i) const {
@ -541,7 +569,7 @@ class LBFGSSolver {
}
// set number of useful memory
inline void set_num_useful(size_t num_useful) {
utils::Assert(num_useful < size_memory_,
utils::Assert(num_useful <= size_memory_,
"num_useful exceed bound");
num_useful_ = num_useful;
}