refactor fileio

This commit is contained in:
tqchen 2015-03-14 16:46:54 -07:00
parent cd9c81be91
commit 5dc843cff3
7 changed files with 107 additions and 81 deletions

View File

@ -38,6 +38,7 @@ class StreamBufferReader {
} }
} }
} }
/*! \brief whether we are reaching the end of file */
inline bool AtEnd(void) const { inline bool AtEnd(void) const {
return read_len_ == 0; return read_len_ == 0;
} }

View File

@ -66,27 +66,36 @@ class FileStream : public utils::ISeekStream {
}; };
/*! \brief line split from normal file system */ /*! \brief line split from normal file system */
class FileSplit : public LineSplitBase { class FileProvider : public LineSplitter::IFileProvider {
public: public:
explicit FileSplit(const char *uri, unsigned rank, unsigned nsplit) { explicit FileProvider(const char *uri) {
LineSplitBase::SplitNames(&fnames_, uri, "#"); LineSplitter::SplitNames(&fnames_, uri, "#");
std::vector<size_t> fsize; std::vector<size_t> fsize;
for (size_t i = 0; i < fnames_.size(); ++i) { for (size_t i = 0; i < fnames_.size(); ++i) {
if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) { if (!std::strncmp(fnames_[i].c_str(), "file://", 7)) {
std::string tmp = fnames_[i].c_str() + 7; std::string tmp = fnames_[i].c_str() + 7;
fnames_[i] = tmp; fnames_[i] = tmp;
} }
fsize.push_back(GetFileSize(fnames_[i].c_str())); size_t fz = GetFileSize(fnames_[i].c_str());
if (fz != 0) {
fsize_.push_back(fz);
}
} }
LineSplitBase::Init(fsize, rank, nsplit);
} }
virtual ~FileSplit(void) {} // destrucor
virtual ~FileProvider(void) {}
protected: virtual utils::ISeekStream *Open(size_t file_index) {
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb"); return new FileStream(fnames_[file_index].c_str(), "rb");
} }
virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
}
private:
// file sizes
std::vector<size_t> fsize_;
// file names
std::vector<std::string> fnames_;
// get file size // get file size
inline static size_t GetFileSize(const char *fname) { inline static size_t GetFileSize(const char *fname) {
std::FILE *fp = utils::FopenCheck(fname, "rb"); std::FILE *fp = utils::FopenCheck(fname, "rb");
@ -96,10 +105,6 @@ class FileSplit : public LineSplitBase {
std::fclose(fp); std::fclose(fp);
return fsize; return fsize;
} }
private:
// file names
std::vector<std::string> fnames_;
}; };
} // namespace io } // namespace io
} // namespace rabit } // namespace rabit

View File

@ -15,7 +15,7 @@
/*! \brief io interface */ /*! \brief io interface */
namespace rabit { namespace rabit {
namespace io { namespace io {
class HDFSStream : public utils::ISeekStream { class HDFSStream : public ISeekStream {
public: public:
HDFSStream(hdfsFS fs, HDFSStream(hdfsFS fs,
const char *fname, const char *fname,
@ -93,7 +93,7 @@ class HDFSStream : public utils::ISeekStream {
} }
} }
private: private:
hdfsFS fs_; hdfsFS fs_;
hdfsFile fp_; hdfsFile fp_;
bool at_end_; bool at_end_;
@ -101,15 +101,14 @@ class HDFSStream : public utils::ISeekStream {
}; };
/*! \brief line split from normal file system */ /*! \brief line split from normal file system */
class HDFSSplit : public LineSplitBase { class HDFSProvider : public LineSplitter::IFileProvider {
public: public:
explicit HDFSSplit(const char *uri, unsigned rank, unsigned nsplit) { explicit HDFSProvider(const char *uri) {
fs_ = hdfsConnect("default", 0); fs_ = hdfsConnect("default", 0);
utils::Check(fs_ != NULL, "error when connecting to default HDFS"); utils::Check(fs_ != NULL, "error when connecting to default HDFS");
std::vector<std::string> paths; std::vector<std::string> paths;
LineSplitBase::SplitNames(&paths, uri, "#"); LineSplitter::SplitNames(&paths, uri, "#");
// get the files // get the files
std::vector<size_t> fsize;
for (size_t i = 0; i < paths.size(); ++i) { for (size_t i = 0; i < paths.size(); ++i) {
hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str()); hdfsFileInfo *info = hdfsGetPathInfo(fs_, paths[i].c_str());
utils::Check(info != NULL, "path %s do not exist", paths[i].c_str()); utils::Check(info != NULL, "path %s do not exist", paths[i].c_str());
@ -118,34 +117,37 @@ class HDFSSplit : public LineSplitBase {
hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry); hdfsFileInfo *files = hdfsListDirectory(fs_, info->mName, &nentry);
utils::Check(files != NULL, "error when ListDirectory %s", info->mName); utils::Check(files != NULL, "error when ListDirectory %s", info->mName);
for (int i = 0; i < nentry; ++i) { for (int i = 0; i < nentry; ++i) {
if (files[i].mKind == 'F') { if (files[i].mKind == 'F' && files[i].mSize != 0) {
fsize.push_back(files[i].mSize); fsize_.push_back(files[i].mSize);
fnames_.push_back(std::string(files[i].mName)); fnames_.push_back(std::string(files[i].mName));
} }
} }
hdfsFreeFileInfo(files, nentry); hdfsFreeFileInfo(files, nentry);
} else { } else {
fsize.push_back(info->mSize); if (info->mSize != 0) {
fnames_.push_back(std::string(info->mName)); fsize_.push_back(info->mSize);
fnames_.push_back(std::string(info->mName));
}
} }
hdfsFreeFileInfo(info, 1); hdfsFreeFileInfo(info, 1);
} }
LineSplitBase::Init(fsize, rank, nsplit);
} }
virtual ~HDFSSplit(void) { virtual ~HDFSProvider(void) {
LineSplitBase::Destroy();
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error"); utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
virtual const std::vector<size_t> &FileSize(void) const {
return fsize_;
} }
virtual ISeekStream *Open(size_t file_index) {
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false); return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
} }
private: private:
// hdfs handle // hdfs handle
hdfsFS fs_; hdfsFS fs_;
// file sizes
std::vector<size_t> fsize_;
// file names // file names
std::vector<std::string> fnames_; std::vector<std::string> fnames_;
}; };

