add all
This commit is contained in:
parent
12ee049a74
commit
4a5b9e5f78
@ -4,7 +4,7 @@ export CC = gcc
|
||||
export CXX = g++
|
||||
export MPICXX = mpicxx
|
||||
export LDFLAGS= -pthread -lm -L../../lib
|
||||
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include -I../common
|
||||
export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include
|
||||
|
||||
.PHONY: clean all lib mpi
|
||||
all: $(BIN) $(MOCKBIN)
|
||||
|
||||
2
rabit-learn/data/README.md
Normal file
2
rabit-learn/data/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
This folder contains processed example dataset used by the demos.
|
||||
Copyright of the dataset belongs to the original copyright holder
|
||||
1611
rabit-learn/data/agaricus.txt.test
Normal file
1611
rabit-learn/data/agaricus.txt.test
Normal file
File diff suppressed because it is too large
Load Diff
6513
rabit-learn/data/agaricus.txt.train
Normal file
6513
rabit-learn/data/agaricus.txt.train
Normal file
File diff suppressed because it is too large
Load Diff
126
rabit-learn/data/featmap.txt
Normal file
126
rabit-learn/data/featmap.txt
Normal file
@ -0,0 +1,126 @@
|
||||
0 cap-shape=bell i
|
||||
1 cap-shape=conical i
|
||||
2 cap-shape=convex i
|
||||
3 cap-shape=flat i
|
||||
4 cap-shape=knobbed i
|
||||
5 cap-shape=sunken i
|
||||
6 cap-surface=fibrous i
|
||||
7 cap-surface=grooves i
|
||||
8 cap-surface=scaly i
|
||||
9 cap-surface=smooth i
|
||||
10 cap-color=brown i
|
||||
11 cap-color=buff i
|
||||
12 cap-color=cinnamon i
|
||||
13 cap-color=gray i
|
||||
14 cap-color=green i
|
||||
15 cap-color=pink i
|
||||
16 cap-color=purple i
|
||||
17 cap-color=red i
|
||||
18 cap-color=white i
|
||||
19 cap-color=yellow i
|
||||
20 bruises?=bruises i
|
||||
21 bruises?=no i
|
||||
22 odor=almond i
|
||||
23 odor=anise i
|
||||
24 odor=creosote i
|
||||
25 odor=fishy i
|
||||
26 odor=foul i
|
||||
27 odor=musty i
|
||||
28 odor=none i
|
||||
29 odor=pungent i
|
||||
30 odor=spicy i
|
||||
31 gill-attachment=attached i
|
||||
32 gill-attachment=descending i
|
||||
33 gill-attachment=free i
|
||||
34 gill-attachment=notched i
|
||||
35 gill-spacing=close i
|
||||
36 gill-spacing=crowded i
|
||||
37 gill-spacing=distant i
|
||||
38 gill-size=broad i
|
||||
39 gill-size=narrow i
|
||||
40 gill-color=black i
|
||||
41 gill-color=brown i
|
||||
42 gill-color=buff i
|
||||
43 gill-color=chocolate i
|
||||
44 gill-color=gray i
|
||||
45 gill-color=green i
|
||||
46 gill-color=orange i
|
||||
47 gill-color=pink i
|
||||
48 gill-color=purple i
|
||||
49 gill-color=red i
|
||||
50 gill-color=white i
|
||||
51 gill-color=yellow i
|
||||
52 stalk-shape=enlarging i
|
||||
53 stalk-shape=tapering i
|
||||
54 stalk-root=bulbous i
|
||||
55 stalk-root=club i
|
||||
56 stalk-root=cup i
|
||||
57 stalk-root=equal i
|
||||
58 stalk-root=rhizomorphs i
|
||||
59 stalk-root=rooted i
|
||||
60 stalk-root=missing i
|
||||
61 stalk-surface-above-ring=fibrous i
|
||||
62 stalk-surface-above-ring=scaly i
|
||||
63 stalk-surface-above-ring=silky i
|
||||
64 stalk-surface-above-ring=smooth i
|
||||
65 stalk-surface-below-ring=fibrous i
|
||||
66 stalk-surface-below-ring=scaly i
|
||||
67 stalk-surface-below-ring=silky i
|
||||
68 stalk-surface-below-ring=smooth i
|
||||
69 stalk-color-above-ring=brown i
|
||||
70 stalk-color-above-ring=buff i
|
||||
71 stalk-color-above-ring=cinnamon i
|
||||
72 stalk-color-above-ring=gray i
|
||||
73 stalk-color-above-ring=orange i
|
||||
74 stalk-color-above-ring=pink i
|
||||
75 stalk-color-above-ring=red i
|
||||
76 stalk-color-above-ring=white i
|
||||
77 stalk-color-above-ring=yellow i
|
||||
78 stalk-color-below-ring=brown i
|
||||
79 stalk-color-below-ring=buff i
|
||||
80 stalk-color-below-ring=cinnamon i
|
||||
81 stalk-color-below-ring=gray i
|
||||
82 stalk-color-below-ring=orange i
|
||||
83 stalk-color-below-ring=pink i
|
||||
84 stalk-color-below-ring=red i
|
||||
85 stalk-color-below-ring=white i
|
||||
86 stalk-color-below-ring=yellow i
|
||||
87 veil-type=partial i
|
||||
88 veil-type=universal i
|
||||
89 veil-color=brown i
|
||||
90 veil-color=orange i
|
||||
91 veil-color=white i
|
||||
92 veil-color=yellow i
|
||||
93 ring-number=none i
|
||||
94 ring-number=one i
|
||||
95 ring-number=two i
|
||||
96 ring-type=cobwebby i
|
||||
97 ring-type=evanescent i
|
||||
98 ring-type=flaring i
|
||||
99 ring-type=large i
|
||||
100 ring-type=none i
|
||||
101 ring-type=pendant i
|
||||
102 ring-type=sheathing i
|
||||
103 ring-type=zone i
|
||||
104 spore-print-color=black i
|
||||
105 spore-print-color=brown i
|
||||
106 spore-print-color=buff i
|
||||
107 spore-print-color=chocolate i
|
||||
108 spore-print-color=green i
|
||||
109 spore-print-color=orange i
|
||||
110 spore-print-color=purple i
|
||||
111 spore-print-color=white i
|
||||
112 spore-print-color=yellow i
|
||||
113 population=abundant i
|
||||
114 population=clustered i
|
||||
115 population=numerous i
|
||||
116 population=scattered i
|
||||
117 population=several i
|
||||
118 population=solitary i
|
||||
119 habitat=grasses i
|
||||
120 habitat=leaves i
|
||||
121 habitat=meadows i
|
||||
122 habitat=paths i
|
||||
123 habitat=urban i
|
||||
124 habitat=waste i
|
||||
125 habitat=woods i
|
||||
@ -2,8 +2,8 @@
|
||||
// facing an exception
|
||||
#include <rabit.h>
|
||||
#include <rabit/utils.h>
|
||||
#include "./toolkit_util.h"
|
||||
#include <time.h>
|
||||
#include "../utils/data.h"
|
||||
|
||||
using namespace rabit;
|
||||
|
||||
|
||||
4
rabit-learn/linear/README.md
Normal file
4
rabit-learn/linear/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
Linear and Logistic Regression
|
||||
====
|
||||
* input format: LibSVM
|
||||
* Example: [run-linear.sh](run-linear.sh)
|
||||
@ -1,4 +1,6 @@
|
||||
#include "./linear.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/base64.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace linear {
|
||||
@ -22,9 +24,10 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
model.weight = NULL;
|
||||
task = "train";
|
||||
model_in = "NULL";
|
||||
name_pred = "pred.txt";
|
||||
model_out = "final.model";
|
||||
}
|
||||
virtual ~LinearObjFunction(void) {
|
||||
if (model.weight != NULL) delete [] model.weight;
|
||||
}
|
||||
// set parameters
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
@ -44,20 +47,79 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
if (!strcmp(name, "task")) task = val;
|
||||
if (!strcmp(name, "model_in")) model_in = val;
|
||||
if (!strcmp(name, "model_out")) model_out = val;
|
||||
if (!strcmp(name, "name_pred")) name_pred = val;
|
||||
}
|
||||
inline void Run(void) {
|
||||
if (model_in != "NULL") {
|
||||
this->LoadModel(model_in.c_str());
|
||||
}
|
||||
if (task == "train") {
|
||||
lbfgs.Run();
|
||||
|
||||
this->SaveModel(model_out.c_str(), lbfgs.GetWeight());
|
||||
} else if (task == "pred") {
|
||||
|
||||
} else if (task == "eval") {
|
||||
this->TaskPred();
|
||||
} else {
|
||||
utils::Error("unknown task=%s", task.c_str());
|
||||
}
|
||||
}
|
||||
inline void TaskPred(void) {
|
||||
utils::Check(model_in != "NULL",
|
||||
"must set model_in for task=pred");
|
||||
FILE *fp = utils::FopenCheck(name_pred.c_str(), "w");
|
||||
for (size_t i = 0; i < dtrain.NumRow(); ++i) {
|
||||
float pred = model.Predict(dtrain[i]);
|
||||
fprintf(fp, "%g\n", pred);
|
||||
}
|
||||
fclose(fp);
|
||||
printf("Finishing writing to %s\n", name_pred.c_str());
|
||||
}
|
||||
inline void LoadModel(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
std::string header; header.resize(4);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
utils::FileStream fi(fp);
|
||||
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fp);
|
||||
bsin.InitPosition();
|
||||
model.Load(bsin);
|
||||
fclose(fp);
|
||||
return;
|
||||
} else if (header == "binf") {
|
||||
model.Load(fi);
|
||||
fclose(fp);
|
||||
return;
|
||||
} else {
|
||||
utils::Error("invalid model file");
|
||||
}
|
||||
}
|
||||
inline void SaveModel(const char *fname,
|
||||
const float *wptr,
|
||||
bool save_base64 = false) {
|
||||
FILE *fp;
|
||||
bool use_stdout = false;
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fp = stdout;
|
||||
use_stdout = true;
|
||||
} else {
|
||||
fp = utils::FopenCheck(fname, "wb");
|
||||
}
|
||||
utils::FileStream fo(fp);
|
||||
if (save_base64 != 0|| use_stdout) {
|
||||
fo.Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fp);
|
||||
model.Save(bout, wptr);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo.Write("binf", 4);
|
||||
model.Save(fo, wptr);
|
||||
}
|
||||
if (!use_stdout) {
|
||||
fclose(fp);
|
||||
}
|
||||
}
|
||||
inline void LoadData(const char *fname) {
|
||||
dtrain.Load(fname);
|
||||
}
|
||||
@ -142,6 +204,7 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
std::string task;
|
||||
std::string model_in;
|
||||
std::string model_out;
|
||||
std::string name_pred;
|
||||
};
|
||||
} // namespace linear
|
||||
} // namespace rabit
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#ifndef RABIT_LINEAR_H_
|
||||
#define RABIT_LINEAR_H_
|
||||
#include <omp.h>
|
||||
#include "../common/toolkit_util.h"
|
||||
#include "../utils/data.h"
|
||||
#include "../solver/lbfgs.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -92,6 +92,7 @@ struct LinearModel {
|
||||
// weight[num_feature] is bias
|
||||
float sum = base_score + weight[num_feature];
|
||||
for (unsigned i = 0; i < v.length; ++i) {
|
||||
if (v[i].findex >= num_feature) continue;
|
||||
sum += weight[v[i].findex] * v[i].fvalue;
|
||||
}
|
||||
return sum;
|
||||
@ -115,12 +116,13 @@ struct LinearModel {
|
||||
fi.Read(¶m, sizeof(param));
|
||||
if (weight == NULL) {
|
||||
weight = new float[param.num_feature + 1];
|
||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
fi.Read(weight, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
inline void Save(rabit::IStream &fo) const {
|
||||
inline void Save(rabit::IStream &fo, const float *wptr = NULL) const {
|
||||
fo.Write(¶m, sizeof(param));
|
||||
fo.Write(weight, sizeof(float) * (param.num_feature + 1));
|
||||
if (wptr == NULL) wptr = weight;
|
||||
fo.Write(wptr, sizeof(float) * (param.num_feature + 1));
|
||||
}
|
||||
inline float Predict(const SparseMat::Vector &v) const {
|
||||
return param.Predict(weight, v);
|
||||
|
||||
@ -12,4 +12,6 @@ k=$1
|
||||
python splitrows.py ../data/agaricus.txt.train mushroom $k
|
||||
|
||||
# run xgboost mpi
|
||||
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}"
|
||||
../../tracker/rabit_demo.py -n $k linear.rabit mushroom.row\%d "${*:2}" reg_L1=1
|
||||
|
||||
./linear.rabit ../data/agaricus.txt.test task=pred model_in=final.model
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_LBFGS_H_
|
||||
#define RABIT_LBFGS_H_
|
||||
#ifndef RABIT_LEARN_LBFGS_H_
|
||||
#define RABIT_LEARN_LBFGS_H_
|
||||
#include <cmath>
|
||||
#include <rabit.h>
|
||||
|
||||
@ -62,7 +62,7 @@ class LBFGSSolver {
|
||||
linesearch_c1 = 1e-4;
|
||||
min_lbfgs_iter = 5;
|
||||
max_lbfgs_iter = 1000;
|
||||
lbfgs_stop_tol = 1e-5f;
|
||||
lbfgs_stop_tol = 3e-6f;
|
||||
silent = 0;
|
||||
}
|
||||
virtual ~LBFGSSolver(void) {}
|
||||
@ -81,6 +81,9 @@ class LBFGSSolver {
|
||||
if (!strcmp("reg_L1", name)) {
|
||||
reg_L1 = static_cast<float>(atof(val));
|
||||
}
|
||||
if (!strcmp("lbfgs_stop_tol", name)) {
|
||||
lbfgs_stop_tol = static_cast<float>(atof(val));
|
||||
}
|
||||
if (!strcmp("linesearch_backoff", name)) {
|
||||
linesearch_backoff = static_cast<float>(atof(val));
|
||||
}
|
||||
@ -185,8 +188,13 @@ class LBFGSSolver {
|
||||
if (this->UpdateOneIter()) break;
|
||||
}
|
||||
if (silent == 0 && rabit::GetRank() == 0) {
|
||||
size_t nonzero = 0;
|
||||
for (size_t i = 0; i < gstate.num_dim; ++i) {
|
||||
if (gstate.weight[i] != 0.0f) nonzero += 1;
|
||||
}
|
||||
rabit::TrackerPrintf
|
||||
("L-BFGS: finishes at iteration %d\n", gstate.num_iteration);
|
||||
("L-BFGS: finishes at iteration %d, %lu/%lu active weights\n",
|
||||
gstate.num_iteration, nonzero, gstate.num_dim);
|
||||
}
|
||||
}
|
||||
protected:
|
||||
@ -625,4 +633,4 @@ class LBFGSSolver {
|
||||
};
|
||||
} // namespace solver
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LBFGS_H_
|
||||
#endif // RABIT_LEARN_LBFGS_H_
|
||||
|
||||
204
rabit-learn/utils/base64.h
Normal file
204
rabit-learn/utils/base64.h
Normal file
@ -0,0 +1,204 @@
|
||||
#ifndef RABIT_LEARN_UTILS_BASE64_H_
|
||||
#define RABIT_LEARN_UTILS_BASE64_H_
|
||||
/*!
|
||||
* \file base64.h
|
||||
* \brief data stream support to input and output from/to base64 stream
|
||||
* base64 is easier to store and pass as text format in mapreduce
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <rabit/io.h>
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief namespace of base64 decoding and encoding table */
|
||||
namespace base64 {
|
||||
const char DecodeTable[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
62, // '+'
|
||||
0, 0, 0,
|
||||
63, // '/'
|
||||
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
||||
0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
||||
0, 0, 0, 0, 0, 0,
|
||||
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
|
||||
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
||||
};
|
||||
static const char EncodeTable[] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
} // namespace base64
|
||||
/*! \brief the stream that reads from base64, note we take from file pointers */
|
||||
class Base64InStream: public IStream {
|
||||
public:
|
||||
explicit Base64InStream(FILE *fp) : fp(fp) {
|
||||
num_prev = 0; tmp_ch = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief initialize the stream position to beginning of next base64 stream
|
||||
* call this function before actually start read
|
||||
*/
|
||||
inline void InitPosition(void) {
|
||||
// get a charater
|
||||
do {
|
||||
tmp_ch = fgetc(fp);
|
||||
} while (isspace(tmp_ch));
|
||||
}
|
||||
/*! \brief whether current position is end of a base64 stream */
|
||||
inline bool IsEOF(void) const {
|
||||
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
using base64::DecodeTable;
|
||||
if (size == 0) return 0;
|
||||
// use tlen to record left size
|
||||
size_t tlen = size;
|
||||
unsigned char *cptr = static_cast<unsigned char*>(ptr);
|
||||
// if anything left, load from previous buffered result
|
||||
if (num_prev != 0) {
|
||||
if (num_prev == 2) {
|
||||
if (tlen >= 2) {
|
||||
*cptr++ = buf_prev[0];
|
||||
*cptr++ = buf_prev[1];
|
||||
tlen -= 2;
|
||||
num_prev = 0;
|
||||
} else {
|
||||
// assert tlen == 1
|
||||
*cptr++ = buf_prev[0]; --tlen;
|
||||
buf_prev[0] = buf_prev[1];
|
||||
num_prev = 1;
|
||||
}
|
||||
} else {
|
||||
// assert num_prev == 1
|
||||
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
|
||||
}
|
||||
}
|
||||
if (tlen == 0) return size;
|
||||
int nvalue;
|
||||
// note: everything goes with 4 bytes in Base64
|
||||
// so we process 4 bytes a unit
|
||||
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
|
||||
// first byte
|
||||
nvalue = DecodeTable[tmp_ch] << 18;
|
||||
{
|
||||
// second byte
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||
"invalid base64 format");
|
||||
nvalue |= DecodeTable[tmp_ch] << 12;
|
||||
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
|
||||
}
|
||||
{
|
||||
// third byte
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||
"invalid base64 format");
|
||||
// handle termination
|
||||
if (tmp_ch == '=') {
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
||||
"invalid base64 format");
|
||||
break;
|
||||
}
|
||||
nvalue |= DecodeTable[tmp_ch] << 6;
|
||||
if (tlen) {
|
||||
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
|
||||
} else {
|
||||
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
|
||||
}
|
||||
}
|
||||
{
|
||||
// fourth byte
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||
"invalid base64 format");
|
||||
if (tmp_ch == '=') {
|
||||
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
|
||||
"invalid base64 format");
|
||||
break;
|
||||
}
|
||||
nvalue |= DecodeTable[tmp_ch];
|
||||
if (tlen) {
|
||||
*cptr++ = nvalue & 0xFF; --tlen;
|
||||
} else {
|
||||
buf_prev[num_prev ++] = nvalue & 0xFF;
|
||||
}
|
||||
}
|
||||
// get next char
|
||||
tmp_ch = fgetc(fp);
|
||||
}
|
||||
if (kStrictCheck) {
|
||||
Check(tlen == 0, "Base64InStream: read incomplete");
|
||||
}
|
||||
return size - tlen;
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
utils::Error("Base64InStream do not support write");
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
int tmp_ch;
|
||||
int num_prev;
|
||||
unsigned char buf_prev[2];
|
||||
// whether we need to do strict check
|
||||
static const bool kStrictCheck = false;
|
||||
};
|
||||
/*! \brief the stream that write to base64, note we take from file pointers */
|
||||
class Base64OutStream: public IStream {
|
||||
public:
|
||||
explicit Base64OutStream(FILE *fp) : fp(fp) {
|
||||
buf_top = 0;
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
using base64::EncodeTable;
|
||||
size_t tlen = size;
|
||||
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
|
||||
while (tlen) {
|
||||
while (buf_top < 3 && tlen != 0) {
|
||||
buf[++buf_top] = *cptr++; --tlen;
|
||||
}
|
||||
if (buf_top == 3) {
|
||||
// flush 4 bytes out
|
||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
||||
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
|
||||
fputc(EncodeTable[buf[3] & 0x3F], fp);
|
||||
buf_top = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
Error("Base64OutStream do not support read");
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
* \brief finish writing of all current base64 stream, do some post processing
|
||||
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
|
||||
*/
|
||||
inline void Finish(char endch = EOF) {
|
||||
using base64::EncodeTable;
|
||||
if (buf_top == 1) {
|
||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
|
||||
fputc('=', fp);
|
||||
fputc('=', fp);
|
||||
}
|
||||
if (buf_top == 2) {
|
||||
fputc(EncodeTable[buf[1] >> 2], fp);
|
||||
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
|
||||
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
|
||||
fputc('=', fp);
|
||||
}
|
||||
buf_top = 0;
|
||||
if (endch != EOF) fputc(endch, fp);
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
int buf_top;
|
||||
unsigned char buf[4];
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_UTILS_BASE64_H_
|
||||
@ -1,12 +1,12 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file toolkit_util.h
|
||||
* \file data.h
|
||||
* \brief simple data structure that could be used by model
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_TOOLKIT_UTIL_H_
|
||||
#define RABIT_TOOLKIT_UTIL_H_
|
||||
#ifndef RABIT_LEARN_DATA_H_
|
||||
#define RABIT_LEARN_DATA_H_
|
||||
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
@ -140,4 +140,4 @@ inline int Random(int value) {
|
||||
return rand() % value;
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif // RABIT_TOOLKIT_UTIL_H_
|
||||
#endif // RABIT_LEARN_DATA_H_
|
||||
40
rabit-learn/utils/io.h
Normal file
40
rabit-learn/utils/io.h
Normal file
@ -0,0 +1,40 @@
|
||||
#ifndef RABIT_LEARN_UTILS_IO_H_
|
||||
#define RABIT_LEARN_UTILS_IO_H_
|
||||
/*!
|
||||
* \file io.h
|
||||
* \brief additional stream interface
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief implementation of file i/o stream */
|
||||
class FileStream : public ISeekStream {
|
||||
public:
|
||||
explicit FileStream(FILE *fp) : fp(fp) {}
|
||||
explicit FileStream(void) {
|
||||
this->fp = NULL;
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return std::fread(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
std::fwrite(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return std::ftell(fp);
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL){
|
||||
std::fclose(fp); fp = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_UTILS_IO_H_
|
||||
Loading…
x
Reference in New Issue
Block a user