From be2ff703bca25af500ebd3ed475a4577aa02a969 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Mar 2015 17:33:51 -0700 Subject: [PATCH] allow adapting wormhole --- rabit-learn/io/io-inl.h | 31 +++++++++++++++++++++++++++++++ rabit-learn/io/io.h | 14 +++++++++++++- rabit-learn/io/line_split-inl.h | 4 ++-- rabit-learn/make/common.mk | 9 +++++++++ rabit-learn/make/config.mk | 3 +++ rabit-learn/utils/data.h | 2 +- 6 files changed, 59 insertions(+), 4 deletions(-) diff --git a/rabit-learn/io/io-inl.h b/rabit-learn/io/io-inl.h index f6823e247..19892c277 100644 --- a/rabit-learn/io/io-inl.h +++ b/rabit-learn/io/io-inl.h @@ -9,10 +9,13 @@ #include #include "./io.h" + +#if RABIT_USE_WORMHOLE == 0 #if RABIT_USE_HDFS #include "./hdfs-inl.h" #endif #include "./file-inl.h" +#endif namespace rabit { namespace io { @@ -25,6 +28,9 @@ namespace io { inline InputSplit *CreateInputSplit(const char *uri, unsigned part, unsigned nsplit) { +#if RABIT_USE_WORMHOLE + return dmlc::InputSplit::Create(uri, part, nsplit); +#else using namespace std; if (!strcmp(uri, "stdin")) { return new SingleFileSplit(uri); @@ -40,7 +46,28 @@ inline InputSplit *CreateInputSplit(const char *uri, #endif } return new LineSplitter(new FileProvider(uri), part, nsplit); +#endif } + +template +class StreamAdapter : public IStream { + public: + explicit StreamAdapter(TStream *stream) + : stream_(stream) { + } + virtual ~StreamAdapter(void) { + delete stream_; + } + virtual size_t Read(void *ptr, size_t size) { + return stream_->Read(ptr, size); + } + virtual void Write(const void *ptr, size_t size) { + stream_->Write(ptr, size); + } + private: + TStream *stream_; +}; + /*! * \brief create an stream, the stream must be able to close * the underlying resources(files) when deleted @@ -49,6 +76,9 @@ inline InputSplit *CreateInputSplit(const char *uri, * \param mode can be 'w' or 'r' for read or write */ inline IStream *CreateStream(const char *uri, const char *mode) { +#if RABIT_USE_WORMHOLE + return new StreamAdapter(dmlc::ISeekStream::Create(uri, mode)); +#else using namespace std; if (!strncmp(uri, "file://", 7)) { return new FileStream(uri + 7, mode); @@ -62,6 +92,7 @@ inline IStream *CreateStream(const char *uri, const char *mode) { #endif } return new FileStream(uri, mode); +#endif } } // namespace io } // namespace rabit diff --git a/rabit-learn/io/io.h b/rabit-learn/io/io.h index 79d2df12e..7193f0dcb 100644 --- a/rabit-learn/io/io.h +++ b/rabit-learn/io/io.h @@ -13,6 +13,13 @@ #define RABIT_USE_HDFS 0 #endif +#ifndef RABIT_USE_WORMHOLE +#define RABIT_USE_WORMHOLE 0 +#endif + +#if RABIT_USE_WORMHOLE +#include +#endif /*! \brief io interface */ namespace rabit { /*! @@ -20,6 +27,10 @@ namespace rabit { */ namespace io { /*! \brief reused ISeekStream's definition */ +#if RABIT_USE_WORMHOLE +typedef dmlc::ISeekStream ISeekStream; +typedef dmlc::InputSplit InputSplit; +#else typedef utils::ISeekStream ISeekStream; /*! * \brief user facing input split helper, @@ -33,10 +44,11 @@ class InputSplit { * \n is not included * \return true of next line was found, false if we read all the lines */ - virtual bool NextLine(std::string *out_data) = 0; + virtual bool ReadLine(std::string *out_data) = 0; /*! \brief destructor*/ virtual ~InputSplit(void) {} }; +#endif /*! * \brief create input split given a uri * \param uri the uri of the input, can contain hdfs prefix diff --git a/rabit-learn/io/line_split-inl.h b/rabit-learn/io/line_split-inl.h index 1f8ae4fdc..a4d27273d 100644 --- a/rabit-learn/io/line_split-inl.h +++ b/rabit-learn/io/line_split-inl.h @@ -51,7 +51,7 @@ class LineSplitter : public InputSplit { delete provider_; } // get next line - virtual bool NextLine(std::string *out_data) { + virtual bool ReadLine(std::string *out_data) { if (file_ptr_ >= file_ptr_end_ && offset_curr_ >= offset_end_) return false; out_data->clear(); @@ -178,7 +178,7 @@ class SingleFileSplit : public InputSplit { virtual ~SingleFileSplit(void) { if (!use_stdin_) std::fclose(fp_); } - virtual bool NextLine(std::string *out_data) { + virtual bool ReadLine(std::string *out_data) { if (end_of_file_) return false; out_data->clear(); while (true) { diff --git a/rabit-learn/make/common.mk b/rabit-learn/make/common.mk index 3431b95f6..2044b1188 100644 --- a/rabit-learn/make/common.mk +++ b/rabit-learn/make/common.mk @@ -3,6 +3,14 @@ export LDFLAGS= -L../../lib -pthread -lm -lrt export CFLAGS = -Wall -msse2 -Wno-unknown-pragmas -fPIC -I../../include +# setup opencv +ifeq ($(USE_WORMHOLE),1) + CFLAGS+= -DRABIT_USE_WORMHOLE=1 -I ../../wormhole/include + LDFLAGS+= -L../../wormhole -lwormhole +else + CFLAGS+= -DRABIT_USE_WORMHOLE=0 +endif + # setup opencv ifeq ($(USE_HDFS),1) CFLAGS+= -DRABIT_USE_HDFS=1 -I$(HADOOP_HDFS_HOME)/include -I$(JAVA_HOME)/include @@ -11,6 +19,7 @@ else CFLAGS+= -DRABIT_USE_HDFS=0 endif + .PHONY: clean all lib mpi all: $(BIN) $(MOCKBIN) diff --git a/rabit-learn/make/config.mk b/rabit-learn/make/config.mk index bd711a9cc..395114fcf 100644 --- a/rabit-learn/make/config.mk +++ b/rabit-learn/make/config.mk @@ -17,5 +17,8 @@ export MPICXX = mpicxx # whether use HDFS support during compile USE_HDFS = 1 +# whether use wormhole's io utils +USE_WORMHOLE = 0 + # path to libjvm.so LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server diff --git a/rabit-learn/utils/data.h b/rabit-learn/utils/data.h index 3ce665bb5..e72a19d51 100644 --- a/rabit-learn/utils/data.h +++ b/rabit-learn/utils/data.h @@ -56,7 +56,7 @@ struct SparseMat { data.clear(); feat_dim = 0; std::string line; - while (in->NextLine(&line)) { + while (in->ReadLine(&line)) { float label; std::istringstream ss(line); ss >> label;