This commit is contained in:
tqchen 2015-02-09 20:26:39 -08:00
parent 12ee049a74
commit 4a5b9e5f78
14 changed files with 8596 additions and 21 deletions

View File

@ -4,7 +4,7 @@ export CC = gcc
export CXX = g++ export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm -L../../lib export LDFLAGS= -pthread -lm -L../../lib
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include -I../common export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
.PHONY: clean all lib mpi .PHONY: clean all lib mpi
all: $(BIN) $(MOCKBIN) all: $(BIN) $(MOCKBIN)

View File

@ -0,0 +1,2 @@
This folder contains processed example dataset used by the demos.
Copyright of the dataset belongs to the original copyright holder

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,126 @@
0 cap-shape=bell i
1 cap-shape=conical i
2 cap-shape=convex i
3 cap-shape=flat i
4 cap-shape=knobbed i
5 cap-shape=sunken i
6 cap-surface=fibrous i
7 cap-surface=grooves i
8 cap-surface=scaly i
9 cap-surface=smooth i
10 cap-color=brown i
11 cap-color=buff i
12 cap-color=cinnamon i
13 cap-color=gray i
14 cap-color=green i
15 cap-color=pink i
16 cap-color=purple i
17 cap-color=red i
18 cap-color=white i
19 cap-color=yellow i
20 bruises?=bruises i
21 bruises?=no i
22 odor=almond i
23 odor=anise i
24 odor=creosote i
25 odor=fishy i
26 odor=foul i
27 odor=musty i
28 odor=none i
29 odor=pungent i
30 odor=spicy i
31 gill-attachment=attached i
32 gill-attachment=descending i
33 gill-attachment=free i
34 gill-attachment=notched i
35 gill-spacing=close i
36 gill-spacing=crowded i
37 gill-spacing=distant i
38 gill-size=broad i
39 gill-size=narrow i
40 gill-color=black i
41 gill-color=brown i
42 gill-color=buff i
43 gill-color=chocolate i
44 gill-color=gray i
45 gill-color=green i
46 gill-color=orange i
47 gill-color=pink i
48 gill-color=purple i
49 gill-color=red i
50 gill-color=white i
51 gill-color=yellow i
52 stalk-shape=enlarging i
53 stalk-shape=tapering i
54 stalk-root=bulbous i
55 stalk-root=club i
56 stalk-root=cup i
57 stalk-root=equal i
58 stalk-root=rhizomorphs i
59 stalk-root=rooted i
60 stalk-root=missing i
61 stalk-surface-above-ring=fibrous i
62 stalk-surface-above-ring=scaly i
63 stalk-surface-above-ring=silky i
64 stalk-surface-above-ring=smooth i
65 stalk-surface-below-ring=fibrous i
66 stalk-surface-below-ring=scaly i
67 stalk-surface-below-ring=silky i
68 stalk-surface-below-ring=smooth i
69 stalk-color-above-ring=brown i
70 stalk-color-above-ring=buff i
71 stalk-color-above-ring=cinnamon i
72 stalk-color-above-ring=gray i
73 stalk-color-above-ring=orange i
74 stalk-color-above-ring=pink i
75 stalk-color-above-ring=red i
76 stalk-color-above-ring=white i
77 stalk-color-above-ring=yellow i
78 stalk-color-below-ring=brown i
79 stalk-color-below-ring=buff i
80 stalk-color-below-ring=cinnamon i
81 stalk-color-below-ring=gray i
82 stalk-color-below-ring=orange i
83 stalk-color-below-ring=pink i
84 stalk-color-below-ring=red i
85 stalk-color-below-ring=white i
86 stalk-color-below-ring=yellow i
87 veil-type=partial i
88 veil-type=universal i
89 veil-color=brown i
90 veil-color=orange i
91 veil-color=white i
92 veil-color=yellow i
93 ring-number=none i
94 ring-number=one i
95 ring-number=two i
96 ring-type=cobwebby i
97 ring-type=evanescent i
98 ring-type=flaring i
99 ring-type=large i
100 ring-type=none i
101 ring-type=pendant i
102 ring-type=sheathing i
103 ring-type=zone i
104 spore-print-color=black i
105 spore-print-color=brown i
106 spore-print-color=buff i
107 spore-print-color=chocolate i
108 spore-print-color=green i
109 spore-print-color=orange i
110 spore-print-color=purple i
111 spore-print-color=white i
112 spore-print-color=yellow i
113 population=abundant i
114 population=clustered i
115 population=numerous i
116 population=scattered i
117 population=several i
118 population=solitary i
119 habitat=grasses i
120 habitat=leaves i
121 habitat=meadows i
122 habitat=paths i
123 habitat=urban i
124 habitat=waste i
125 habitat=woods i

View File

@ -2,8 +2,8 @@
// facing an exception // facing an exception
#include <rabit.h> #include <rabit.h>
#include <rabit/utils.h> #include <rabit/utils.h>
#include "./toolkit_util.h"
#include <time.h> #include <time.h>
#include "../utils/data.h"
using namespace rabit; using namespace rabit;

View File

@ -0,0 +1,4 @@
Linear and Logistic Regression
====
* input format: LibSVM
* Example: [run-linear.sh](run-linear.sh)

View File

@ -1,4 +1,6 @@
#include "./linear.h" #include "./linear.h"
#include "../utils/io.h"
#include "../utils/base64.h"
namespace rabit { namespace rabit {
namespace linear { namespace linear {
@ -22,9 +24,10 @@ class LinearObjFunction : public solver::IObjFunction<float> {
model.weight = NULL; model.weight = NULL;
task = "train"; task = "train";
model_in = "NULL"; model_in = "NULL";
name_pred = "pred.txt";
model_out = "final.model";
} }
virtual ~LinearObjFunction(void) { virtual ~LinearObjFunction(void) {
if (model.weight != NULL) delete [] model.weight;
} }
// set parameters // set parameters
inline void SetParam(const char *name, const char *val) { inline void SetParam(const char *name, const char *val) {
@ -44,20 +47,79 @@ class LinearObjFunction : public solver::IObjFunction<float> {
if (!strcmp(name, "task")) task = val; if (!strcmp(name, "task")) task = val;
if (!strcmp(name, "model_in")) model_in = val; if (!strcmp(name, "model_in")) model_in = val;
if (!strcmp(name, "model_out")) model_out = val; if (!strcmp(name, "model_out")) model_out = val;
if (!strcmp(name, "name_pred")) name_pred = val;
} }
inline void Run(void) { inline void Run(void) {
if (model_in != "NULL") { if (model_in != "NULL") {
this->LoadModel(model_in.c_str());
} }
if (task == "train") { if (task == "train") {
lbfgs.Run(); lbfgs.Run();
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
} else if (task == "pred") { } else if (task == "pred") {
this->TaskPred();
} else if (task == "eval") {
} else { } else {
utils::Error("unknown task=%s", task.c_str()); utils::Error("unknown task=%s", task.c_str());
} }
} }
inline void TaskPred(void) {
utils::Check(model_in != "NULL",
"must set model_in for task=pred");
FILE *fp = utils::FopenCheck(name_pred.c_str(), "w");
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
float pred = model.Predict(dtrain[i]);
fprintf(fp, "%g\n", pred);
}
fclose(fp);
printf("Finishing writing to %s\n", name_pred.c_str());
}
inline void LoadModel(const char *fname) {
FILE *fp = utils::FopenCheck(fname, "rb");
std::string header; header.resize(4);
// check header for different binary encode
// can be base64 or binary
utils::FileStream fi(fp);
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
bsin.InitPosition();
model.Load(bsin);
fclose(fp);
return;
} else if (header == "binf") {
model.Load(fi);
fclose(fp);
return;
} else {
utils::Error("invalid model file");
}
}
inline void SaveModel(const char *fname,
const float *wptr,
bool save_base64 = false) {
FILE *fp;
bool use_stdout = false;
if (!strcmp(fname, "stdout")) {
fp = stdout;
use_stdout = true;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
if (save_base64 != 0|| use_stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
model.Save(bout, wptr);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
model.Save(fo, wptr);
}
if (!use_stdout) {
fclose(fp);
}
}
inline void LoadData(const char *fname) { inline void LoadData(const char *fname) {
dtrain.Load(fname); dtrain.Load(fname);
} }
@ -137,11 +199,12 @@ class LinearObjFunction : public solver::IObjFunction<float> {
} }
} }
} }
private: private:
std::string task; std::string task;
std::string model_in; std::string model_in;
std::string model_out; std::string model_out;
std::string name_pred;
}; };
} // namespace linear } // namespace linear
} // namespace rabit } // namespace rabit

View File

