This commit is contained in:
tqchen
2015-02-09 20:26:39 -08:00
parent 12ee049a74
commit 4a5b9e5f78
14 changed files with 8596 additions and 21 deletions

View File

@@ -0,0 +1,4 @@
Linear and Logistic Regression
====
* input format: LibSVM
* Example: [run-linear.sh](run-linear.sh)

View File

@@ -1,4 +1,6 @@
#include "./linear.h"
#include "../utils/io.h"
#include "../utils/base64.h"
namespace rabit {
namespace linear {
@@ -22,9 +24,10 @@ class LinearObjFunction : public solver::IObjFunction<float> {
model.weight = NULL;
task = "train";
model_in = "NULL";
name_pred = "pred.txt";
model_out = "final.model";
}
virtual ~LinearObjFunction(void) {
if (model.weight != NULL) delete [] model.weight;
}
// set parameters
inline void SetParam(const char *name, const char *val) {
@@ -44,20 +47,79 @@ class LinearObjFunction : public solver::IObjFunction<float> {
if (!strcmp(name, "task")) task = val;
if (!strcmp(name, "model_in")) model_in = val;
if (!strcmp(name, "model_out")) model_out = val;
if (!strcmp(name, "name_pred")) name_pred = val;
}
inline void Run(void) {
if (model_in != "NULL") {
this->LoadModel(model_in.c_str());
}
if (task == "train") {
lbfgs.Run();
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
} else if (task == "pred") {
} else if (task == "eval") {
this->TaskPred();
} else {
utils::Error("unknown task=%s", task.c_str());
}
}
inline void TaskPred(void) {
utils::Check(model_in != "NULL",
"must set model_in for task=pred");
FILE *fp = utils::FopenCheck(name_pred.c_str(), "w");
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
float pred = model.Predict(dtrain[i]);
fprintf(fp, "%g\n", pred);
}
fclose(fp);
printf("Finishing writing to %s\n", name_pred.c_str());
}
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
utils::FileStream fi(fp);
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
bsin.InitPosition();
model.Load(bsin);
fclose(fp);
return;
} else if (header == "binf") {
model.Load(fi);
fclose(fp);
return;
} else {
utils::Error("invalid model file");
}
}
inline void SaveModel(const char *fname,
const float *wptr,
bool save_base64 = false) {
FILE *fp;
bool use_stdout = false;
if (!strcmp(fname, "stdout")) {
fp = stdout;
use_stdout = true;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
model.Save(bout, wptr);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
model.Save(fo, wptr);
}
if (!use_stdout) {
fclose(fp);
}
}
inline void LoadData(const char *fname) {
dtrain.Load(fname);
}
@@ -137,11 +199,12 @@ class LinearObjFunction : public solver::IObjFunction<float> {
}
}
}
private:
std::string task;
std::string model_in;
std::string model_out;
std::string name_pred;
};
} // namespace linear
} // namespace rabit

View File

@@ -8,7 +8,7 @@
#ifndef RABIT_LINEAR_H_
#define RABIT_LINEAR_H_
#include <omp.h>
#include "../common/toolkit_util.h"
#include "../utils/data.h"
#include "../solver/lbfgs.h"
namespace rabit {
@@ -92,6 +92,7 @@ struct LinearModel {
// weight[num_feature] is bias
float sum = base_score + weight[num_feature];
for (unsigned i = 0; i < v.length; ++i) {
if (v[i].findex >= num_feature) continue;
sum += weight[v[i].findex] * v[i].fvalue;
}
return sum;
@@ -115,12 +116,13 @@ struct LinearModel {
fi.Read(&param, sizeof(param));
if (weight == NULL) {
weight = new float[param.num_feature + 1];
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
}
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
}
inline void Save(rabit::IStream &fo) const {
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
fo.Write(&param, sizeof(param));
fo.Write(weight, sizeof(float) * (param.num_feature + 1));
if (wptr == NULL) wptr = weight;
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
}
inline float Predict(const SparseMat::Vector &v) const {
return param.Predict(weight, v);

View File

@@ -12,4 +12,6 @@ k=$1
python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}"
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model