This commit is contained in:
tqchen
2015-03-13 22:59:04 -07:00
parent 8cc650847a
commit f165ffbc95
4 changed files with 28 additions and 7 deletions

View File

@@ -17,8 +17,12 @@ namespace rabit {
namespace io {
class HDFSStream : public utils::ISeekStream {
public:
HDFSStream(hdfsFS fs, const char *fname, const char *mode)
: fs_(fs), at_end_(false) {
HDFSStream(hdfsFS fs,
const char *fname,
const char *mode,
bool disconnect_when_done)
: fs_(fs), at_end_(false),
disconnect_when_done_(disconnect_when_done) {
int flag;
if (!strcmp(mode, "r")) {
flag = O_RDONLY;
@@ -35,6 +39,9 @@ class HDFSStream : public utils::ISeekStream {
}
virtual ~HDFSStream(void) {
this->Close();
if (disconnect_when_done_) {
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
}
virtual size_t Read(void *ptr, size_t size) {
tSize nread = hdfsRead(fs_, fp_, ptr, size);
@@ -90,6 +97,7 @@ class HDFSStream : public utils::ISeekStream {
hdfsFS fs_;
hdfsFile fp_;
bool at_end_;
bool disconnect_when_done_;
};
/*! \brief line split from normal file system */
@@ -124,12 +132,15 @@ class HDFSSplit : public LineSplitBase {
}
LineSplitBase::Init(fsize, rank, nsplit);
}
virtual ~HDFSSplit(void) {}
virtual ~HDFSSplit(void) {
LineSplitBase::Destroy();
utils::Check(hdfsDisconnect(fs_) == 0, "hdfsDisconnect error");
}
protected:
virtual utils::ISeekStream *GetFile(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
}
private:

View File

@@ -55,7 +55,7 @@ inline IStream *CreateStream(const char *uri, const char *mode) {
}
if (!strncmp(uri, "hdfs://", 7)) {
#if RABIT_USE_HDFS
return new HDFSStream(hdfsConnect("default", 0), uri, mode);
return new HDFSStream(hdfsConnect("default", 0), uri, mode, true);
#else
utils::Error("Please compile with RABIT_USE_HDFS=1");
#endif

View File

@@ -18,7 +18,7 @@ namespace io {
class LineSplitBase : public InputSplit {
public:
virtual ~LineSplitBase() {
if (fs_ != NULL) delete fs_;
this->Destroy();
}
virtual bool NextLine(std::string *out_data) {
if (file_ptr_ >= file_ptr_end_ &&
@@ -57,6 +57,15 @@ class LineSplitBase : public InputSplit {
LineSplitBase(void)
: fs_(NULL), reader_(kBufferSize) {
}
/*!
* \brief destroy all the filesystem resources owned
* can be called by child destructor
*/
inline void Destroy(void) {
if (fs_ != NULL) {
delete fs_; fs_ = NULL;
}
}
/*!
* \brief initialize the line spliter,
* \param file_size, size of each std::FILEs