@ -8,7 +8,7 @@
#ifndef RABIT_LINEAR_H_ #ifndef RABIT_LINEAR_H_
#define RABIT_LINEAR_H_ #define RABIT_LINEAR_H_
#include <omp.h> #include <omp.h>
#include "../common/toolkit_util.h" #include "../utils/data.h"
#include "../solver/lbfgs.h" #include "../solver/lbfgs.h"
namespace rabit { namespace rabit {
@ -92,6 +92,7 @@ struct LinearModel {
// weight[num_feature] is bias // weight[num_feature] is bias
float sum = base_score + weight[num_feature]; float sum = base_score + weight[num_feature];
for (unsigned i = 0; i < v.length; ++i) { for (unsigned i = 0; i < v.length; ++i) {
if (v[i].findex >= num_feature) continue;
sum += weight[v[i].findex] * v[i].fvalue; sum += weight[v[i].findex] * v[i].fvalue;
} }
return sum; return sum;
@ -115,12 +116,13 @@ struct LinearModel {
fi.Read(&param, sizeof(param)); fi.Read(&param, sizeof(param));
if (weight == NULL) { if (weight == NULL) {
weight = new float[param.num_feature + 1]; weight = new float[param.num_feature + 1];
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
} }
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
} }
inline void Save(rabit::IStream &fo) const { inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
fo.Write(&param, sizeof(param)); fo.Write(&param, sizeof(param));
fo.Write(weight, sizeof(float) * (param.num_feature + 1)); if (wptr == NULL) wptr = weight;
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
} }
inline float Predict(const SparseMat::Vector &v) const { inline float Predict(const SparseMat::Vector &v) const {
return param.Predict(weight, v); return param.Predict(weight, v);

View File

@ -12,4 +12,6 @@ k=$1
python splitrows.py ../data/agaricus.txt.train mushroom $k python splitrows.py ../data/agaricus.txt.train mushroom $k
# run xgboost mpi # run xgboost mpi
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" ../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model

View File

@ -5,8 +5,8 @@
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_LBFGS_H_ #ifndef RABIT_LEARN_LBFGS_H_
#define RABIT_LBFGS_H_ #define RABIT_LEARN_LBFGS_H_
#include <cmath> #include <cmath>
#include <rabit.h> #include <rabit.h>
@ -62,7 +62,7 @@ class LBFGSSolver {
linesearch_c1 = 1e-4; linesearch_c1 = 1e-4;
min_lbfgs_iter = 5; min_lbfgs_iter = 5;
max_lbfgs_iter = 1000; max_lbfgs_iter = 1000;
lbfgs_stop_tol = 1e-5f; lbfgs_stop_tol = 3e-6f;
silent = 0; silent = 0;
} }
virtual ~LBFGSSolver(void) {} virtual ~LBFGSSolver(void) {}
@ -81,6 +81,9 @@ class LBFGSSolver {
if (!strcmp("reg_L1", name)) { if (!strcmp("reg_L1", name)) {
reg_L1 = static_cast<float>(atof(val)); reg_L1 = static_cast<float>(atof(val));
} }
if (!strcmp("lbfgs_stop_tol", name)) {
lbfgs_stop_tol = static_cast<float>(atof(val));
}
if (!strcmp("linesearch_backoff", name)) { if (!strcmp("linesearch_backoff", name)) {
linesearch_backoff = static_cast<float>(atof(val)); linesearch_backoff = static_cast<float>(atof(val));
} }
@ -185,8 +188,13 @@ class LBFGSSolver {
if (this->UpdateOneIter()) break; if (this->UpdateOneIter()) break;
} }
if (silent == 0 && rabit::GetRank() == 0) { if (silent == 0 && rabit::GetRank() == 0) {
size_t nonzero = 0;
for (size_t i = 0; i < gstate.num_dim; ++i) {
if (gstate.weight[i] != 0.0f) nonzero += 1;
}
rabit::TrackerPrintf rabit::TrackerPrintf
("L-BFGS: finishes at iteration %d\n", gstate.num_iteration); ("L-BFGS: finishes at iteration %d, %lu/%lu active weights\n",
gstate.num_iteration, nonzero, gstate.num_dim);
} }
} }
protected: protected:
@ -625,4 +633,4 @@ class LBFGSSolver {
}; };
} // namespace solver } // namespace solver
} // namespace rabit } // namespace rabit
#endif // RABIT_LBFGS_H_ #endif // RABIT_LEARN_LBFGS_H_

204
rabit-learn/utils/base64.h Normal file
View File

@ -0,0 +1,204 @@
#ifndef RABIT_LEARN_UTILS_BASE64_H_
#define RABIT_LEARN_UTILS_BASE64_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen
*/
#include <cctype>
#include <cstdio>
#include <rabit/io.h>
namespace rabit {
namespace utils {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
num_prev = 0; tmp_ch = 0;
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* call this function before actually start read
*/
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
inline bool IsEOF(void) const {
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
}
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev != 0) {
if (num_prev == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev = 1;
}
} else {
// assert num_prev == 1
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
// first byte
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev ++] = nvalue & 0xFF;
}
}
// get next char
tmp_ch = fgetc(fp);
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write");
}
private:
FILE *fp;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf_top < 3 && tlen != 0) {
buf[++buf_top] = *cptr++; --tlen;
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
*/
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
}
private:
FILE *fp;
int buf_top;
unsigned char buf[4];
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_BASE64_H_

View File

@ -1,12 +1,12 @@
/*! /*!
* Copyright (c) 2015 by Contributors * Copyright (c) 2015 by Contributors
* \file toolkit_util.h * \file data.h
* \brief simple data structure that could be used by model * \brief simple data structure that could be used by model
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_TOOLKIT_UTIL_H_ #ifndef RABIT_LEARN_DATA_H_
#define RABIT_TOOLKIT_UTIL_H_ #define RABIT_LEARN_DATA_H_
#include <vector> #include <vector>
#include <cstdlib> #include <cstdlib>
@ -140,4 +140,4 @@ inline int Random(int value) {
return rand() % value; return rand() % value;
} }
} // namespace rabit } // namespace rabit
#endif // RABIT_TOOLKIT_UTIL_H_ #endif // RABIT_LEARN_DATA_H_

40
rabit-learn/utils/io.h Normal file
View File

@ -0,0 +1,40 @@
#ifndef RABIT_LEARN_UTILS_IO_H_
#define RABIT_LEARN_UTILS_IO_H_
/*!
* \file io.h
* \brief additional stream interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif // RABIT_LEARN_UTILS_IO_H_