refactor io, initial hdfs file access need test
This commit is contained in:
parent
19be870562
commit
88ce76767e
@ -1,5 +1,5 @@
|
||||
#ifndef RABIT_LEARN_IO_BASE64_H_
|
||||
#define RABIT_LEARN_IO_BASE64_H_
|
||||
#ifndef RABIT_LEARN_IO_BASE64_INL_H_
|
||||
#define RABIT_LEARN_IO_BASE64_INL_H_
|
||||
/*!
|
||||
* \file base64.h
|
||||
* \brief data stream support to input and output from/to base64 stream
|
||||
@ -9,7 +9,7 @@
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
#include "./buffer_reader-inl.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
@ -215,4 +215,4 @@ class Base64OutStream: public IStream {
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_UTILS_BASE64_H_
|
||||
#endif // RABIT_LEARN_UTILS_BASE64_INL_H_
|
||||
57
rabit-learn/io/buffer_reader-inl.h
Normal file
57
rabit-learn/io/buffer_reader-inl.h
Normal file
@ -0,0 +1,57 @@
|
||||
#ifndef RABIT_LEARN_IO_BUFFER_READER_INL_H_
|
||||
#define RABIT_LEARN_IO_BUFFER_READER_INL_H_
|
||||
/*!
|
||||
* \file buffer_reader-inl.h
|
||||
* \brief implementation of stream buffer reader
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include "./io.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
/*! \brief buffer reader of the stream that allows you to get */
|
||||
class StreamBufferReader {
|
||||
public:
|
||||
StreamBufferReader(size_t buffer_size)
|
||||
:stream_(NULL),
|
||||
read_len_(1), read_ptr_(1) {
|
||||
buffer_.resize(buffer_size);
|
||||
}
|
||||
/*!
|
||||
* \brief set input stream
|
||||
*/
|
||||
inline void set_stream(IStream *stream) {
|
||||
stream_ = stream;
|
||||
read_len_ = read_ptr_ = 1;
|
||||
}
|
||||
/*!
|
||||
* \brief allows quick read using get char
|
||||
*/
|
||||
inline char GetChar(void) {
|
||||
while (true) {
|
||||
if (read_ptr_ < read_len_) {
|
||||
return buffer_[read_ptr_++];
|
||||
} else {
|
||||
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
|
||||
if (read_len_ == 0) return EOF;
|
||||
read_ptr_ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline bool AtEnd(void) const {
|
||||
return read_len_ == 0;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief the underlying stream */
|
||||
IStream *stream_;
|
||||
/*! \brief buffer to hold data */
|
||||
std::string buffer_;
|
||||
/*! \brief length of valid data in buffer */
|
||||
size_t read_len_;
|
||||
/*! \brief pointer in the buffer */
|
||||
size_t read_ptr_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_BUFFER_READER_INL_H_
|
||||
104
rabit-learn/io/file-inl.h
Normal file
104
rabit-learn/io/file-inl.h
Normal file
@ -0,0 +1,104 @@
|
||||
#ifndef RABIT_LEARN_IO_FILE_INL_H_
|
||||
#define RABIT_LEARN_IO_FILE_INL_H_
|
||||
/*!
|
||||
* \file file-inl.h
|
||||
* \brief normal filesystem I/O
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <cstdio>
|
||||
#include "./io.h"
|
||||
#include "./line_split-inl.h"
|
||||
|
||||
/*! \brief io interface */
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
/*! \brief implementation of file i/o stream */
|
||||
class FileStream : public utils::ISeekStream {
|
||||
public:
|
||||
explicit FileStream(const char *fname, const char *mode)
|
||||
: use_stdio(false) {
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
if (!strcmp(fname, "stdin")) {
|
||||
use_stdio = true; fp = stdin;
|
||||
}
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
use_stdio = true; fp = stdout;
|
||||
}
|
||||
#endif
|
||||
if (!strncmp(fname, "file://", 7)) fname += 7;
|
||||
if (!use_stdio) {
|
||||
std::string flag = mode;
|
||||
if (flag == "w") flag = "wb";
|
||||
if (flag == "r") flag = "rb";
|
||||
fp = utils::FopenCheck(fname, flag.c_str());
|
||||
}
|
||||
}
|
||||
virtual ~FileStream(void) {
|
||||
this->Close();
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return std::fread(ptr, 1, size, fp);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
std::fwrite(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return std::ftell(fp);
|
||||
}
|
||||
virtual bool AtEnd(void) const {
|
||||
return feof(fp) != 0;
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL && !use_stdio) {
|
||||
std::fclose(fp); fp = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
bool use_stdio;
|
||||
};
|
||||
|
||||
/*! \brief line split from normal file system */
|
||||
class FileSplit : public LineSplitBase {
|
||||
public:
|
||||
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
||||
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
||||
std::vector<size_t> fsize;
|
||||
for (size_t i = 0; i < fnames_.size(); ++i) {
|
||||
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||
std::string tmp = fnames_[i].c_str() + 7;
|
||||
fnames_[i] = tmp;
|
||||
}
|
||||
fsize.push_back(GetFileSize(fnames_[i].c_str()));
|
||||
}
|
||||
LineSplitBase::Init(fsize, rank, nsplit);
|
||||
}
|
||||
virtual ~FileSplit(void) {}
|
||||
|
||||
protected:
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new FileStream(fnames_[file_index].c_str(), "rb");
|
||||
}
|
||||
// get file size
|
||||
inline static size_t GetFileSize(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
// NOTE: fseek may not be good, but serves as ok solution
|
||||
fseek(fp, 0, SEEK_END);
|
||||
size_t fsize = static_cast<size_t>(ftell(fp));
|
||||
fclose(fp);
|
||||
return fsize;
|
||||
}
|
||||
|
||||
private:
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_FILE_INL_H_
|
||||
|
||||
131
rabit-learn/io/hdfs-inl.h
Normal file
131
rabit-learn/io/hdfs-inl.h
Normal file
@ -0,0 +1,131 @@
|
||||
#ifndef RABIT_LEARN_IO_HDFS_INL_H_
|
||||
#define RABIT_LEARN_IO_HDFS_INL_H_
|
||||
/*!
|
||||
* \file hdfs-inl.h
|
||||
* \brief HDFS I/O
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <hdfs.h>
|
||||
#include <errno.h>
|
||||
#include "./io.h"
|
||||
#include "./line_split-inl.h"
|
||||
|
||||
/*! \brief io interface */
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
class HDFSStream : public utils::ISeekStream {
|
||||
public:
|
||||
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
|
||||
: fs_(fs) {
|
||||
int flag;
|
||||
if (!strcmp(mode, "r")) {
|
||||
flag = O_RDONLY;
|
||||
} else if (!strcmp(mode, "w")) {
|
||||
flag = O_WDONLY;
|
||||
} else if (!strcmp(mode, "a")) {
|
||||
flag = O_WDONLY | O_APPEND;
|
||||
} else {
|
||||
utils::Error("HDFSStream: unknown flag %s", mode);
|
||||
}
|
||||
fp_ = hdfsOpenFile(fs_, fname, flag, 0, 0, 0);
|
||||
utils::Check(fp_ != NULL,
|
||||
"HDFSStream: fail to open %s", fname);
|
||||
}
|
||||
virtual ~FileStream(void) {
|
||||
this->Close();
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
tSize nread = hdfsRead(fs_, fp_, ptr, size);
|
||||
if (nread == -1) {
|
||||
int errsv = errno;
|
||||
utils::Error("HDFSStream.Read Error:%s", strerror(errsv));
|
||||
}
|
||||
return static_cast<size_t>(nread);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
const char *buf = reinterpret_cast<const char*>(ptr);
|
||||
while (size != 0) {
|
||||
tSize nwrite = hdfsWrite(fs_, fp_, buf, size);
|
||||
if (nwrite == -1) {
|
||||
int errsv = errno;
|
||||
utils::Error("HDFSStream.Write Error:%s", strerror(errsv));
|
||||
}
|
||||
size_t sz = static_cast<size_t>(nwrite);
|
||||
buf += sz; size -= sz;
|
||||
}
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
if (hdfsSeek(fs_, fp_, pos) != 0) {
|
||||
int errsv = errno;
|
||||
utils::Error("HDFSStream.Seek Error:%s", strerror(errsv));
|
||||
}
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
tOffset offset = hdfsTell(fs_, fp_);
|
||||
if (offset == -1) {
|
||||
int errsv = errno;
|
||||
utils::Error("HDFSStream.Tell Error:%s", strerror(errsv));
|
||||
}
|
||||
return static_cast<size_t>(offset);
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL) {
|
||||
if (hdfsCloseFile(fs_, fp_) == 0) {
|
||||
int errsv = errno;
|
||||
utils::Error("HDFSStream.Close Error:%s", strerror(errsv));
|
||||
}
|
||||
fp_ = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
hdfsFS fs_;
|
||||
hdfsFile fp_;
|
||||
};
|
||||
|
||||
/*! \brief line split from normal file system */
|
||||
class HDFSSplit : public LineSplitBase {
|
||||
public:
|
||||
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
||||
fs_ = hdfsConnect("default", 0);
|
||||
std::string paths;
|
||||
LineSplitBase::SplitNames(&paths, uri, "#");
|
||||
// get the files
|
||||
std::vector<size_t> fsize;
|
||||
for (size_t i = 0; i < paths.size(); ++i) {
|
||||
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
|
||||
if (info->mKind == 'D') {
|
||||
int nentry;
|
||||
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
|
||||
for (int i = 0; i < nentry; ++i) {
|
||||
if (files[i].mKind == 'F') {
|
||||
fsize.push_back(files[i].mSize);
|
||||
fnames_.push_back(std::string(files[i].mName));
|
||||
}
|
||||
}
|
||||
hdfsFileInfo(files, nentry);
|
||||
} else {
|
||||
fsize.push_back(info->mSize);
|
||||
fnames_.push_back(std::string(info->mName));
|
||||
}
|
||||
hdfsFileInfo(info, 1);
|
||||
}
|
||||
LineSplitBase::Init(fsize, rank, nsplit);
|
||||
}
|
||||
virtual ~FileSplit(void) {}
|
||||
|
||||
protected:
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
|
||||
}
|
||||
|
||||
private:
|
||||
// hdfs handle
|
||||
hdfsFS fs_;
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_HDFS_INL_H_
|
||||
@ -7,7 +7,13 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <cstring>
|
||||
#include "./line_split.h"
|
||||
|
||||
#include "./io.h"
|
||||
#if RABIT_USE_HDFS
|
||||
#include "./hdfs-inl.h"
|
||||
#endif
|
||||
#include "./file-inl.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
/*!
|
||||
@ -26,11 +32,34 @@ inline InputSplit *CreateInputSplit(const char *uri,
|
||||
return new FileSplit(uri, part, nsplit);
|
||||
}
|
||||
if (!strncmp(uri, "hdfs://", 7)) {
|
||||
utils::Error("HDFS reading is not yet supported");
|
||||
return NULL;
|
||||
#if RABIT_USE_HDFS
|
||||
return new HDFSSplit(uri, part, nsplit);
|
||||
#else
|
||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||
#endif
|
||||
}
|
||||
return new FileSplit(uri, part, nsplit);
|
||||
}
|
||||
/*!
|
||||
* \brief create an stream, the stream must be able to close
|
||||
* the underlying resources(files) when deleted
|
||||
*
|
||||
* \param uri the uri of the input, can contain hdfs prefix
|
||||
* \param mode can be 'w' or 'r' for read or write
|
||||
*/
|
||||
inline IStream *CreateStream(const char *uri, const char *mode) {
|
||||
if (!strncmp(uri, "file://", 7)) {
|
||||
return new FileStream(uri + 7, mode);
|
||||
}
|
||||
if (!strncmp(uri, "hdfs://", 7)) {
|
||||
#if RABIT_USE_HDFS
|
||||
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
|
||||
#else
|
||||
utils::Error("Please compile with RABIT_USE_HDFS=1");
|
||||
#endif
|
||||
}
|
||||
return new FileStream(uri, mode);
|
||||
}
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_IO_INL_H_
|
||||
|
||||
@ -7,12 +7,19 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include "../../include/rabit_serializable.h"
|
||||
|
||||
/*! \brief whether compile with HDFS support */
|
||||
#ifndef RABIT_USE_HDFS
|
||||
#define RABIT_USE_HDFS 0
|
||||
#endif
|
||||
|
||||
/*! \brief io interface */
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief namespace to handle input split and filesystem interfacing
|
||||
*/
|
||||
namespace io {
|
||||
typedef utils::ISeekStream ISeekStream;
|
||||
/*!
|
||||
* \brief user facing input split helper,
|
||||
* can be used to get the partition of data used by current node
|
||||
@ -50,4 +57,5 @@ inline IStream *CreateStream(const char *uri, const char *mode);
|
||||
} // namespace rabit
|
||||
|
||||
#include "./io-inl.h"
|
||||
#include "./base64-inl.h"
|
||||
#endif // RABIT_LEARN_IO_IO_H_
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#ifndef RABIT_LEARN_IO_LINE_SPLIT_H_
|
||||
#define RABIT_LEARN_IO_LINE_SPLIT_H_
|
||||
#ifndef RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||
#define RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||
/*!
|
||||
* \file line_split.h
|
||||
* \file line_split-inl.h
|
||||
* \brief base implementation of line-spliter
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -12,7 +12,7 @@
|
||||
#include <fstream>
|
||||
#include "../../include/rabit.h"
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
#include "./buffer_reader-inl.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
@ -175,43 +175,6 @@ class SingleFileSplit : public InputSplit {
|
||||
bool use_stdin_;
|
||||
bool end_of_file_;
|
||||
};
|
||||
|
||||
/*! \brief line split from normal file system */
|
||||
class FileSplit : public LineSplitBase {
|
||||
public:
|
||||
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) {
|
||||
LineSplitBase::SplitNames(&fnames_, uri, "#");
|
||||
std::vector<size_t> fsize;
|
||||
for (size_t i = 0; i < fnames_.size(); ++i) {
|
||||
if (!strncmp(fnames_[i].c_str(), "file://", 7)) {
|
||||
std::string tmp = fnames_[i].c_str() + 7;
|
||||
fnames_[i] = tmp;
|
||||
}
|
||||
fsize.push_back(GetFileSize(fnames_[i].c_str()));
|
||||
}
|
||||
LineSplitBase::Init(fsize, rank, nsplit);
|
||||
}
|
||||
virtual ~FileSplit(void) {}
|
||||
|
||||
protected:
|
||||
virtual utils::ISeekStream *GetFile(size_t file_index) {
|
||||
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
|
||||
return new FileStream(fnames_[file_index].c_str(), "rb");
|
||||
}
|
||||
// get file size
|
||||
inline static size_t GetFileSize(const char *fname) {
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
// NOTE: fseek may not be good, but serves as ok solution
|
||||
fseek(fp, 0, SEEK_END);
|
||||
size_t fsize = static_cast<size_t>(ftell(fp));
|
||||
fclose(fp);
|
||||
return fsize;
|
||||
}
|
||||
|
||||
private:
|
||||
// file names
|
||||
std::vector<std::string> fnames_;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_LINE_SPLIT_H_
|
||||
#endif // RABIT_LEARN_IO_LINE_SPLIT_INL_H_
|
||||
@ -1,102 +0,0 @@
|
||||
#ifndef RABIT_LEARN_IO_UTILS_H_
|
||||
#define RABIT_LEARN_IO_UTILS_H_
|
||||
/*!
|
||||
* \file utils.h
|
||||
* \brief some helper utils that can be used to implement IO
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
namespace rabit {
|
||||
namespace io {
|
||||
/*! \brief buffer reader of the stream that allows you to get */
|
||||
class StreamBufferReader {
|
||||
public:
|
||||
StreamBufferReader(size_t buffer_size)
|
||||
:stream_(NULL),
|
||||
read_len_(1), read_ptr_(1) {
|
||||
buffer_.resize(buffer_size);
|
||||
}
|
||||
/*!
|
||||
* \brief set input stream
|
||||
*/
|
||||
inline void set_stream(IStream *stream) {
|
||||
stream_ = stream;
|
||||
read_len_ = read_ptr_ = 1;
|
||||
}
|
||||
/*!
|
||||
* \brief allows quick read using get char
|
||||
*/
|
||||
inline char GetChar(void) {
|
||||
while (true) {
|
||||
if (read_ptr_ < read_len_) {
|
||||
return buffer_[read_ptr_++];
|
||||
} else {
|
||||
read_len_ = stream_->Read(&buffer_[0], buffer_.length());
|
||||
if (read_len_ == 0) return EOF;
|
||||
read_ptr_ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline bool AtEnd(void) const {
|
||||
return read_len_ == 0;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief the underlying stream */
|
||||
IStream *stream_;
|
||||
/*! \brief buffer to hold data */
|
||||
std::string buffer_;
|
||||
/*! \brief length of valid data in buffer */
|
||||
size_t read_len_;
|
||||
/*! \brief pointer in the buffer */
|
||||
size_t read_ptr_;
|
||||
};
|
||||
|
||||
/*! \brief implementation of file i/o stream */
|
||||
class FileStream : public utils::ISeekStream {
|
||||
public:
|
||||
explicit FileStream(const char *fname, const char *mode)
|
||||
: use_stdio(false) {
|
||||
#ifndef RABIT_STRICT_CXX98_
|
||||
if (!strcmp(fname, "stdin")) {
|
||||
use_stdio = true; fp = stdin;
|
||||
}
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
use_stdio = true; fp = stdout;
|
||||
}
|
||||
#endif
|
||||
if (!use_stdio) {
|
||||
fp = utils::FopenCheck(fname, mode);
|
||||
}
|
||||
}
|
||||
virtual ~FileStream(void) {
|
||||
this->Close();
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return std::fread(ptr, 1, size, fp);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
std::fwrite(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return std::ftell(fp);
|
||||
}
|
||||
virtual bool AtEnd(void) const {
|
||||
return feof(fp) != 0;
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL && !use_stdio) {
|
||||
std::fclose(fp); fp = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
bool use_stdio;
|
||||
};
|
||||
} // namespace io
|
||||
} // namespace rabit
|
||||
#endif // RABIT_LEARN_IO_UTILS_H_
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#include "./linear.h"
|
||||
#include "../io/io.h"
|
||||
#include "../io/base64.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace linear {
|
||||
@ -74,37 +73,37 @@ class LinearObjFunction : public solver::IObjFunction<float> {
|
||||
printf("Finishing writing to %s\n", name_pred.c_str());
|
||||
}
|
||||
inline void LoadModel(const char *fname) {
|
||||
io::FileStream fi(fname, "rb");
|
||||
IStream *fi = io::CreateStream(fname, "r");
|
||||
std::string header; header.resize(4);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
utils::Check(fi.Read(&header[0], 4) != 0, "invalid model");
|
||||
utils::Check(fi->Read(&header[0], 4) != 0, "invalid model");
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
io::Base64InStream bsin(&fi);
|
||||
io::Base64InStream bsin(fi);
|
||||
bsin.InitPosition();
|
||||
model.Load(bsin);
|
||||
return;
|
||||
} else if (header == "binf") {
|
||||
model.Load(fi);
|
||||
return;
|
||||
model.Load(*fi);
|
||||
} else {
|
||||
utils::Error("invalid model file");
|
||||
}
|
||||
delete fi;
|
||||
}
|
||||
inline void SaveModel(const char *fname,
|
||||
const float *wptr,
|
||||
bool save_base64 = false) {
|
||||
io::FileStream fo(fname, "wb");
|
||||
IStream *fo = io::CreateStream(fname, "w");
|
||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||
fo.Write("bs64\t", 5);
|
||||
io::Base64OutStream bout(&fo);
|
||||
fo->Write("bs64\t", 5);
|
||||
io::Base64OutStream bout(fo);
|
||||
model.Save(bout, wptr);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo.Write("binf", 4);
|
||||
model.Save(fo, wptr);
|
||||
fo->Write("binf", 4);
|
||||
model.Save(*fo, wptr);
|
||||
}
|
||||
delete fo;
|
||||
}
|
||||
inline void LoadData(const char *fname) {
|
||||
dtrain.Load(fname);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user