From 89244b4aec1f229b9ba1378389d4dea697389666 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 5 Apr 2015 09:56:53 -0700 Subject: [PATCH] Squashed 'subtree/rabit/' changes from 16975b4..b15f6cd b15f6cd rabit unifires with dmlc 5634ec3 ok 2dd6c2f Merge branch 'master' of ssh://github.com/dmlc/rabit 38d7f99 checkin wormhole spliter 8acb96a Merge pull request #10 from ryanzz/master 911a1f0 fixed a mistake 732d8c3 inteface changing 684ea0a inteface changing 8cb4c02 add dmlc support be2ff70 allow adapting wormhole git-subtree-dir: subtree/rabit git-subtree-split: b15f6cd2ac4ac0e530df2b0a207d26868515f2d5 --- include/dmlc/README.md | 4 + include/dmlc/io.h | 162 ++++++++++++++++++++++++++++++++ include/rabit/io.h | 15 +-- include/rabit_serializable.h | 99 ++----------------- rabit-learn/io/io-inl.h | 31 ++++++ rabit-learn/io/io.h | 14 ++- rabit-learn/io/line_split-inl.h | 4 +- rabit-learn/linear/run-yarn.sh | 2 +- rabit-learn/make/common.mk | 10 ++ rabit-learn/make/config.mk | 3 + rabit-learn/utils/data.h | 2 +- src/allreduce_base.cc | 15 +++ src/allreduce_base.h | 2 + src/allreduce_mock.h | 1 + wrapper/rabit.py | 2 +- 15 files changed, 258 insertions(+), 108 deletions(-) create mode 100644 include/dmlc/README.md create mode 100644 include/dmlc/io.h diff --git a/include/dmlc/README.md b/include/dmlc/README.md new file mode 100644 index 000000000..846cec006 --- /dev/null +++ b/include/dmlc/README.md @@ -0,0 +1,4 @@ +This folder is part of dmlc-core library, this allows rabit to use unified stream interface with other dmlc projects. + +- Since it is only interface dependency DMLC core is not required to compile rabit +- To compile project that uses dmlc-core functions, link to libdmlc.a (provided by dmlc-core) will be required. diff --git a/include/dmlc/io.h b/include/dmlc/io.h new file mode 100644 index 000000000..41bfdf4a8 --- /dev/null +++ b/include/dmlc/io.h @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file io.h + * \brief defines serializable interface of dmlc + */ +#ifndef DMLC_IO_H_ +#define DMLC_IO_H_ +#include +#include +#include + +/*! \brief namespace for dmlc */ +namespace dmlc { +/*! + * \brief interface of stream I/O for serialization + */ +class IStream { + public: + /*! + * \brief reads data from a stream + * \param ptr pointer to a memory buffer + * \param size block size + * \return the size of data read + */ + virtual size_t Read(void *ptr, size_t size) = 0; + /*! + * \brief writes data to a stream + * \param ptr pointer to a memory buffer + * \param size block size + */ + virtual void Write(const void *ptr, size_t size) = 0; + /*! \brief virtual destructor */ + virtual ~IStream(void) {} + /*! + * \brief generic factory function + * create an stream, the stream will close the underlying files + * upon deletion + * \param uri the uri of the input currently we support + * hdfs://, s3://, and file:// by default file:// will be used + * \param flag can be "w", "r", "a" + */ + static IStream *Create(const char *uri, const char* const flag); + // helper functions to write/read different data structures + /*! + * \brief writes a vector + * \param vec vector to be written/serialized + */ + template + inline void Write(const std::vector &vec); + /*! + * \brief loads a vector + * \param out_vec vector to be loaded/deserialized + * \return whether the load was successful + */ + template + inline bool Read(std::vector *out_vec); + /*! + * \brief writes a string + * \param str the string to be written/serialized + */ + inline void Write(const std::string &str); + /*! + * \brief loads a string + * \param out_str string to be loaded/deserialized + * \return whether the load/deserialization was successful + */ + inline bool Read(std::string *out_str); +}; + +/*! \brief interface of i/o stream that support seek */ +class ISeekStream: public IStream { + public: + // virtual destructor + virtual ~ISeekStream(void) {} + /*! \brief seek to certain position of the file */ + virtual void Seek(size_t pos) = 0; + /*! \brief tell the position of the stream */ + virtual size_t Tell(void) = 0; + /*! \return whether we are at end of file */ + virtual bool AtEnd(void) const = 0; +}; + +/*! \brief interface for serializable objects */ +class ISerializable { + public: + /*! + * \brief load the model from a stream + * \param fi stream where to load the model from + */ + virtual void Load(IStream &fi) = 0; + /*! + * \brief saves the model to a stream + * \param fo stream where to save the model to + */ + virtual void Save(IStream &fo) const = 0; +}; + +/*! + * \brief input split header, used to create input split on input dataset + * this class can be used to obtain filesystem invariant splits from input files + */ +class InputSplit { + public: + /*! + * \brief read next line, store into out_data + * \param out_data the string that stores the line data, \n is not included + * \return true of next line was found, false if we read all the lines + */ + virtual bool ReadLine(std::string *out_data) = 0; + /*! \brief destructor*/ + virtual ~InputSplit(void) {} + /*! + * \brief factory function: + * create input split given a uri + * \param uri the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + */ + static InputSplit* Create(const char *uri, + unsigned part_index, + unsigned num_parts); +}; + +// implementations of inline functions +template +inline void IStream::Write(const std::vector &vec) { + size_t sz = vec.size(); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&vec[0], sizeof(T) * sz); + } +} +template +inline bool IStream::Read(std::vector *out_vec) { + size_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_vec->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; + } + return true; +} +inline void IStream::Write(const std::string &str) { + size_t sz = str.length(); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&str[0], sizeof(char) * sz); + } +} +inline bool IStream::Read(std::string *out_str) { + size_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_str->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) { + return false; + } + } + return true; +} +} // namespace dmlc +#endif // DMLC_IO_H_ diff --git a/include/rabit/io.h b/include/rabit/io.h index eb1ca4414..4792d932c 100644 --- a/include/rabit/io.h +++ b/include/rabit/io.h @@ -16,19 +16,8 @@ namespace rabit { namespace utils { -/*! \brief interface of i/o stream that support seek */ -class ISeekStream: public IStream { - public: - // virtual destructor - virtual ~ISeekStream(void) {} - /*! \brief seek to certain position of the file */ - virtual void Seek(size_t pos) = 0; - /*! \brief tell the position of the stream */ - virtual size_t Tell(void) = 0; - /*! \return whether we are at end of file */ - virtual bool AtEnd(void) const = 0; -}; - +/*! \brief re-use definition of dmlc::ISeekStream */ +typedef dmlc::ISeekStream ISeekStream; /*! \brief fixed size memory buffer */ struct MemoryFixSizeBuffer : public ISeekStream { public: diff --git a/include/rabit_serializable.h b/include/rabit_serializable.h index bee125ed8..7314747c0 100644 --- a/include/rabit_serializable.h +++ b/include/rabit_serializable.h @@ -9,98 +9,19 @@ #include #include #include "./rabit/utils.h" +#include "./dmlc/io.h" + namespace rabit { /*! - * \brief interface of stream I/O, used by ISerializable - * \sa ISerializable + * \brief defines stream used in rabit + * see definition of IStream in dmlc/io.h */ -class IStream { - public: - /*! - * \brief reads data from a stream - * \param ptr pointer to a memory buffer - * \param size block size - * \return the size of data read - */ - virtual size_t Read(void *ptr, size_t size) = 0; - /*! - * \brief writes data to a stream - * \param ptr pointer to a memory buffer - * \param size block size - */ - virtual void Write(const void *ptr, size_t size) = 0; - /*! \brief virtual destructor */ - virtual ~IStream(void) {} +typedef dmlc::IStream IStream; +/*! + * \brief defines serializable objects used in rabit + * see definition of ISerializable in dmlc/io.h + */ +typedef dmlc::ISerializable ISerializable; - public: - // helper functions to write/read different data structures - /*! - * \brief writes a vector - * \param vec vector to be written/serialized - */ - template - inline void Write(const std::vector &vec) { - uint64_t sz = static_cast(vec.size()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&vec[0], sizeof(T) * sz); - } - } - /*! - * \brief loads a vector - * \param out_vec vector to be loaded/deserialized - * \return whether the load was successful - */ - template - inline bool Read(std::vector *out_vec) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - out_vec->resize(sz); - if (sz != 0) { - if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; - } - return true; - } - /*! - * \brief writes a string - * \param str the string to be written/serialized - */ - inline void Write(const std::string &str) { - uint64_t sz = static_cast(str.length()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&str[0], sizeof(char) * sz); - } - } - /*! - * \brief loads a string - * \param out_str string to be loaded/deserialized - * \return whether the load/deserialization was successful - */ - inline bool Read(std::string *out_str) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - out_str->resize(sz); - if (sz != 0) { - if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false; - } - return true; - } -}; - -/*! \brief interface for serializable objects */ -class ISerializable { - public: - /*! - * \brief load the model from a stream - * \param fi stream where to load the model from - */ - virtual void Load(IStream &fi) = 0; - /*! - * \brief saves the model to a stream - * \param fo stream where to save the model to - */ - virtual void Save(IStream &fo) const = 0; -}; } // namespace rabit #endif // RABIT_RABIT_SERIALIZABLE_H_ diff --git a/rabit-learn/io/io-inl.h b/rabit-learn/io/io-inl.h index f6823e247..b8e7562d0 100644 --- a/rabit-learn/io/io-inl.h +++ b/rabit-learn/io/io-inl.h @@ -9,10 +9,13 @@ #include #include "./io.h" + +#if RABIT_USE_WORMHOLE == 0 #if RABIT_USE_HDFS #include "./hdfs-inl.h" #endif #include "./file-inl.h" +#endif namespace rabit { namespace io { @@ -25,6 +28,9 @@ namespace io { inline InputSplit *CreateInputSplit(const char *uri, unsigned part, unsigned nsplit) { +#if RABIT_USE_WORMHOLE + return dmlc::InputSplit::Create(uri, part, nsplit); +#else using namespace std; if (!strcmp(uri, "stdin")) { return new SingleFileSplit(uri); @@ -40,7 +46,28 @@ inline InputSplit *CreateInputSplit(const char *uri, #endif } return new LineSplitter(new FileProvider(uri), part, nsplit); +#endif } + +template +class StreamAdapter : public IStream { + public: + explicit StreamAdapter(TStream *stream) + : stream_(stream) { + } + virtual ~StreamAdapter(void) { + delete stream_; + } + virtual size_t Read(void *ptr, size_t size) { + return stream_->Read(ptr, size); + } + virtual void Write(const void *ptr, size_t size) { + stream_->Write(ptr, size); + } + private: + TStream *stream_; +}; + /*! * \brief create an stream, the stream must be able to close * the underlying resources(files) when deleted @@ -49,6 +76,9 @@ inline InputSplit *CreateInputSplit(const char *uri, * \param mode can be 'w' or 'r' for read or write */ inline IStream *CreateStream(const char *uri, const char *mode) { +#if RABIT_USE_WORMHOLE + return new StreamAdapter(dmlc::IStream::Create(uri, mode)); +#else using namespace std; if (!strncmp(uri, "file://", 7)) { return new FileStream(uri + 7, mode); @@ -62,6 +92,7 @@ inline IStream *CreateStream(const char *uri, const char *mode) { #endif } return new FileStream(uri, mode); +#endif } } // namespace io } // namespace rabit diff --git a/rabit-learn/io/io.h b/rabit-learn/io/io.h index 79d2df12e..ff4b2f5ac 100644 --- a/rabit-learn/io/io.h +++ b/rabit-learn/io/io.h @@ -13,6 +13,13 @@ #define RABIT_USE_HDFS 0 #endif +#ifndef RABIT_USE_WORMHOLE +#define RABIT_USE_WORMHOLE 0 +#endif + +#if RABIT_USE_WORMHOLE +#include +#endif /*! \brief io interface */ namespace rabit { /*! @@ -20,6 +27,10 @@ namespace rabit { */ namespace io { /*! \brief reused ISeekStream's definition */ +#if RABIT_USE_WORMHOLE +typedef dmlc::ISeekStream ISeekStream; +typedef dmlc::InputSplit InputSplit; +#else typedef utils::ISeekStream ISeekStream; /*! * \brief user facing input split helper, @@ -33,10 +44,11 @@ class InputSplit { * \n is not included * \return true of next line was found, false if we read all the lines */ - virtual bool NextLine(std::string *out_data) = 0; + virtual bool ReadLine(std::string *out_data) = 0; /*! \brief destructor*/ virtual ~InputSplit(void) {} }; +#endif /*! * \brief create input split given a uri * \param uri the uri of the input, can contain hdfs prefix diff --git a/rabit-learn/io/line_split-inl.h b/rabit-learn/io/line_split-inl.h index 1f8ae4fdc..a4d27273d 100644 --- a/rabit-learn/io/line_split-inl.h +++ b/rabit-learn/io/line_split-inl.h @@ -51,7 +51,7 @@ class LineSplitter : public InputSplit { delete provider_; } // get next line - virtual bool NextLine(std::string *out_data) { + virtual bool ReadLine(std::string *out_data) { if (file_ptr_ >= file_ptr_end_ && offset_curr_ >= offset_end_) return false; out_data->clear(); @@ -178,7 +178,7 @@ class SingleFileSplit : public InputSplit { virtual ~SingleFileSplit(void) { if (!use_stdin_) std::fclose(fp_); } - virtual bool NextLine(std::string *out_data) { + virtual bool ReadLine(std::string *out_data) { if (end_of_file_) return false; out_data->clear(); while (true) { diff --git a/rabit-learn/linear/run-yarn.sh b/rabit-learn/linear/run-yarn.sh index a9d65bcb2..5e6e90b62 100755 --- a/rabit-learn/linear/run-yarn.sh +++ b/rabit-learn/linear/run-yarn.sh @@ -12,7 +12,7 @@ hadoop fs -mkdir $2/data hadoop fs -put ../data/agaricus.txt.train $2/data # submit to hadoop -../../tracker/rabit_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}" +../../wormhole/tracker/dmlc_yarn.py -n $1 --vcores 1 ./linear.rabit hdfs://$2/data/agaricus.txt.train model_out=hdfs://$2/mushroom.linear.model "${*:3}" # get the final model file hadoop fs -get $2/mushroom.linear.model ./linear.model diff --git a/rabit-learn/make/common.mk b/rabit-learn/make/common.mk index 3431b95f6..ae0a16f73 100644 --- a/rabit-learn/make/common.mk +++ b/rabit-learn/make/common.mk @@ -3,6 +3,15 @@ export LDFLAGS= -L../../lib -pthread -lm -lrt export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include +# setup opencv +ifeq ($(USE_DMLC),1) + include ../../dmlc-core/make/dmlc.mk + CFLAGS+= -DRABIT_USE_DMLC=1 -I ../../dmlc-core/include $(DMLC_CFLAGS) + LDFLAGS+= -L../../dmlc-core -ldmlc $(DMLC_LDFLAGS) +else + CFLAGS+= -DRABIT_USE_DMLC=0 +endif + # setup opencv ifeq ($(USE_HDFS),1) CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include @@ -11,6 +20,7 @@ else CFLAGS+= -DRABIT_USE_HDFS=0 endif + .PHONY: clean all lib mpi all: $(BIN) $(MOCKBIN) diff --git a/rabit-learn/make/config.mk b/rabit-learn/make/config.mk index bd711a9cc..6aa0feef6 100644 --- a/rabit-learn/make/config.mk +++ b/rabit-learn/make/config.mk @@ -17,5 +17,8 @@ export MPICXX = mpicxx # whether use HDFS support during compile USE_HDFS = 1 +# whether use dmlc's io utils +USE_DMLC = 0 + # path to libjvm.so LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server diff --git a/rabit-learn/utils/data.h b/rabit-learn/utils/data.h index 3ce665bb5..e72a19d51 100644 --- a/rabit-learn/utils/data.h +++ b/rabit-learn/utils/data.h @@ -56,7 +56,7 @@ struct SparseMat { data.clear(); feat_dim = 0; std::string line; - while (in->NextLine(&line)) { + while (in->ReadLine(&line)) { float label; std::istringstream ss(line); ss >> label; diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 0235723a6..d0eff0425 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -31,6 +31,7 @@ AllreduceBase::AllreduceBase(void) { // tracker URL task_id = "NULL"; err_link = NULL; + dmlc_role = "worker"; this->SetParam("rabit_reduce_buffer", "256MB"); // setup possible enviroment variable of intrest env_vars.push_back("rabit_task_id"); @@ -39,6 +40,12 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("rabit_reduce_ring_mincount"); env_vars.push_back("rabit_tracker_uri"); env_vars.push_back("rabit_tracker_port"); + // also include dmlc support direct variables + env_vars.push_back("DMLC_TASK_ID"); + env_vars.push_back("DMLC_ROLE"); + env_vars.push_back("DMLC_NUM_ATTEMPT"); + env_vars.push_back("DMLC_TRACKER_URI"); + env_vars.push_back("DMLC_TRACKER_PORT"); } // initialization function @@ -86,6 +93,10 @@ void AllreduceBase::Init(void) { this->SetParam("rabit_world_size", num_task); } } + if (dmlc_role != "worker") { + fprintf(stderr, "Rabit Module currently only work with dmlc worker, quit this program by exit 0\n"); + exit(0); + } // clear the setting before start reconnection this->rank = -1; //--------------------- @@ -150,6 +161,10 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); if (!strcmp(name, "rabit_task_id")) task_id = val; + if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val; + if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val); + if (!strcmp(name, "DMLC_TASK_ID")) task_id = val; + if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val; if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_reduce_ring_mincount")) { diff --git a/src/allreduce_base.h b/src/allreduce_base.h index a9eafea39..690c27d8a 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -496,6 +496,8 @@ class AllreduceBase : public IEngine { std::string host_uri; // uri of tracker std::string tracker_uri; + // role in dmlc jobs + std::string dmlc_role; // port of tracker address int tracker_port; // port of slave process diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 67f8d80dd..666acbeef 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -31,6 +31,7 @@ class AllreduceMock : public AllreduceRobust { AllreduceRobust::SetParam(name, val); // additional parameters if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val); + if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val); if (!strcmp(name, "report_stats")) report_stats = atoi(val); if (!strcmp(name, "force_local")) force_local = atoi(val); if (!strcmp(name, "mock")) { diff --git a/wrapper/rabit.py b/wrapper/rabit.py index a6c579338..6282e5cfd 100644 --- a/wrapper/rabit.py +++ b/wrapper/rabit.py @@ -87,7 +87,7 @@ def get_world_size(): """ Returns get total number of process """ - ret = rbtlib.RabitGetWorlSize() + ret = rbtlib.RabitGetWorldSize() check_err__() return ret