View File

@ -30,16 +30,16 @@ inline InputSplit *CreateInputSplit(const char *uri,
return new SingleFileSplit(uri); return new SingleFileSplit(uri);
} }
if (!strncmp(uri, "file://", 7)) { if (!strncmp(uri, "file://", 7)) {
return new FileSplit(uri, part, nsplit); return new LineSplitter(new FileProvider(uri), part, nsplit);
} }
if (!strncmp(uri, "hdfs://", 7)) { if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS #if RABIT_USE_HDFS
return new HDFSSplit(uri, part, nsplit); return new LineSplitter(new HDFSProvider(uri), part, nsplit);
#else #else
utils::Error("Please compile with RABIT_USE_HDFS=1"); utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif #endif
} }
return new FileSplit(uri, part, nsplit); return new LineSplitter(new FileProvider(uri), part, nsplit);
} }
/*! /*!
* \brief create an stream, the stream must be able to close * \brief create an stream, the stream must be able to close

View File

@ -19,6 +19,7 @@ namespace rabit {
* \brief namespace to handle input split and filesystem interfacing * \brief namespace to handle input split and filesystem interfacing
*/ */
namespace io { namespace io {
/*! \brief reused ISeekStream's definition */
typedef utils::ISeekStream ISeekStream; typedef utils::ISeekStream ISeekStream;
/*! /*!
* \brief user facing input split helper, * \brief user facing input split helper,

View File

@ -15,11 +15,42 @@
namespace rabit { namespace rabit {
namespace io { namespace io {
class LineSplitBase : public InputSplit {
/*! \brief class that split the files by line */
class LineSplitter : public InputSplit {
public: public:
virtual ~LineSplitBase() { class IFileProvider {
this->Destroy(); public:
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of the stream
* the seek stream's resource can be freed by calling delete
*/
virtual ISeekStream *Open(size_t file_index) = 0;
/*!
* \return const reference to size of each files
*/
virtual const std::vector<size_t> &FileSize(void) const = 0;
// virtual destructor
virtual ~IFileProvider() {}
};
// constructor
explicit LineSplitter(IFileProvider *provider,
unsigned rank,
unsigned nsplit)
: provider_(provider), fs_(NULL),
reader_(kBufferSize) {
this->Init(provider_->FileSize(), rank, nsplit);
} }
// destructor
virtual ~LineSplitter() {
if (fs_ != NULL) {
delete fs_; fs_ = NULL;
}
// delete provider after destructing the streams
delete provider_;
}
// get next line
virtual bool NextLine(std::string *out_data) { virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ && if (file_ptr_ >= file_ptr_end_ &&
offset_curr_ >= offset_end_) return false; offset_curr_ >= offset_end_) return false;
@ -29,15 +60,15 @@ class LineSplitBase : public InputSplit {
if (reader_.AtEnd()) { if (reader_.AtEnd()) {
if (out_data->length() != 0) return true; if (out_data->length() != 0) return true;
file_ptr_ += 1; file_ptr_ += 1;
if (offset_curr_ >= offset_end_) return false;
if (offset_curr_ != file_offset_[file_ptr_]) { if (offset_curr_ != file_offset_[file_ptr_]) {
utils::Error("warning:std::FILE size not calculated correctly\n"); utils::Error("warning: FILE size not calculated correctly\n");
offset_curr_ = file_offset_[file_ptr_]; offset_curr_ = file_offset_[file_ptr_];
} }
if (offset_curr_ >= offset_end_) return false;
utils::Assert(file_ptr_ + 1 < file_offset_.size(), utils::Assert(file_ptr_ + 1 < file_offset_.size(),
"boundary check"); "boundary check");
delete fs_; delete fs_;
fs_ = this->GetFile(file_ptr_); fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_); reader_.set_stream(fs_);
} else { } else {
++offset_curr_; ++offset_curr_;
@ -51,24 +82,27 @@ class LineSplitBase : public InputSplit {
} }
} }
} }
protected:
// constructor
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
}
/*! /*!
* \brief destroy all the filesystem resources owned * \brief split names given
* can be called by child destructor * \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/ */
inline void Destroy(void) { inline static void SplitNames(std::vector<std::string> *out_fname,
if (fs_ != NULL) { const char *uri_,
delete fs_; fs_ = NULL; const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
} }
} }
private:
/*! /*!
* \brief initialize the line spliter, * \brief initialize the line spliter,
* \param file_size, size of each std::FILEs * \param file_size, size of each files
* \param rank the current rank of the data * \param rank the current rank of the data
* \param nsplit number of split we will divide the data into * \param nsplit number of split we will divide the data into
*/ */
@ -91,7 +125,7 @@ class LineSplitBase : public InputSplit {
file_ptr_end_ = std::upper_bound(file_offset_.begin(), file_ptr_end_ = std::upper_bound(file_offset_.begin(),
file_offset_.end(), file_offset_.end(),
offset_end_) - file_offset_.begin() - 1; offset_end_) - file_offset_.begin() - 1;
fs_ = GetFile(file_ptr_); fs_ = provider_->Open(file_ptr_);
reader_.set_stream(fs_); reader_.set_stream(fs_);
// try to set the starting position correctly // try to set the starting position correctly
if (file_offset_[file_ptr_] != offset_begin_) { if (file_offset_[file_ptr_] != offset_begin_) {
@ -103,33 +137,15 @@ class LineSplitBase : public InputSplit {
} }
} }
} }
/*!
* \brief get the seek stream of given file_index
* \return the corresponding seek stream at head of std::FILE
*/
virtual utils::ISeekStream *GetFile(size_t file_index) = 0;
/*!
* \brief split names given
* \param out_fname output std::FILE names
* \param uri_ the iput uri std::FILE
* \param dlm deliminetr
*/
inline static void SplitNames(std::vector<std::string> *out_fname,
const char *uri_,
const char *dlm) {
std::string uri = uri_;
char *p = std::strtok(BeginPtr(uri), dlm);
while (p != NULL) {
out_fname->push_back(std::string(p));
p = std::strtok(NULL, dlm);
}
}
private: private:
/*! \brief FileProvider */
IFileProvider *provider_;
/*! \brief current input stream */ /*! \brief current input stream */
utils::ISeekStream *fs_; utils::ISeekStream *fs_;
/*! \brief std::FILE pointer of which std::FILE to read on */ /*! \brief file pointer of which file to read on */
size_t file_ptr_; size_t file_ptr_;
/*! \brief std::FILE pointer where the end of std::FILE lies */ /*! \brief file pointer where the end of file lies */
size_t file_ptr_end_; size_t file_ptr_end_;
/*! \brief get the current offset */ /*! \brief get the current offset */
size_t offset_curr_; size_t offset_curr_;
@ -137,7 +153,7 @@ class LineSplitBase : public InputSplit {
size_t offset_begin_; size_t offset_begin_;
/*! \brief end of the offset */ /*! \brief end of the offset */
size_t offset_end_; size_t offset_end_;
/*! \brief byte-offset of each std::FILE */ /*! \brief byte-offset of each file */
std::vector<size_t> file_offset_; std::vector<size_t> file_offset_;
/*! \brief buffer reader */ /*! \brief buffer reader */
StreamBufferReader reader_; StreamBufferReader reader_;

View File

@ -206,21 +206,22 @@ int main(int argc, char *argv[]) {
rabit::Finalize(); rabit::Finalize();
return 0; return 0;
} }
rabit::linear::LinearObjFunction linear; rabit::linear::LinearObjFunction *linear = new rabit::linear::LinearObjFunction();
if (!strcmp(argv[1], "stdin")) { if (!strcmp(argv[1], "stdin")) {
linear.LoadData(argv[1]); linear->LoadData(argv[1]);
rabit::Init(argc, argv); rabit::Init(argc, argv);
} else { } else {
rabit::Init(argc, argv); rabit::Init(argc, argv);
linear.LoadData(argv[1]); linear->LoadData(argv[1]);
} }
for (int i = 2; i < argc; ++i) { for (int i = 2; i < argc; ++i) {
char name[256], val[256]; char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
linear.SetParam(name, val); linear->SetParam(name, val);
} }
} }
linear.Run(); linear->Run();
delete linear;
rabit::Finalize(); rabit::Finalize();
return 0; return 0;
} }