tqchen 57b5d7873f Squashed 'subtree/rabit/' changes from d4ec037..28ca7be
28ca7be add linear readme
ca4b20f add linear readme
1133628 add linear readme
6a11676 update docs
a607047 Update build.sh
2c1cfd8 complete yarn
4f28e32 change formater
2fbda81 fix stdin input
3258bcf checkin yarn master
67ebf81 allow setup from env variables
9b6bf57 fix hdfs
395d5c2 add make system
88ce767 refactor io, initial hdfs file access need test
19be870 chgs
a1bd3c6 Merge branch 'master' of ssh://github.com/tqchen/rabit
1a573f9 introduce input split
29476f1 fix timer issue

git-subtree-dir: subtree/rabit
git-subtree-split: 28ca7becbdf6503e6b1398588a969efb164c9701
2015-03-09 13:28:38 -07:00

227 lines
6.6 KiB
C++

#include "./linear.h"
#include "../io/io.h"
namespace rabit {
namespace linear {
class LinearObjFunction : public solver::IObjFunction<float> {
public:
// training threads
int nthread;
// L2 regularization
float reg_L2;
// model
LinearModel model;
// training data
SparseMat dtrain;
// solver
solver::LBFGSSolver<float> lbfgs;
// constructor
LinearObjFunction(void) {
lbfgs.SetObjFunction(this);
nthread = 1;
reg_L2 = 0.0f;
model.weight = NULL;
task = "train";
model_in = "NULL";
name_pred = "pred.txt";
model_out = "final.model";
}
virtual ~LinearObjFunction(void) {
}
// set parameters
inline void SetParam(const char *name, const char *val) {
model.param.SetParam(name, val);
lbfgs.SetParam(name, val);
if (!strcmp(name, "num_feature")) {
char ndigit[30];
sprintf(ndigit, "%lu", model.param.num_feature + 1);
lbfgs.SetParam("num_dim", ndigit);
}
if (!strcmp(name, "reg_L2")) {
reg_L2 = static_cast<float>(atof(val));
}
if (!strcmp(name, "nthread")) {
nthread = atoi(val);
}
if (!strcmp(name, "task")) task = val;
if (!strcmp(name, "model_in")) model_in = val;
if (!strcmp(name, "model_out")) model_out = val;
if (!strcmp(name, "name_pred")) name_pred = val;
}
inline void Run(void) {
if (model_in != "NULL") {
this->LoadModel(model_in.c_str());
}
if (task == "train") {
lbfgs.Run();
if (rabit::GetRank() == 0) {
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
}
} else if (task == "pred") {
this->TaskPred();
} else {
utils::Error("unknown task=%s", task.c_str());
}
}
inline void TaskPred(void) {
utils::Check(model_in != "NULL",
"must set model_in for task=pred");
FILE *fp = utils::FopenCheck(name_pred.c_str(), "w");
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
float pred = model.Predict(dtrain[i]);
fprintf(fp, "%g\n", pred);
}
fclose(fp);
printf("Finishing writing to %s\n", name_pred.c_str());
}
inline void LoadModel(const char *fname) {
IStream *fi = io::CreateStream(fname, "r");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
// base64 format
if (header == "bs64") {
io::Base64InStream bsin(fi);
bsin.InitPosition();
model.Load(bsin);
} else if (header == "binf") {
model.Load(*fi);
} else {
utils::Error("invalid model file");
}
delete fi;
}
inline void SaveModel(const char *fname,
const float *wptr,
bool save_base64 = false) {
IStream *fo = io::CreateStream(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5);
io::Base64OutStream bout(fo);
model.Save(bout, wptr);
bout.Finish('\n');
} else {
fo->Write("binf", 4);
model.Save(*fo, wptr);
}
delete fo;
}
inline void LoadData(const char *fname) {
dtrain.Load(fname);
}
virtual size_t InitNumDim(void) {
if (model_in == "NULL") {
size_t ndim = dtrain.feat_dim;
rabit::Allreduce<rabit::op::Max>(&ndim, 1);
model.param.num_feature = std::max(ndim, model.param.num_feature);
}
return model.param.num_feature + 1;
}
virtual void InitModel(float *weight, size_t size) {
if (model_in == "NULL") {
memset(weight, 0.0f, size * sizeof(float));
model.param.InitBaseScore();
} else {
rabit::Broadcast(model.weight, size * sizeof(float), 0);
memcpy(weight, model.weight, size * sizeof(float));
}
}
// load model
virtual void Load(rabit::IStream &fi) {
fi.Read(&model.param, sizeof(model.param));
}
virtual void Save(rabit::IStream &fo) const {
fo.Write(&model.param, sizeof(model.param));
}
virtual double Eval(const float *weight, size_t size) {
if (nthread != 0) omp_set_num_threads(nthread);
utils::Check(size == model.param.num_feature + 1,
"size consistency check");
double sum_val = 0.0;
#pragma omp parallel for schedule(static) reduction(+:sum_val)
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
float py = model.param.PredictMargin(weight, dtrain[i]);
float fv = model.param.MarginToLoss(dtrain.labels[i], py);
sum_val += fv;
}
if (rabit::GetRank() == 0) {
// only add regularization once
if (reg_L2 != 0.0f) {
double sum_sqr = 0.0;
for (size_t i = 0; i < model.param.num_feature; ++i) {
sum_sqr += weight[i] * weight[i];
}
sum_val += 0.5 * reg_L2 * sum_sqr;
}
}
utils::Check(!std::isnan(sum_val), "nan occurs");
return sum_val;
}
virtual void CalcGrad(float *out_grad,
const float *weight,
size_t size) {
if (nthread != 0) omp_set_num_threads(nthread);
utils::Check(size == model.param.num_feature + 1,
"size consistency check");
memset(out_grad, 0.0f, sizeof(float) * size);
double sum_gbias = 0.0;
#pragma omp parallel for schedule(static) reduction(+:sum_gbias)
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
SparseMat::Vector v = dtrain[i];
float py = model.param.Predict(weight, v);
float grad = model.param.PredToGrad(dtrain.labels[i], py);
for (index_t j = 0; j < v.length; ++j) {
out_grad[v[j].findex] += v[j].fvalue * grad;
}
sum_gbias += grad;
}
out_grad[model.param.num_feature] = static_cast<float>(sum_gbias);
if (rabit::GetRank() == 0) {
// only add regularization once
if (reg_L2 != 0.0f) {
for (size_t i = 0; i < model.param.num_feature; ++i) {
out_grad[i] += reg_L2 * weight[i];
}
}
}
}
private:
std::string task;
std::string model_in;
std::string model_out;
std::string name_pred;
};
} // namespace linear
} // namespace rabit
int main(int argc, char *argv[]) {
if (argc < 2) {
// intialize rabit engine
rabit::Init(argc, argv);
if (rabit::GetRank() == 0) {
rabit::TrackerPrintf("Usage: <data_in> param=val\n");
}
rabit::Finalize();
return 0;
}
rabit::linear::LinearObjFunction linear;
if (!strcmp(argv[1], "stdin")) {
linear.LoadData(argv[1]);
rabit::Init(argc, argv);
} else {
rabit::Init(argc, argv);
linear.LoadData(argv[1]);
}
for (int i = 2; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
linear.SetParam(name, val);
}
}
linear.Run();
rabit::Finalize();
return 0;
}