diff --git a/rabit-learn/io/hdfs-inl.h b/rabit-learn/io/hdfs-inl.h index cf5c3f599..a450ee32c 100644 --- a/rabit-learn/io/hdfs-inl.h +++ b/rabit-learn/io/hdfs-inl.h @@ -6,6 +6,7 @@ * \author Tianqi Chen */ #include +#include #include #include #include @@ -23,7 +24,6 @@ class HDFSStream : public ISeekStream { bool disconnect_when_done) : fs_(fs), at_end_(false), disconnect_when_done_(disconnect_when_done) { - fsbk_ = fs_; int flag = 0; if (!strcmp(mode, "r")) { flag = O_RDONLY; @@ -45,7 +45,6 @@ class HDFSStream : public ISeekStream { } } virtual size_t Read(void *ptr, size_t size) { - CheckFS(); tSize nread = hdfsRead(fs_, fp_, ptr, size); if (nread == -1) { int errsv = errno; @@ -57,7 +56,6 @@ class HDFSStream : public ISeekStream { return static_cast(nread); } virtual void Write(const void *ptr, size_t size) { - CheckFS(); const char *buf = reinterpret_cast(ptr); while (size != 0) { tSize nwrite = hdfsWrite(fs_, fp_, buf, size); @@ -70,14 +68,12 @@ class HDFSStream : public ISeekStream { } } virtual void Seek(size_t pos) { - CheckFS(); if (hdfsSeek(fs_, fp_, pos) != 0) { int errsv = errno; utils::Error("HDFSStream.Seek Error:%s", strerror(errsv)); } } virtual size_t Tell(void) { - CheckFS(); tOffset offset = hdfsTell(fs_, fp_); if (offset == -1) { int errsv = errno; @@ -89,7 +85,6 @@ class HDFSStream : public ISeekStream { return at_end_; } inline void Close(void) { - CheckFS(); if (fp_ != NULL) { if (hdfsCloseFile(fs_, fp_) == -1) { int errsv = errno; @@ -99,24 +94,26 @@ class HDFSStream : public ISeekStream { } } - private: - inline void CheckFS(void) const { - if (fs_ != fsbk_) { - rabit::TrackerPrintf("[%d] fs flag inconstent\n", rabit::GetRank()); + inline static std::string GetNameNode(void) { + const char *nn = getenv("rabit_hdfs_namenode"); + if (nn == NULL) { + return std::string("default"); + } else { + return std::string(nn); } } + private: hdfsFS fs_; hdfsFile fp_; bool at_end_; bool disconnect_when_done_; - hdfsFS fsbk_; }; /*! \brief line split from normal file system */ class HDFSProvider : public LineSplitter::IFileProvider { public: explicit HDFSProvider(const char *uri) { - fs_ = hdfsConnect("default", 0); + fs_ = hdfsConnect(HDFSStream::GetNameNode().c_str(), 0); utils::Check(fs_ != NULL, "error when connecting to default HDFS"); std::vector paths; LineSplitter::SplitNames(&paths, uri, "#"); diff --git a/rabit-learn/io/io-inl.h b/rabit-learn/io/io-inl.h index 53b24ae1d..f6823e247 100644 --- a/rabit-learn/io/io-inl.h +++ b/rabit-learn/io/io-inl.h @@ -55,7 +55,8 @@ inline IStream *CreateStream(const char *uri, const char *mode) { } if (!strncmp(uri, "hdfs://", 7)) { #if RABIT_USE_HDFS - return new HDFSStream(hdfsConnect("default", 0), uri, mode, true); + return new HDFSStream(hdfsConnect(HDFSStream::GetNameNode().c_str(), 0), + uri, mode, true); #else utils::Error("Please compile with RABIT_USE_HDFS=1"); #endif diff --git a/rabit-learn/linear/Makefile b/rabit-learn/linear/Makefile index ee76b03ce..f26d87189 100644 --- a/rabit-learn/linear/Makefile +++ b/rabit-learn/linear/Makefile @@ -5,7 +5,7 @@ else endif include $(config) -BIN = linear.rabit test.rabit +BIN = linear.rabit MOCKBIN= linear.mock MPIBIN = # objectives that makes up rabit library diff --git a/tracker/rabit_yarn.py b/tracker/rabit_yarn.py index 54b9e286e..56b9d1e71 100755 --- a/tracker/rabit_yarn.py +++ b/tracker/rabit_yarn.py @@ -61,6 +61,9 @@ parser.add_argument('-mem', '--memory_mb', default=1024, type=int, 'so that each node can occupy all the mapper slots in a machine for maximum performance') parser.add_argument('--libhdfs-opts', default='-Xmx128m', type=str, help = 'setting to be passed to libhdfs') +parser.add_argument('--name-node', default='default', type=str, + help = 'the namenode address of hdfs, libhdfs should connect to, normally leave it as default') + parser.add_argument('command', nargs='+', help = 'command for rabit program') args = parser.parse_args() @@ -118,6 +121,7 @@ def submit_yarn(nworker, worker_args, worker_env): env['rabit_memory_mb'] = str(args.memory_mb) env['rabit_world_size'] = str(args.nworker) env['rabit_hdfs_opts'] = str(args.libhdfs_opts) + env['rabit_hdfs_namenode'] = str(args.name_node) if args.files != None: for flst in args.files: