add all
This commit is contained in:
parent
12ee049a74
commit
4a5b9e5f78
@ -4,7 +4,7 @@ export CC = gcc
|
|||||||
export CXX = g++
|
export CXX = g++
|
||||||
export MPICXX = mpicxx
|
export MPICXX = mpicxx
|
||||||
export LDFLAGS= -pthread -lm -L../../lib
|
export LDFLAGS= -pthread -lm -L../../lib
|
||||||
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include -I../common
|
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
|
||||||
|
|
||||||
.PHONY: clean all lib mpi
|
.PHONY: clean all lib mpi
|
||||||
all: $(BIN) $(MOCKBIN)
|
all: $(BIN) $(MOCKBIN)
|
||||||
|
|||||||
2
rabit-learn/data/README.md
Normal file
2
rabit-learn/data/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
This folder contains processed example dataset used by the demos.
|
||||||
|
Copyright of the dataset belongs to the original copyright holder
|
||||||
1611
rabit-learn/data/agaricus.txt.test
Normal file
1611
rabit-learn/data/agaricus.txt.test
Normal file
File diff suppressed because it is too large
Load Diff
6513
rabit-learn/data/agaricus.txt.train
Normal file
6513
rabit-learn/data/agaricus.txt.train
Normal file
File diff suppressed because it is too large
Load Diff
126
rabit-learn/data/featmap.txt
Normal file
126
rabit-learn/data/featmap.txt
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
0 cap-shape=bell i
|
||||||
|
1 cap-shape=conical i
|
||||||
|
2 cap-shape=convex i
|
||||||
|
3 cap-shape=flat i
|
||||||
|
4 cap-shape=knobbed i
|
||||||
|
5 cap-shape=sunken i
|
||||||
|
6 cap-surface=fibrous i
|
||||||
|
7 cap-surface=grooves i
|
||||||
|
8 cap-surface=scaly i
|
||||||
|
9 cap-surface=smooth i
|
||||||
|
10 cap-color=brown i
|
||||||
|
11 cap-color=buff i
|
||||||
|
12 cap-color=cinnamon i
|
||||||
|
13 cap-color=gray i
|
||||||
|
14 cap-color=green i
|
||||||
|
15 cap-color=pink i
|
||||||
|
16 cap-color=purple i
|
||||||
|
17 cap-color=red i
|
||||||
|
18 cap-color=white i
|
||||||
|
19 cap-color=yellow i
|
||||||
|
20 bruises?=bruises i
|
||||||
|
21 bruises?=no i
|
||||||
|
22 odor=almond i
|
||||||
|
23 odor=anise i
|
||||||
|
24 odor=creosote i
|
||||||
|
25 odor=fishy i
|
||||||
|
26 odor=foul i
|
||||||
|
27 odor=musty i
|
||||||
|
28 odor=none i
|
||||||
|
29 odor=pungent i
|
||||||
|
30 odor=spicy i
|
||||||
|
31 gill-attachment=attached i
|
||||||
|
32 gill-attachment=descending i
|
||||||
|
33 gill-attachment=free i
|
||||||
|
34 gill-attachment=notched i
|
||||||
|
35 gill-spacing=close i
|
||||||
|
36 gill-spacing=crowded i
|
||||||
|
37 gill-spacing=distant i
|
||||||
|
38 gill-size=broad i
|
||||||
|
39 gill-size=narrow i
|
||||||
|
40 gill-color=black i
|
||||||
|
41 gill-color=brown i
|
||||||
|
42 gill-color=buff i
|
||||||
|
43 gill-color=chocolate i
|
||||||
|
44 gill-color=gray i
|
||||||
|
45 gill-color=green i
|
||||||
|
46 gill-color=orange i
|
||||||
|
47 gill-color=pink i
|
||||||
|
48 gill-color=purple i
|
||||||
|
49 gill-color=red i
|
||||||
|
50 gill-color=white i
|
||||||
|
51 gill-color=yellow i
|
||||||
|
52 stalk-shape=enlarging i
|
||||||
|
53 stalk-shape=tapering i
|
||||||
|
54 stalk-root=bulbous i
|
||||||
|
55 stalk-root=club i
|
||||||
|
56 stalk-root=cup i
|
||||||
|
57 stalk-root=equal i
|
||||||
|
58 stalk-root=rhizomorphs i
|
||||||
|
59 stalk-root=rooted i
|
||||||
|
60 stalk-root=missing i
|
||||||
|
61 stalk-surface-above-ring=fibrous i
|
||||||
|
62 stalk-surface-above-ring=scaly i
|
||||||
|
63 stalk-surface-above-ring=silky i
|
||||||
|
64 stalk-surface-above-ring=smooth i
|
||||||
|
65 stalk-surface-below-ring=fibrous i
|
||||||
|
66 stalk-surface-below-ring=scaly i
|
||||||
|
67 stalk-surface-below-ring=silky i
|
||||||
|
68 stalk-surface-below-ring=smooth i
|
||||||
|
69 stalk-color-above-ring=brown i
|
||||||
|
70 stalk-color-above-ring=buff i
|
||||||
|
71 stalk-color-above-ring=cinnamon i
|
||||||
|
72 stalk-color-above-ring=gray i
|
||||||
|
73 stalk-color-above-ring=orange i
|
||||||
|
74 stalk-color-above-ring=pink i
|
||||||
|
75 stalk-color-above-ring=red i
|
||||||
|
76 stalk-color-above-ring=white i
|
||||||
|
77 stalk-color-above-ring=yellow i
|
||||||
|
78 stalk-color-below-ring=brown i
|
||||||
|
79 stalk-color-below-ring=buff i
|
||||||
|
80 stalk-color-below-ring=cinnamon i
|
||||||
|
81 stalk-color-below-ring=gray i
|
||||||
|
82 stalk-color-below-ring=orange i
|
||||||
|
83 stalk-color-below-ring=pink i
|
||||||
|
84 stalk-color-below-ring=red i
|
||||||
|
85 stalk-color-below-ring=white i
|
||||||
|
86 stalk-color-below-ring=yellow i
|
||||||
|
87 veil-type=partial i
|
||||||
|
88 veil-type=universal i
|
||||||
|
89 veil-color=brown i
|
||||||
|
90 veil-color=orange i
|
||||||
|
91 veil-color=white i
|
||||||
|
92 veil-color=yellow i
|
||||||
|
93 ring-number=none i
|
||||||
|
94 ring-number=one i
|
||||||
|
95 ring-number=two i
|
||||||
|
96 ring-type=cobwebby i
|
||||||
|
97 ring-type=evanescent i
|
||||||
|
98 ring-type=flaring i
|
||||||
|
99 ring-type=large i
|
||||||
|
100 ring-type=none i
|
||||||
|
101 ring-type=pendant i
|
||||||
|
102 ring-type=sheathing i
|
||||||
|
103 ring-type=zone i
|
||||||
|
104 spore-print-color=black i
|
||||||
|
105 spore-print-color=brown i
|
||||||
|
106 spore-print-color=buff i
|
||||||
|
107 spore-print-color=chocolate i
|
||||||
|
108 spore-print-color=green i
|
||||||
|
109 spore-print-color=orange i
|
||||||
|
110 spore-print-color=purple i
|
||||||
|
111 spore-print-color=white i
|
||||||
|
112 spore-print-color=yellow i
|
||||||
|
113 population=abundant i
|
||||||
|
114 population=clustered i
|
||||||
|
115 population=numerous i
|
||||||
|
116 population=scattered i
|
||||||
|
117 population=several i
|
||||||
|
118 population=solitary i
|
||||||
|
119 habitat=grasses i
|
||||||
|
120 habitat=leaves i
|
||||||
|
121 habitat=meadows i
|
||||||
|
122 habitat=paths i
|
||||||
|
123 habitat=urban i
|
||||||
|
124 habitat=waste i
|
||||||
|
125 habitat=woods i
|
||||||
@ -2,8 +2,8 @@
|
|||||||
// facing an exception
|
// facing an exception
|
||||||
#include <rabit.h>
|
#include <rabit.h>
|
||||||
#include <rabit/utils.h>
|
#include <rabit/utils.h>
|
||||||
#include "./toolkit_util.h"
|
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
#include "../utils/data.h"
|
||||||
|
|
||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
|
|
||||||
|
|||||||
4
rabit-learn/linear/README.md
Normal file
4
rabit-learn/linear/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
Linear and Logistic Regression
|
||||||
|
====
|
||||||
|
* input format: LibSVM
|
||||||
|
* Example: [run-linear.sh](run-linear.sh)
|
||||||
@ -1,4 +1,6 @@
|
|||||||
#include "./linear.h"
|
#include "./linear.h"
|
||||||
|
#include "../utils/io.h"
|
||||||
|
#include "../utils/base64.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace linear {
|
namespace linear {
|
||||||
@ -22,9 +24,10 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
|||||||
model.weight = NULL;
|
model.weight = NULL;
|
||||||
task = "train";
|
task = "train";
|
||||||
model_in = "NULL";
|
model_in = "NULL";
|
||||||
|
name_pred = "pred.txt";
|
||||||
|
model_out = "final.model";
|
||||||
}
|
}
|
||||||
virtual ~LinearObjFunction(void) {
|
virtual ~LinearObjFunction(void) {
|
||||||
if (model.weight != NULL) delete [] model.weight;
|
|
||||||
}
|
}
|
||||||
// set parameters
|
// set parameters
|
||||||
inline void SetParam(const char *name, const char *val) {
|
inline void SetParam(const char *name, const char *val) {
|
||||||
@ -44,20 +47,79 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
|||||||
if (!strcmp(name, "task")) task = val;
|
if (!strcmp(name, "task")) task = val;
|
||||||
if (!strcmp(name, "model_in")) model_in = val;
|
if (!strcmp(name, "model_in")) model_in = val;
|
||||||
if (!strcmp(name, "model_out")) model_out = val;
|
if (!strcmp(name, "model_out")) model_out = val;
|
||||||
|
if (!strcmp(name, "name_pred")) name_pred = val;
|
||||||
}
|
}
|
||||||
inline void Run(void) {
|
inline void Run(void) {
|
||||||
if (model_in != "NULL") {
|
if (model_in != "NULL") {
|
||||||
|
this->LoadModel(model_in.c_str());
|
||||||
}
|
}
|
||||||
if (task == "train") {
|
if (task == "train") {
|
||||||
lbfgs.Run();
|
lbfgs.Run();
|
||||||
|
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
|
||||||
} else if (task == "pred") {
|
} else if (task == "pred") {
|
||||||
|
this->TaskPred();
|
||||||
} else if (task == "eval") {
|
|
||||||
} else {
|
} else {
|
||||||
utils::Error("unknown task=%s", task.c_str());
|
utils::Error("unknown task=%s", task.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline void TaskPred(void) {
|
||||||
|
utils::Check(model_in != "NULL",
|
||||||
|
"must set model_in for task=pred");
|
||||||
|
FILE *fp = utils::FopenCheck(name_pred.c_str(), "w");
|
||||||
|
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
|
||||||
|
float pred = model.Predict(dtrain[i]);
|
||||||
|
fprintf(fp, "%g\n", pred);
|
||||||
|
}
|
||||||
|
fclose(fp);
|
||||||
|
printf("Finishing writing to %s\n", name_pred.c_str());
|
||||||
|
}
|
||||||
|
inline void LoadModel(const char *fname) {
|
||||||
|
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||||
|
std::string header; header.resize(4);
|
||||||
|
// check header for different binary encode
|
||||||
|
// can be base64 or binary
|
||||||
|
utils::FileStream fi(fp);
|
||||||
|
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
|
||||||
|
// base64 format
|
||||||
|
if (header == "bs64") {
|
||||||
|
utils::Base64InStream bsin(fp);
|
||||||
|
bsin.InitPosition();
|
||||||
|
model.Load(bsin);
|
||||||
|
fclose(fp);
|
||||||
|
return;
|
||||||
|
} else if (header == "binf") {
|
||||||
|
model.Load(fi);
|
||||||
|
fclose(fp);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
utils::Error("invalid model file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline void SaveModel(const char *fname,
|
||||||
|
const float *wptr,
|
||||||
|
bool save_base64 = false) {
|
||||||
|
FILE *fp;
|
||||||
|
bool use_stdout = false;
|
||||||
|
if (!strcmp(fname, "stdout")) {
|
||||||
|
fp = stdout;
|
||||||
|
use_stdout = true;
|
||||||
|
} else {
|
||||||
|
fp = utils::FopenCheck(fname, "wb");
|
||||||
|
}
|
||||||
|
utils::FileStream fo(fp);
|
||||||
|
if (save_base64 != 0|| use_stdout) {
|
||||||
|
fo.Write("bs64\t", 5);
|
||||||
|
utils::Base64OutStream bout(fp);
|
||||||
|
model.Save(bout, wptr);
|
||||||
|
bout.Finish('\n');
|
||||||
|
} else {
|
||||||
|
fo.Write("binf", 4);
|
||||||
|
model.Save(fo, wptr);
|
||||||
|
}
|
||||||
|
if (!use_stdout) {
|
||||||
|
fclose(fp);
|
||||||
|
}
|
||||||
|
}
|
||||||
inline void LoadData(const char *fname) {
|
inline void LoadData(const char *fname) {
|
||||||
dtrain.Load(fname);
|
dtrain.Load(fname);
|
||||||
}
|
}
|
||||||
@ -142,6 +204,7 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
|||||||
std::string task;
|
std::string task;
|
||||||
std::string model_in;
|
std::string model_in;
|
||||||
std::string model_out;
|
std::string model_out;
|
||||||
|
std::string name_pred;
|
||||||
};
|
};
|
||||||
} // namespace linear
|
} // namespace linear
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
#ifndef RABIT_LINEAR_H_
|
#ifndef RABIT_LINEAR_H_
|
||||||
#define RABIT_LINEAR_H_
|
#define RABIT_LINEAR_H_
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#include "../common/toolkit_util.h"
|
#include "../utils/data.h"
|
||||||
#include "../solver/lbfgs.h"
|
#include "../solver/lbfgs.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -92,6 +92,7 @@ struct LinearModel {
|
|||||||
// weight[num_feature] is bias
|
// weight[num_feature] is bias
|
||||||
float sum = base_score + weight[num_feature];
|
float sum = base_score + weight[num_feature];
|
||||||
for (unsigned i = 0; i < v.length; ++i) {
|
for (unsigned i = 0; i < v.length; ++i) {
|
||||||
|
if (v[i].findex >= num_feature) continue;
|
||||||
sum += weight[v[i].findex] * v[i].fvalue;
|
sum += weight[v[i].findex] * v[i].fvalue;
|
||||||
}
|
}
|
||||||
return sum;
|
return sum;
|
||||||
@ -115,12 +116,13 @@ struct LinearModel {
|
|||||||
fi.Read(¶m, sizeof(param));
|
fi.Read(¶m, sizeof(param));
|
||||||
if (weight == NULL) {
|
if (weight == NULL) {
|
||||||
weight = new float[param.num_feature + 1];
|
weight = new float[param.num_feature + 1];
|
||||||
|
}
|
||||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||||
}
|
}
|
||||||
}
|
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
|
||||||
inline void Save(rabit::IStream &fo) const {
|
|
||||||
fo.Write(¶m, sizeof(param));
|
fo.Write(¶m, sizeof(param));
|
||||||
fo.Write(weight, sizeof(float) * (param.num_feature + 1));
|
if (wptr == NULL) wptr = weight;
|
||||||
|
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||||
}
|
}
|
||||||
inline float Predict(const SparseMat::Vector &v) const {
|
inline float Predict(const SparseMat::Vector &v) const {
|
||||||
return param.Predict(weight, v);
|
return param.Predict(weight, v);
|
||||||
|
|||||||
@ -12,4 +12,6 @@ k=$1
|
|||||||
python splitrows.py ../data/agaricus.txt.train mushroom $k
|
python splitrows.py ../data/agaricus.txt.train mushroom $k
|
||||||
|
|
||||||
# run xgboost mpi
|
# run xgboost mpi
|
||||||
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}"
|
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
|
||||||
|
|
||||||
|
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model
|
||||||
|
|||||||
@ -5,8 +5,8 @@
|
|||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_LBFGS_H_
|
#ifndef RABIT_LEARN_LBFGS_H_
|
||||||
#define RABIT_LBFGS_H_
|
#define RABIT_LEARN_LBFGS_H_
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <rabit.h>
|
#include <rabit.h>
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ class LBFGSSolver {
|
|||||||
linesearch_c1 = 1e-4;
|
linesearch_c1 = 1e-4;
|
||||||
min_lbfgs_iter = 5;
|
min_lbfgs_iter = 5;
|
||||||
max_lbfgs_iter = 1000;
|
max_lbfgs_iter = 1000;
|
||||||
lbfgs_stop_tol = 1e-5f;
|
lbfgs_stop_tol = 3e-6f;
|
||||||
silent = 0;
|
silent = 0;
|
||||||
}
|
}
|
||||||
virtual ~LBFGSSolver(void) {}
|
virtual ~LBFGSSolver(void) {}
|
||||||
@ -81,6 +81,9 @@ class LBFGSSolver {
|
|||||||
if (!strcmp("reg_L1", name)) {
|
if (!strcmp("reg_L1", name)) {
|
||||||
reg_L1 = static_cast<float>(atof(val));
|
reg_L1 = static_cast<float>(atof(val));
|
||||||
}
|
}
|
||||||
|
if (!strcmp("lbfgs_stop_tol", name)) {
|
||||||
|
lbfgs_stop_tol = static_cast<float>(atof(val));
|
||||||
|
}
|
||||||
if (!strcmp("linesearch_backoff", name)) {
|
if (!strcmp("linesearch_backoff", name)) {
|
||||||
linesearch_backoff = static_cast<float>(atof(val));
|
linesearch_backoff = static_cast<float>(atof(val));
|
||||||
}
|
}
|
||||||
@ -185,8 +188,13 @@ class LBFGSSolver {
|
|||||||
if (this->UpdateOneIter()) break;
|
if (this->UpdateOneIter()) break;
|
||||||
}
|
}
|
||||||
if (silent == 0 && rabit::GetRank() == 0) {
|
if (silent == 0 && rabit::GetRank() == 0) {
|
||||||
|
size_t nonzero = 0;
|
||||||
|
for (size_t i = 0; i < gstate.num_dim; ++i) {
|
||||||
|
if (gstate.weight[i] != 0.0f) nonzero += 1;
|
||||||
|
}
|
||||||
rabit::TrackerPrintf
|
rabit::TrackerPrintf
|
||||||
("L-BFGS: finishes at iteration %d\n", gstate.num_iteration);
|
("L-BFGS: finishes at iteration %d, %lu/%lu active weights\n",
|
||||||
|
gstate.num_iteration, nonzero, gstate.num_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
protected:
|
protected:
|
||||||
@ -625,4 +633,4 @@ class LBFGSSolver {
|
|||||||
};
|
};
|
||||||
} // namespace solver
|
} // namespace solver
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_LBFGS_H_
|
#endif // RABIT_LEARN_LBFGS_H_
|
||||||
|
|||||||
204
rabit-learn/utils/base64.h
Normal file
204
rabit-learn/utils/base64.h
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
#ifndef RABIT_LEARN_UTILS_BASE64_H_
|
||||||
|
#define RABIT_LEARN_UTILS_BASE64_H_
|
||||||
|
/*!
|
||||||
|
* \file base64.h
|
||||||
|
* \brief data stream support to input and output from/to base64 stream
|
||||||
|
* base64 is easier to store and pass as text format in mapreduce
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#include <cctype>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <rabit/io.h>
|
||||||
|
|
||||||
|
namespace rabit {
|
||||||
|
namespace utils {
|
||||||
|
/*! \brief namespace of base64 decoding and encoding table */
|
||||||
|
namespace base64 {
|
||||||
|
const char DecodeTable[] = {
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
62, // '+'
|
||||||
|
0, 0, 0,
|
||||||
|
63, // '/'
|
||||||
|
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
||||||
|
0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
|
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
||||||
|
0, 0, 0, 0, 0, 0,
|
||||||
|
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
|
||||||
|
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
||||||
|
};
|
||||||
|
static const char EncodeTable[] =
|
||||||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
} // namespace base64
|
||||||
|
/*! \brief the stream that reads from base64, note we take from file pointers */
|
||||||
|
class Base64InStream: public IStream {
|
||||||
|
public:
|
||||||
|
explicit Base64InStream(FILE *fp) : fp(fp) {
|
||||||
|
num_prev = 0; tmp_ch = 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief initialize the stream position to beginning of next base64 stream
|
||||||
|
* call this function before actually start read
|
||||||
|
*/
|
||||||
|
inline void InitPosition(void) {
|
||||||
|
// get a charater
|
||||||
|
do {
|
||||||
|
tmp_ch = fgetc(fp);
|
||||||
|
} while (isspace(tmp_ch));
|
||||||
|
}
|
||||||
|
/*! \brief whether current position is end of a base64 stream */
|
||||||
|
inline bool IsEOF(void) const {
|
||||||
|
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
using base64::DecodeTable;
|
||||||
|
if (size == 0) return 0;
|
||||||
|
// use tlen to record left size
|
||||||
|
size_t tlen = size;
|
||||||
|
unsigned char *cptr = static_cast<unsigned char*>(ptr);
|
||||||
|
// if anything left, load from previous buffered result
|
||||||
|
if (num_prev != 0) {
|
||||||
|
if (num_prev == 2) {
|
||||||
|
if (tlen >= 2) {
|
||||||
|
*cptr++ = buf_prev[0];
|
||||||
|
*cptr++ = buf_prev[1];
|
||||||
|
tlen -= 2;
|
||||||
|
num_prev = 0;
|
||||||
|
} else {
|
||||||
|
// assert tlen == 1
|
||||||
|
*cptr++ = buf_prev[0]; --tlen;
|
||||||
|
buf_prev[0] = buf_prev[1];
|
||||||
|
num_prev = 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// assert num_prev == 1
|
||||||
|
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tlen == 0) return size;
|
||||||
|
int nvalue;
|
||||||
|
// note: everything goes with 4 bytes in Base64
|
||||||
|
// so we process 4 bytes a unit
|
||||||
|
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
|
||||||
|
// first byte
|
||||||
|
nvalue = DecodeTable[tmp_ch] << 18;
|
||||||
|
{
|
||||||
|
// second byte
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
nvalue |= DecodeTable[tmp_ch] << 12;
|
||||||
|
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// third byte
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
// handle termination
|
||||||
|
if (tmp_ch == '=') {
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nvalue |= DecodeTable[tmp_ch] << 6;
|
||||||
|
if (tlen) {
|
||||||
|
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
|
||||||
|
} else {
|
||||||
|
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// fourth byte
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
if (tmp_ch == '=') {
|
||||||
|
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nvalue |= DecodeTable[tmp_ch];
|
||||||
|
if (tlen) {
|
||||||
|
*cptr++ = nvalue & 0xFF; --tlen;
|
||||||
|
} else {
|
||||||
|
buf_prev[num_prev ++] = nvalue & 0xFF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// get next char
|
||||||
|
tmp_ch = fgetc(fp);
|
||||||
|
}
|
||||||
|
if (kStrictCheck) {
|
||||||
|
Check(tlen == 0, "Base64InStream: read incomplete");
|
||||||
|
}
|
||||||
|
return size - tlen;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
utils::Error("Base64InStream do not support write");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FILE *fp;
|
||||||
|
int tmp_ch;
|
||||||
|
int num_prev;
|
||||||
|
unsigned char buf_prev[2];
|
||||||
|
// whether we need to do strict check
|
||||||
|
static const bool kStrictCheck = false;
|
||||||
|
};
|
||||||
|
/*! \brief the stream that write to base64, note we take from file pointers */
|
||||||
|
class Base64OutStream: public IStream {
|
||||||
|
public:
|
||||||
|
explicit Base64OutStream(FILE *fp) : fp(fp) {
|
||||||
|
buf_top = 0;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
using base64::EncodeTable;
|
||||||
|
size_t tlen = size;
|
||||||
|
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
|
||||||
|
while (tlen) {
|
||||||
|
while (buf_top < 3 && tlen != 0) {
|
||||||
|
buf[++buf_top] = *cptr++; --tlen;
|
||||||
|
}
|
||||||
|
if (buf_top == 3) {
|
||||||
|
// flush 4 bytes out
|
||||||
|
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||||
|
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
||||||
|
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
|
||||||
|
fputc(EncodeTable[buf[3] & 0x3F], fp);
|
||||||
|
buf_top = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
Error("Base64OutStream do not support read");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief finish writing of all current base64 stream, do some post processing
|
||||||
|
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
|
||||||
|
*/
|
||||||
|
inline void Finish(char endch = EOF) {
|
||||||
|
using base64::EncodeTable;
|
||||||
|
if (buf_top == 1) {
|
||||||
|
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||||
|
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
|
||||||
|
fputc('=', fp);
|
||||||
|
fputc('=', fp);
|
||||||
|
}
|
||||||
|
if (buf_top == 2) {
|
||||||
|
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||||
|
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
||||||
|
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
|
||||||
|
fputc('=', fp);
|
||||||
|
}
|
||||||
|
buf_top = 0;
|
||||||
|
if (endch != EOF) fputc(endch, fp);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FILE *fp;
|
||||||
|
int buf_top;
|
||||||
|
unsigned char buf[4];
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace rabit
|
||||||
|
#endif // RABIT_LEARN_UTILS_BASE64_H_
|
||||||
@ -1,12 +1,12 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2015 by Contributors
|
* Copyright (c) 2015 by Contributors
|
||||||
* \file toolkit_util.h
|
* \file data.h
|
||||||
* \brief simple data structure that could be used by model
|
* \brief simple data structure that could be used by model
|
||||||
*
|
*
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_TOOLKIT_UTIL_H_
|
#ifndef RABIT_LEARN_DATA_H_
|
||||||
#define RABIT_TOOLKIT_UTIL_H_
|
#define RABIT_LEARN_DATA_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
@ -140,4 +140,4 @@ inline int Random(int value) {
|
|||||||
return rand() % value;
|
return rand() % value;
|
||||||
}
|
}
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_TOOLKIT_UTIL_H_
|
#endif // RABIT_LEARN_DATA_H_
|
||||||
40
rabit-learn/utils/io.h
Normal file
40
rabit-learn/utils/io.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#ifndef RABIT_LEARN_UTILS_IO_H_
|
||||||
|
#define RABIT_LEARN_UTILS_IO_H_
|
||||||
|
/*!
|
||||||
|
* \file io.h
|
||||||
|
* \brief additional stream interface
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
namespace rabit {
|
||||||
|
namespace utils {
|
||||||
|
/*! \brief implementation of file i/o stream */
|
||||||
|
class FileStream : public ISeekStream {
|
||||||
|
public:
|
||||||
|
explicit FileStream(FILE *fp) : fp(fp) {}
|
||||||
|
explicit FileStream(void) {
|
||||||
|
this->fp = NULL;
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
return std::fread(ptr, size, 1, fp);
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
std::fwrite(ptr, size, 1, fp);
|
||||||
|
}
|
||||||
|
virtual void Seek(size_t pos) {
|
||||||
|
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||||
|
}
|
||||||
|
virtual size_t Tell(void) {
|
||||||
|
return std::ftell(fp);
|
||||||
|
}
|
||||||
|
inline void Close(void) {
|
||||||
|
if (fp != NULL){
|
||||||
|
std::fclose(fp); fp = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FILE *fp;
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace rabit
|
||||||
|
#endif // RABIT_LEARN_UTILS_IO_H_
|
||||||
Loading…
x
Reference in New Issue
Block a user