compile with dmlc
This commit is contained in:
parent
9b7907eda3
commit
5f902982f2
31
Makefile
31
Makefile
@ -16,18 +16,28 @@ ifeq ($(cxx11),1)
|
|||||||
else
|
else
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(hdfs),1)
|
# handling dmlc
|
||||||
CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include
|
ifdef dmlc
|
||||||
LDFLAGS+= -L$(HADOOP_HDFS_HOME)/lib/native -L$(JAVA_HOME)/jre/lib/amd64/server -lhdfs -ljvm
|
ifndef config
|
||||||
else
|
ifneq ("$(wildcard $(dmlc)/config.mk)","")
|
||||||
CFLAGS+= -DRABIT_USE_HDFS=0
|
config = $(dmlc)/config.mk
|
||||||
|
else
|
||||||
|
config = $(dmlc)/make/config.mk
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
include $(config)
|
||||||
|
include $(dmlc)/make/dmlc.mk
|
||||||
|
LDFLAGS+= $(DMLC_LDFLAGS)
|
||||||
|
LIBDMLC=$(dmlc)/libdmlc.a
|
||||||
|
else
|
||||||
|
LIBDMLC=dmlc_simple.o
|
||||||
endif
|
endif
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = xgboost
|
BIN = xgboost
|
||||||
MOCKBIN = xgboost.mock
|
MOCKBIN = xgboost.mock
|
||||||
OBJ = updater.o gbm.o io.o main.o
|
OBJ = updater.o gbm.o io.o main.o dmlc_simple.o
|
||||||
MPIBIN = xgboost.mpi
|
MPIBIN =
|
||||||
SLIB = wrapper/libxgboostwrapper.so
|
SLIB = wrapper/libxgboostwrapper.so
|
||||||
|
|
||||||
.PHONY: clean all mpi python Rpack
|
.PHONY: clean all mpi python Rpack
|
||||||
@ -38,13 +48,12 @@ mpi: $(MPIBIN)
|
|||||||
python: wrapper/libxgboostwrapper.so
|
python: wrapper/libxgboostwrapper.so
|
||||||
# now the wrapper takes in two files. io and wrapper part
|
# now the wrapper takes in two files. io and wrapper part
|
||||||
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
|
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
|
||||||
|
dmlc_simple.o: src/io/dmlc_simple.cpp src/utils/*.h
|
||||||
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
|
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
|
||||||
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
|
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
|
||||||
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
|
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
|
||||||
xgboost.mpi: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit_mpi.a
|
xgboost: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit.a $(LIBDMLC)
|
||||||
xgboost.mock: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit_mock.a
|
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o subtree/rabit/lib/librabit.a $(LIBDMLC)
|
||||||
xgboost: updater.o gbm.o io.o main.o subtree/rabit/lib/librabit.a
|
|
||||||
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o subtree/rabit/lib/librabit.a
|
|
||||||
|
|
||||||
# dependency on rabit
|
# dependency on rabit
|
||||||
subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc
|
subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc
|
||||||
|
|||||||
@ -4,4 +4,5 @@ PKGROOT=../../
|
|||||||
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT)
|
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT)
|
||||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
|
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
|
||||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
|
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
|
||||||
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o
|
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o $(PKGROOT)/src/io/dmlc_simple.o
|
||||||
|
|
||||||
|
|||||||
@ -15,5 +15,5 @@ xgblib:
|
|||||||
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT) -I../..
|
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -DXGBOOST_STRICT_CXX98_ -DRABIT_CUSTOMIZE_MSG_ -DRABIT_STRICT_CXX98_ -I$(PKGROOT) -I../..
|
||||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
|
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
|
||||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
|
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)
|
||||||
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o
|
OBJECTS= xgboost_R.o xgboost_assert.o $(PKGROOT)/wrapper/xgboost_wrapper.o $(PKGROOT)/src/io/io.o $(PKGROOT)/src/gbm/gbm.o $(PKGROOT)/src/tree/updater.o $(PKGROOT)/subtree/rabit/src/engine_empty.o $(PKGROOT)/src/io/dmlc_simple.o
|
||||||
$(OBJECTS) : xgblib
|
$(OBJECTS) : xgblib
|
||||||
|
|||||||
126
src/io/dmlc_simple.cpp
Normal file
126
src/io/dmlc_simple.cpp
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
|
#define NOMINMAX
|
||||||
|
#include "../utils/io.h"
|
||||||
|
|
||||||
|
// implements a single no split version of DMLC
|
||||||
|
// in case we want to avoid dependency on dmlc-core
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace utils {
|
||||||
|
class SingleFileSplit : public dmlc::InputSplit {
|
||||||
|
public:
|
||||||
|
explicit SingleFileSplit(const char *fname) {
|
||||||
|
if (!std::strcmp(fname, "stdin")) {
|
||||||
|
#ifndef XGBOOST_STRICT_CXX98_
|
||||||
|
use_stdin_ = true; fp_ = stdin;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
if (!use_stdin_) {
|
||||||
|
fp_ = utils::FopenCheck(fname, "r");
|
||||||
|
}
|
||||||
|
end_of_file_ = false;
|
||||||
|
}
|
||||||
|
virtual ~SingleFileSplit(void) {
|
||||||
|
if (!use_stdin_) std::fclose(fp_);
|
||||||
|
}
|
||||||
|
virtual bool ReadLine(std::string *out_data) {
|
||||||
|
if (end_of_file_) return false;
|
||||||
|
out_data->clear();
|
||||||
|
while (true) {
|
||||||
|
char c = std::fgetc(fp_);
|
||||||
|
if (c == EOF) {
|
||||||
|
end_of_file_ = true;
|
||||||
|
}
|
||||||
|
if (c != '\r' && c != '\n' && c != EOF) {
|
||||||
|
*out_data += c;
|
||||||
|
} else {
|
||||||
|
if (out_data->length() != 0) return true;
|
||||||
|
if (end_of_file_) return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::FILE *fp_;
|
||||||
|
bool use_stdin_;
|
||||||
|
bool end_of_file_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class StdFile : public dmlc::IStream {
|
||||||
|
public:
|
||||||
|
explicit StdFile(const char *fname, const char *mode)
|
||||||
|
: use_stdio(false) {
|
||||||
|
using namespace std;
|
||||||
|
#ifndef XGBOOST_STRICT_CXX98_
|
||||||
|
if (!strcmp(fname, "stdin")) {
|
||||||
|
use_stdio = true; fp = stdin;
|
||||||
|
}
|
||||||
|
if (!strcmp(fname, "stdout")) {
|
||||||
|
use_stdio = true; fp = stdout;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (!strncmp(fname, "file://", 7)) fname += 7;
|
||||||
|
if (!use_stdio) {
|
||||||
|
std::string flag = mode;
|
||||||
|
if (flag == "w") flag = "wb";
|
||||||
|
if (flag == "r") flag = "rb";
|
||||||
|
fp = utils::FopenCheck(fname, flag.c_str());
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
virtual ~StdFile(void) {
|
||||||
|
this->Close();
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
return std::fread(ptr, 1, size, fp);
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
std::fwrite(ptr, size, 1, fp);
|
||||||
|
}
|
||||||
|
virtual void Seek(size_t pos) {
|
||||||
|
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||||
|
}
|
||||||
|
virtual size_t Tell(void) {
|
||||||
|
return std::ftell(fp);
|
||||||
|
}
|
||||||
|
virtual bool AtEnd(void) const {
|
||||||
|
return std::feof(fp) != 0;
|
||||||
|
}
|
||||||
|
inline void Close(void) {
|
||||||
|
if (fp != NULL && !use_stdio) {
|
||||||
|
std::fclose(fp); fp = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::FILE *fp;
|
||||||
|
bool use_stdio;
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
namespace dmlc {
|
||||||
|
InputSplit* InputSplit::Create(const char *uri,
|
||||||
|
unsigned part,
|
||||||
|
unsigned nsplit) {
|
||||||
|
using namespace xgboost;
|
||||||
|
const char *msg = "xgboost is compiled in local mode\n"\
|
||||||
|
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
|
||||||
|
utils::Check(strncmp(uri, "s3://", 5) != 0, msg);
|
||||||
|
utils::Check(strncmp(uri, "hdfs://", 7) != 0, msg);
|
||||||
|
utils::Check(nsplit == 1, msg);
|
||||||
|
return new utils::SingleFileSplit(uri);
|
||||||
|
}
|
||||||
|
|
||||||
|
IStream *IStream::Create(const char *uri, const char * const flag) {
|
||||||
|
using namespace xgboost;
|
||||||
|
const char *msg = "xgboost is compiled in local mode\n"\
|
||||||
|
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
|
||||||
|
utils::Check(strncmp(uri, "s3://", 5) != 0, msg);
|
||||||
|
utils::Check(strncmp(uri, "hdfs://", 7) != 0, msg);
|
||||||
|
return new utils::StdFile(uri, flag);
|
||||||
|
}
|
||||||
|
} // namespace dmlc
|
||||||
|
|
||||||
@ -16,7 +16,10 @@ namespace xgboost {
|
|||||||
namespace io {
|
namespace io {
|
||||||
DataMatrix* LoadDataMatrix(const char *fname, bool silent,
|
DataMatrix* LoadDataMatrix(const char *fname, bool silent,
|
||||||
bool savebuffer, bool loadsplit) {
|
bool savebuffer, bool loadsplit) {
|
||||||
if (!std::strcmp(fname, "stdin") || loadsplit) {
|
if (!std::strcmp(fname, "stdin") ||
|
||||||
|
!std::strncmp(fname, "s3://", 5) ||
|
||||||
|
!std::strncmp(fname, "hdfs://", 7) ||
|
||||||
|
loadsplit) {
|
||||||
DMatrixSimple *dmat = new DMatrixSimple();
|
DMatrixSimple *dmat = new DMatrixSimple();
|
||||||
dmat->LoadText(fname, silent, loadsplit);
|
dmat->LoadText(fname, silent, loadsplit);
|
||||||
return dmat;
|
return dmat;
|
||||||
|
|||||||
@ -90,11 +90,11 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
rank = rabit::GetRank();
|
rank = rabit::GetRank();
|
||||||
npart = rabit::GetWorldSize();
|
npart = rabit::GetWorldSize();
|
||||||
}
|
}
|
||||||
rabit::io::InputSplit *in =
|
dmlc::InputSplit *in =
|
||||||
rabit::io::CreateInputSplit(uri, rank, npart);
|
dmlc::InputSplit::Create(uri, rank, npart);
|
||||||
this->Clear();
|
this->Clear();
|
||||||
std::string line;
|
std::string line;
|
||||||
while (in->NextLine(&line)) {
|
while (in->ReadLine(&line)) {
|
||||||
float label;
|
float label;
|
||||||
std::istringstream ss(line);
|
std::istringstream ss(line);
|
||||||
std::vector<RowBatch::Entry> feats;
|
std::vector<RowBatch::Entry> feats;
|
||||||
|
|||||||
@ -193,7 +193,7 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
* \param fname file name
|
* \param fname file name
|
||||||
*/
|
*/
|
||||||
inline void LoadModel(const char *fname) {
|
inline void LoadModel(const char *fname) {
|
||||||
utils::IStream *fi = rabit::io::CreateStream(fname, "r");
|
utils::IStream *fi = utils::IStream::Create(fname, "r");
|
||||||
std::string header; header.resize(4);
|
std::string header; header.resize(4);
|
||||||
// check header for different binary encode
|
// check header for different binary encode
|
||||||
// can be base64 or binary
|
// can be base64 or binary
|
||||||
@ -207,7 +207,7 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
this->LoadModel(*fi);
|
this->LoadModel(*fi);
|
||||||
} else {
|
} else {
|
||||||
delete fi;
|
delete fi;
|
||||||
fi = rabit::io::CreateStream(fname, "r");
|
fi = utils::IStream::Create(fname, "r");
|
||||||
this->LoadModel(*fi);
|
this->LoadModel(*fi);
|
||||||
}
|
}
|
||||||
delete fi;
|
delete fi;
|
||||||
@ -224,7 +224,7 @@ class BoostLearner : public rabit::ISerializable {
|
|||||||
* \param save_base64 whether save in base64 format
|
* \param save_base64 whether save in base64 format
|
||||||
*/
|
*/
|
||||||
inline void SaveModel(const char *fname, bool save_base64 = false) const {
|
inline void SaveModel(const char *fname, bool save_base64 = false) const {
|
||||||
utils::IStream *fo = rabit::io::CreateStream(fname, "w");
|
utils::IStream *fo = utils::IStream::Create(fname, "w");
|
||||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||||
fo->Write("bs64\t", 5);
|
fo->Write("bs64\t", 5);
|
||||||
utils::Base64OutStream bout(fo);
|
utils::Base64OutStream bout(fo);
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#include "../../subtree/rabit/include/rabit.h"
|
#include "../../subtree/rabit/include/rabit.h"
|
||||||
#include "../../subtree/rabit/rabit-learn/io/io.h"
|
|
||||||
#endif // XGBOOST_SYNC_H_
|
#endif // XGBOOST_SYNC_H_
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
264
src/utils/base64-inl.h
Normal file
264
src/utils/base64-inl.h
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
#ifndef XGBOOST_UTILS_BASE64_INL_H_
|
||||||
|
#define XGBOOST_UTILS_BASE64_INL_H_
|
||||||
|
/*!
|
||||||
|
* \file base64.h
|
||||||
|
* \brief data stream support to input and output from/to base64 stream
|
||||||
|
* base64 is easier to store and pass as text format in mapreduce
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#include <cctype>
|
||||||
|
#include <cstdio>
|
||||||
|
#include "./io.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace utils {
|
||||||
|
/*! \brief buffer reader of the stream that allows you to get */
|
||||||
|
class StreamBufferReader {
|
||||||
|
public:
|
||||||
|
StreamBufferReader(size_t buffer_size)
|
||||||
|
:stream_(NULL),
|
||||||
|
read_len_(1), read_ptr_(1) {
|
||||||
|
buffer_.resize(buffer_size);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief set input stream
|
||||||
|
*/
|
||||||
|
inline void set_stream(IStream *stream) {
|
||||||
|
stream_ = stream;
|
||||||
|
read_len_ = read_ptr_ = 1;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief allows quick read using get char
|
||||||
|
*/
|
||||||
|
inline char GetChar(void) {
|
||||||
|
while (true) {
|
||||||
|
if (read_ptr_ < read_len_) {
|
||||||
|
return buffer_[read_ptr_++];
|
||||||
|
} else {
|
||||||
|
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
|
||||||
|
if (read_len_ == 0) return EOF;
|
||||||
|
read_ptr_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief whether we are reaching the end of file */
|
||||||
|
inline bool AtEnd(void) const {
|
||||||
|
return read_len_ == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief the underlying stream */
|
||||||
|
IStream *stream_;
|
||||||
|
/*! \brief buffer to hold data */
|
||||||
|
std::string buffer_;
|
||||||
|
/*! \brief length of valid data in buffer */
|
||||||
|
size_t read_len_;
|
||||||
|
/*! \brief pointer in the buffer */
|
||||||
|
size_t read_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief namespace of base64 decoding and encoding table */
|
||||||
|
namespace base64 {
|
||||||
|
const char DecodeTable[] = {
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
62, // '+'
|
||||||
|
0, 0, 0,
|
||||||
|
63, // '/'
|
||||||
|
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
|
||||||
|
0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
|
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
|
||||||
|
0, 0, 0, 0, 0, 0,
|
||||||
|
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
|
||||||
|
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
|
||||||
|
};
|
||||||
|
static const char EncodeTable[] =
|
||||||
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
} // namespace base64
|
||||||
|
/*! \brief the stream that reads from base64, note we take from file pointers */
|
||||||
|
class Base64InStream: public IStream {
|
||||||
|
public:
|
||||||
|
explicit Base64InStream(IStream *fs) : reader_(256) {
|
||||||
|
reader_.set_stream(fs);
|
||||||
|
num_prev = 0; tmp_ch = 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief initialize the stream position to beginning of next base64 stream
|
||||||
|
* call this function before actually start read
|
||||||
|
*/
|
||||||
|
inline void InitPosition(void) {
|
||||||
|
// get a charater
|
||||||
|
do {
|
||||||
|
tmp_ch = reader_.GetChar();
|
||||||
|
} while (isspace(tmp_ch));
|
||||||
|
}
|
||||||
|
/*! \brief whether current position is end of a base64 stream */
|
||||||
|
inline bool IsEOF(void) const {
|
||||||
|
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
using base64::DecodeTable;
|
||||||
|
if (size == 0) return 0;
|
||||||
|
// use tlen to record left size
|
||||||
|
size_t tlen = size;
|
||||||
|
unsigned char *cptr = static_cast<unsigned char*>(ptr);
|
||||||
|
// if anything left, load from previous buffered result
|
||||||
|
if (num_prev != 0) {
|
||||||
|
if (num_prev == 2) {
|
||||||
|
if (tlen >= 2) {
|
||||||
|
*cptr++ = buf_prev[0];
|
||||||
|
*cptr++ = buf_prev[1];
|
||||||
|
tlen -= 2;
|
||||||
|
num_prev = 0;
|
||||||
|
} else {
|
||||||
|
// assert tlen == 1
|
||||||
|
*cptr++ = buf_prev[0]; --tlen;
|
||||||
|
buf_prev[0] = buf_prev[1];
|
||||||
|
num_prev = 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// assert num_prev == 1
|
||||||
|
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tlen == 0) return size;
|
||||||
|
int nvalue;
|
||||||
|
// note: everything goes with 4 bytes in Base64
|
||||||
|
// so we process 4 bytes a unit
|
||||||
|
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
|
||||||
|
// first byte
|
||||||
|
nvalue = DecodeTable[tmp_ch] << 18;
|
||||||
|
{
|
||||||
|
// second byte
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
nvalue |= DecodeTable[tmp_ch] << 12;
|
||||||
|
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// third byte
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
// handle termination
|
||||||
|
if (tmp_ch == '=') {
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == '='), "invalid base64 format");
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nvalue |= DecodeTable[tmp_ch] << 6;
|
||||||
|
if (tlen) {
|
||||||
|
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
|
||||||
|
} else {
|
||||||
|
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// fourth byte
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch != EOF && !isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
if (tmp_ch == '=') {
|
||||||
|
utils::Check((tmp_ch = reader_.GetChar(), tmp_ch == EOF || isspace(tmp_ch)),
|
||||||
|
"invalid base64 format");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nvalue |= DecodeTable[tmp_ch];
|
||||||
|
if (tlen) {
|
||||||
|
*cptr++ = nvalue & 0xFF; --tlen;
|
||||||
|
} else {
|
||||||
|
buf_prev[num_prev ++] = nvalue & 0xFF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// get next char
|
||||||
|
tmp_ch = reader_.GetChar();
|
||||||
|
}
|
||||||
|
if (kStrictCheck) {
|
||||||
|
utils::Check(tlen == 0, "Base64InStream: read incomplete");
|
||||||
|
}
|
||||||
|
return size - tlen;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
utils::Error("Base64InStream do not support write");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
StreamBufferReader reader_;
|
||||||
|
int tmp_ch;
|
||||||
|
int num_prev;
|
||||||
|
unsigned char buf_prev[2];
|
||||||
|
// whether we need to do strict check
|
||||||
|
static const bool kStrictCheck = false;
|
||||||
|
};
|
||||||
|
/*! \brief the stream that write to base64, note we take from file pointers */
|
||||||
|
class Base64OutStream: public IStream {
|
||||||
|
public:
|
||||||
|
explicit Base64OutStream(IStream *fp) : fp(fp) {
|
||||||
|
buf_top = 0;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
using base64::EncodeTable;
|
||||||
|
size_t tlen = size;
|
||||||
|
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
|
||||||
|
while (tlen) {
|
||||||
|
while (buf_top < 3 && tlen != 0) {
|
||||||
|
buf[++buf_top] = *cptr++; --tlen;
|
||||||
|
}
|
||||||
|
if (buf_top == 3) {
|
||||||
|
// flush 4 bytes out
|
||||||
|
PutChar(EncodeTable[buf[1] >> 2]);
|
||||||
|
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
|
||||||
|
PutChar(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F]);
|
||||||
|
PutChar(EncodeTable[buf[3] & 0x3F]);
|
||||||
|
buf_top = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
utils::Error("Base64OutStream do not support read");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief finish writing of all current base64 stream, do some post processing
|
||||||
|
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
|
||||||
|
*/
|
||||||
|
inline void Finish(char endch = EOF) {
|
||||||
|
using base64::EncodeTable;
|
||||||
|
if (buf_top == 1) {
|
||||||
|
PutChar(EncodeTable[buf[1] >> 2]);
|
||||||
|
PutChar(EncodeTable[(buf[1] << 4) & 0x3F]);
|
||||||
|
PutChar('=');
|
||||||
|
PutChar('=');
|
||||||
|
}
|
||||||
|
if (buf_top == 2) {
|
||||||
|
PutChar(EncodeTable[buf[1] >> 2]);
|
||||||
|
PutChar(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F]);
|
||||||
|
PutChar(EncodeTable[(buf[2] << 2) & 0x3F]);
|
||||||
|
PutChar('=');
|
||||||
|
}
|
||||||
|
buf_top = 0;
|
||||||
|
if (endch != EOF) PutChar(endch);
|
||||||
|
this->Flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
IStream *fp;
|
||||||
|
int buf_top;
|
||||||
|
unsigned char buf[4];
|
||||||
|
std::string out_buf;
|
||||||
|
const static size_t kBufferSize = 256;
|
||||||
|
|
||||||
|
inline void PutChar(char ch) {
|
||||||
|
out_buf += ch;
|
||||||
|
if (out_buf.length() >= kBufferSize) Flush();
|
||||||
|
}
|
||||||
|
inline void Flush(void) {
|
||||||
|
if (out_buf.length() != 0) {
|
||||||
|
fp->Write(&out_buf[0], out_buf.length());
|
||||||
|
out_buf.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace rabit
|
||||||
|
#endif // RABIT_LEARN_UTILS_BASE64_INL_H_
|
||||||
@ -18,8 +18,6 @@ typedef rabit::IStream IStream;
|
|||||||
typedef rabit::utils::ISeekStream ISeekStream;
|
typedef rabit::utils::ISeekStream ISeekStream;
|
||||||
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
||||||
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
||||||
typedef rabit::io::Base64InStream Base64InStream;
|
|
||||||
typedef rabit::io::Base64OutStream Base64OutStream;
|
|
||||||
|
|
||||||
/*! \brief implementation of file i/o stream */
|
/*! \brief implementation of file i/o stream */
|
||||||
class FileStream : public ISeekStream {
|
class FileStream : public ISeekStream {
|
||||||
@ -54,4 +52,6 @@ class FileStream : public ISeekStream {
|
|||||||
};
|
};
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#include "./base64-inl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user