diff --git a/doc/Doxyfile b/doc/Doxyfile index 2e1af0286..694bc35d3 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -95,7 +95,7 @@ WARN_LOGFILE = #--------------------------------------------------------------------------- # configuration options related to the input files #--------------------------------------------------------------------------- -INPUT = +INPUT = . dmlc INPUT_ENCODING = UTF-8 FILE_PATTERNS = RECURSIVE = NO diff --git a/doc/mkdoc.sh b/doc/mkdoc.sh index 4bc0284c3..181e280fb 100755 --- a/doc/mkdoc.sh +++ b/doc/mkdoc.sh @@ -1,4 +1,4 @@ #!/bin/bash cd ../include -doxygen ../doc/Doxyfile +doxygen ../doc/Doxyfile cd ../doc diff --git a/include/dmlc/io.h b/include/dmlc/io.h index 41bfdf4a8..e6a6ee566 100644 --- a/include/dmlc/io.h +++ b/include/dmlc/io.h @@ -14,7 +14,7 @@ namespace dmlc { /*! * \brief interface of stream I/O for serialization */ -class IStream { +class Stream { public: /*! * \brief reads data from a stream @@ -30,7 +30,7 @@ class IStream { */ virtual void Write(const void *ptr, size_t size) = 0; /*! \brief virtual destructor */ - virtual ~IStream(void) {} + virtual ~Stream(void) {} /*! * \brief generic factory function * create an stream, the stream will close the underlying files @@ -38,8 +38,9 @@ class IStream { * \param uri the uri of the input currently we support * hdfs://, s3://, and file:// by default file:// will be used * \param flag can be "w", "r", "a" + * \return a created stream */ - static IStream *Create(const char *uri, const char* const flag); + static Stream *Create(const char *uri, const char* const flag); // helper functions to write/read different data structures /*! * \brief writes a vector @@ -68,10 +69,10 @@ class IStream { }; /*! \brief interface of i/o stream that support seek */ -class ISeekStream: public IStream { +class SeekStream: public Stream { public: // virtual destructor - virtual ~ISeekStream(void) {} + virtual ~SeekStream(void) {} /*! \brief seek to certain position of the file */ virtual void Seek(size_t pos) = 0; /*! \brief tell the position of the stream */ @@ -81,18 +82,18 @@ class ISeekStream: public IStream { }; /*! \brief interface for serializable objects */ -class ISerializable { +class Serializable { public: /*! * \brief load the model from a stream * \param fi stream where to load the model from */ - virtual void Load(IStream &fi) = 0; + virtual void Load(Stream *fi) = 0; /*! * \brief saves the model to a stream * \param fo stream where to save the model to */ - virtual void Save(IStream &fo) const = 0; + virtual void Save(Stream *fo) const = 0; }; /*! @@ -115,6 +116,7 @@ class InputSplit { * \param uri the uri of the input, can contain hdfs prefix * \param part_index the part id of current input * \param num_parts total number of splits + * \return a created input split */ static InputSplit* Create(const char *uri, unsigned part_index, @@ -123,7 +125,7 @@ class InputSplit { // implementations of inline functions template -inline void IStream::Write(const std::vector &vec) { +inline void Stream::Write(const std::vector &vec) { size_t sz = vec.size(); this->Write(&sz, sizeof(sz)); if (sz != 0) { @@ -131,7 +133,7 @@ inline void IStream::Write(const std::vector &vec) { } } template -inline bool IStream::Read(std::vector *out_vec) { +inline bool Stream::Read(std::vector *out_vec) { size_t sz; if (this->Read(&sz, sizeof(sz)) == 0) return false; out_vec->resize(sz); @@ -140,14 +142,14 @@ inline bool IStream::Read(std::vector *out_vec) { } return true; } -inline void IStream::Write(const std::string &str) { +inline void Stream::Write(const std::string &str) { size_t sz = str.length(); this->Write(&sz, sizeof(sz)); if (sz != 0) { this->Write(&str[0], sizeof(char) * sz); } } -inline bool IStream::Read(std::string *out_str) { +inline bool Stream::Read(std::string *out_str) { size_t sz; if (this->Read(&sz, sizeof(sz)) == 0) return false; out_str->resize(sz); diff --git a/include/rabit.h b/include/rabit.h index 7e3b88cdf..824b454bb 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -16,7 +16,7 @@ #if __cplusplus >= 201103L #include #endif // C++11 -// contains definition of ISerializable +// contains definition of Serializable #include "./rabit_serializable.h" // engine definition of rabit, defines internal implementation // to use rabit interface, there is no need to read engine.h @@ -183,8 +183,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, * * \sa CheckPoint, VersionNumber */ -inline int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model = NULL); +inline int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = NULL); /*! * \brief checkpoints the model, meaning a stage of execution has finished. * every time we call check point, a version number will be increased by one @@ -199,8 +199,8 @@ inline int LoadCheckPoint(ISerializable *global_model, * So, only CheckPoint with the global_model if possible * \sa LoadCheckPoint, VersionNumber */ -inline void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL); +inline void CheckPoint(const Serializable *global_model, + const Serializable *local_model = NULL); /*! * \brief This function can be used to replace CheckPoint for global_model only, * when certain condition is met (see detailed explanation). @@ -222,7 +222,7 @@ inline void CheckPoint(const ISerializable *global_model, * is the same in every node * \sa LoadCheckPoint, CheckPoint, VersionNumber */ -inline void LazyCheckPoint(const ISerializable *global_model); +inline void LazyCheckPoint(const Serializable *global_model); /*! * \return version number of the current stored model, * which means how many calls to CheckPoint we made so far diff --git a/include/rabit/engine.h b/include/rabit/engine.h index e0395cdcd..a2f5da25b 100644 --- a/include/rabit/engine.h +++ b/include/rabit/engine.h @@ -94,8 +94,8 @@ class IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model = NULL) = 0; + virtual int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = NULL) = 0; /*! * \brief checkpoints the model, meaning a stage of execution was finished * every time we call check point, a version number increases by ones @@ -112,8 +112,8 @@ class IEngine { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL) = 0; + virtual void CheckPoint(const Serializable *global_model, + const Serializable *local_model = NULL) = 0; /*! * \brief This function can be used to replace CheckPoint for global_model only, * when certain condition is met (see detailed explanation). @@ -134,7 +134,7 @@ class IEngine { * is the same in every node * \sa LoadCheckPoint, CheckPoint, VersionNumber */ - virtual void LazyCheckPoint(const ISerializable *global_model) = 0; + virtual void LazyCheckPoint(const Serializable *global_model) = 0; /*! * \return version number of the current stored model, * which means how many calls to CheckPoint we made so far diff --git a/include/rabit/io.h b/include/rabit/io.h index 4792d932c..a0eb0adb8 100644 --- a/include/rabit/io.h +++ b/include/rabit/io.h @@ -16,10 +16,10 @@ namespace rabit { namespace utils { -/*! \brief re-use definition of dmlc::ISeekStream */ -typedef dmlc::ISeekStream ISeekStream; +/*! \brief re-use definition of dmlc::SeekStream */ +typedef dmlc::SeekStream SeekStream; /*! \brief fixed size memory buffer */ -struct MemoryFixSizeBuffer : public ISeekStream { +struct MemoryFixSizeBuffer : public SeekStream { public: MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) : p_buffer_(reinterpret_cast(p_buffer)), @@ -61,7 +61,7 @@ struct MemoryFixSizeBuffer : public ISeekStream { }; // class MemoryFixSizeBuffer /*! \brief a in memory buffer that can be read and write as stream interface */ -struct MemoryBufferStream : public ISeekStream { +struct MemoryBufferStream : public SeekStream { public: explicit MemoryBufferStream(std::string *p_buffer) : p_buffer_(p_buffer) { diff --git a/include/rabit/rabit-inl.h b/include/rabit/rabit-inl.h index 21d15d9e1..97c43767d 100644 --- a/include/rabit/rabit-inl.h +++ b/include/rabit/rabit-inl.h @@ -178,17 +178,17 @@ inline void TrackerPrintf(const char *fmt, ...) { } #endif // load latest check point -inline int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model) { +inline int LoadCheckPoint(Serializable *global_model, + Serializable *local_model) { return engine::GetEngine()->LoadCheckPoint(global_model, local_model); } // checkpoint the model, meaning we finished a stage of execution -inline void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model) { +inline void CheckPoint(const Serializable *global_model, + const Serializable *local_model) { engine::GetEngine()->CheckPoint(global_model, local_model); } // lazy checkpoint the model, only remember the pointer to global_model -inline void LazyCheckPoint(const ISerializable *global_model) { +inline void LazyCheckPoint(const Serializable *global_model) { engine::GetEngine()->LazyCheckPoint(global_model); } // return the version number of currently stored model diff --git a/include/rabit_serializable.h b/include/rabit_serializable.h index 7314747c0..40266575b 100644 --- a/include/rabit_serializable.h +++ b/include/rabit_serializable.h @@ -14,14 +14,14 @@ namespace rabit { /*! * \brief defines stream used in rabit - * see definition of IStream in dmlc/io.h + * see definition of Stream in dmlc/io.h */ -typedef dmlc::IStream IStream; +typedef dmlc::Stream Stream; /*! * \brief defines serializable objects used in rabit - * see definition of ISerializable in dmlc/io.h + * see definition of Serializable in dmlc/io.h */ -typedef dmlc::ISerializable ISerializable; +typedef dmlc::Serializable Serializable; } // namespace rabit #endif // RABIT_RABIT_SERIALIZABLE_H_ diff --git a/rabit-learn/io/base64-inl.h b/rabit-learn/io/base64-inl.h index 7b0154c0d..61581d888 100644 --- a/rabit-learn/io/base64-inl.h +++ b/rabit-learn/io/base64-inl.h @@ -33,9 +33,9 @@ static const char EncodeTable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; } // namespace base64 /*! \brief the stream that reads from base64, note we take from file pointers */ -class Base64InStream: public IStream { +class Base64InStream: public Stream { public: - explicit Base64InStream(IStream *fs) : reader_(256) { + explicit Base64InStream(Stream *fs) : reader_(256) { reader_.set_stream(fs); num_prev = 0; tmp_ch = 0; } @@ -147,9 +147,9 @@ class Base64InStream: public IStream { static const bool kStrictCheck = false; }; /*! \brief the stream that write to base64, note we take from file pointers */ -class Base64OutStream: public IStream { +class Base64OutStream: public Stream { public: - explicit Base64OutStream(IStream *fp) : fp(fp) { + explicit Base64OutStream(Stream *fp) : fp(fp) { buf_top = 0; } virtual void Write(const void *ptr, size_t size) { @@ -198,7 +198,7 @@ class Base64OutStream: public IStream { } private: - IStream *fp; + Stream *fp; int buf_top; unsigned char buf[4]; std::string out_buf; diff --git a/rabit-learn/io/buffer_reader-inl.h b/rabit-learn/io/buffer_reader-inl.h index c887c5013..78017cb99 100644 --- a/rabit-learn/io/buffer_reader-inl.h +++ b/rabit-learn/io/buffer_reader-inl.h @@ -20,7 +20,7 @@ class StreamBufferReader { /*! * \brief set input stream */ - inline void set_stream(IStream *stream) { + inline void set_stream(Stream *stream) { stream_ = stream; read_len_ = read_ptr_ = 1; } @@ -45,7 +45,7 @@ class StreamBufferReader { private: /*! \brief the underlying stream */ - IStream *stream_; + Stream *stream_; /*! \brief buffer to hold data */ std::string buffer_; /*! \brief length of valid data in buffer */ diff --git a/rabit-learn/io/file-inl.h b/rabit-learn/io/file-inl.h index 6eaa62b33..0495ecb32 100644 --- a/rabit-learn/io/file-inl.h +++ b/rabit-learn/io/file-inl.h @@ -15,7 +15,7 @@ namespace rabit { namespace io { /*! \brief implementation of file i/o stream */ -class FileStream : public utils::ISeekStream { +class FileStream : public utils::SeekStream { public: explicit FileStream(const char *fname, const char *mode) : use_stdio(false) { @@ -84,7 +84,7 @@ class FileProvider : public LineSplitter::IFileProvider { } // destrucor virtual ~FileProvider(void) {} - virtual utils::ISeekStream *Open(size_t file_index) { + virtual utils::SeekStream *Open(size_t file_index) { utils::Assert(file_index < fnames_.size(), "file index exceed bound"); return new FileStream(fnames_[file_index].c_str(), "rb"); } diff --git a/rabit-learn/io/hdfs-inl.h b/rabit-learn/io/hdfs-inl.h index a450ee32c..5d28c5397 100644 --- a/rabit-learn/io/hdfs-inl.h +++ b/rabit-learn/io/hdfs-inl.h @@ -16,7 +16,7 @@ /*! \brief io interface */ namespace rabit { namespace io { -class HDFSStream : public ISeekStream { +class HDFSStream : public SeekStream { public: HDFSStream(hdfsFS fs, const char *fname, @@ -147,7 +147,7 @@ class HDFSProvider : public LineSplitter::IFileProvider { virtual const std::vector &FileSize(void) const { return fsize_; } - virtual ISeekStream *Open(size_t file_index) { + virtual SeekStream *Open(size_t file_index) { utils::Assert(file_index < fnames_.size(), "file index exceed bound"); return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false); } diff --git a/rabit-learn/io/io-inl.h b/rabit-learn/io/io-inl.h index b8e7562d0..95897ac09 100644 --- a/rabit-learn/io/io-inl.h +++ b/rabit-learn/io/io-inl.h @@ -50,7 +50,7 @@ inline InputSplit *CreateInputSplit(const char *uri, } template -class StreamAdapter : public IStream { +class StreamAdapter : public Stream { public: explicit StreamAdapter(TStream *stream) : stream_(stream) { @@ -75,9 +75,9 @@ class StreamAdapter : public IStream { * \param uri the uri of the input, can contain hdfs prefix * \param mode can be 'w' or 'r' for read or write */ -inline IStream *CreateStream(const char *uri, const char *mode) { +inline Stream *CreateStream(const char *uri, const char *mode) { #if RABIT_USE_WORMHOLE - return new StreamAdapter(dmlc::IStream::Create(uri, mode)); + return new StreamAdapter(dmlc::Stream::Create(uri, mode)); #else using namespace std; if (!strncmp(uri, "file://", 7)) { diff --git a/rabit-learn/io/io.h b/rabit-learn/io/io.h index ff4b2f5ac..dd766bd6d 100644 --- a/rabit-learn/io/io.h +++ b/rabit-learn/io/io.h @@ -26,12 +26,12 @@ namespace rabit { * \brief namespace to handle input split and filesystem interfacing */ namespace io { -/*! \brief reused ISeekStream's definition */ +/*! \brief reused SeekStream's definition */ #if RABIT_USE_WORMHOLE -typedef dmlc::ISeekStream ISeekStream; +typedef dmlc::SeekStream SeekStream; typedef dmlc::InputSplit InputSplit; #else -typedef utils::ISeekStream ISeekStream; +typedef utils::SeekStream SeekStream; /*! * \brief user facing input split helper, * can be used to get the partition of data used by current node @@ -65,7 +65,7 @@ inline InputSplit *CreateInputSplit(const char *uri, * \param uri the uri of the input, can contain hdfs prefix * \param mode can be 'w' or 'r' for read or write */ -inline IStream *CreateStream(const char *uri, const char *mode); +inline Stream *CreateStream(const char *uri, const char *mode); } // namespace io } // namespace rabit diff --git a/rabit-learn/io/line_split-inl.h b/rabit-learn/io/line_split-inl.h index a4d27273d..7e2d9ff87 100644 --- a/rabit-learn/io/line_split-inl.h +++ b/rabit-learn/io/line_split-inl.h @@ -26,7 +26,7 @@ class LineSplitter : public InputSplit { * \return the corresponding seek stream at head of the stream * the seek stream's resource can be freed by calling delete */ - virtual ISeekStream *Open(size_t file_index) = 0; + virtual SeekStream *Open(size_t file_index) = 0; /*! * \return const reference to size of each files */ @@ -142,7 +142,7 @@ class LineSplitter : public InputSplit { /*! \brief FileProvider */ IFileProvider *provider_; /*! \brief current input stream */ - utils::ISeekStream *fs_; + utils::SeekStream *fs_; /*! \brief file pointer of which file to read on */ size_t file_ptr_; /*! \brief file pointer where the end of file lies */ diff --git a/rabit-learn/kmeans/kmeans.cc b/rabit-learn/kmeans/kmeans.cc index 1bff93d34..8a9fbaf71 100644 --- a/rabit-learn/kmeans/kmeans.cc +++ b/rabit-learn/kmeans/kmeans.cc @@ -7,22 +7,54 @@ using namespace rabit; +// simple dense matrix, mshadow or Eigen matrix was better +// this was was OK +struct Matrix { + inline void Init(size_t nrow, size_t ncol, float v = 0.0f) { + this->nrow = nrow; + this->ncol = ncol; + data.resize(nrow * ncol); + std::fill(data.begin(), data.end(), v); + } + inline float *operator[](size_t i) { + return &data[0] + i * ncol; + } + inline const float *operator[](size_t i) const { + return &data[0] + i * ncol; + } + inline void Print(utils::Stream *fo) { + for (size_t i = 0; i < data.size(); ++i) { + std::ostringstream ss; + ss << data[i]; + if ((i+1) % ncol == 0) { + ss << '\n'; + } else { + ss << ' '; + } + } + std::string s = ss.str(); + } + // number of data + size_t nrow, ncol; + std::vector data; +}; + // kmeans model -class Model : public rabit::ISerializable { +class Model : public rabit::Serializable { public: // matrix of centroids Matrix centroids; // load from stream - virtual void Load(rabit::IStream &fi) { - fi.Read(¢roids.nrow, sizeof(centroids.nrow)); - fi.Read(¢roids.ncol, sizeof(centroids.ncol)); - fi.Read(¢roids.data); + virtual void Load(rabit::Stream *fi) { + fi->Read(¢roids.nrow, sizeof(centroids.nrow)); + fi->Read(¢roids.ncol, sizeof(centroids.ncol)); + fi->Read(¢roids.data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::IStream &fo) const { - fo.Write(¢roids.nrow, sizeof(centroids.nrow)); - fo.Write(¢roids.ncol, sizeof(centroids.ncol)); - fo.Write(centroids.data); + virtual void Save(rabit::Stream *fo) const { + fo->Write(¢roids.nrow, sizeof(centroids.nrow)); + fo->Write(¢roids.ncol, sizeof(centroids.ncol)); + fo->Write(centroids.data); } virtual void InitModel(unsigned num_cluster, unsigned feat_dim) { centroids.Init(num_cluster, feat_dim); @@ -153,7 +185,7 @@ int main(int argc, char *argv[]) { } } model.Normalize(); - rabit::CheckPoint(&model); + rabit::LazyCheckPoint(&model); } // output the model file to somewhere if (rabit::GetRank() == 0) { diff --git a/rabit-learn/linear/linear.cc b/rabit-learn/linear/linear.cc index a29c20ca7..a6a41a0fa 100644 --- a/rabit-learn/linear/linear.cc +++ b/rabit-learn/linear/linear.cc @@ -75,7 +75,7 @@ class LinearObjFunction : public solver::IObjFunction { printf("Finishing writing to %s\n", name_pred.c_str()); } inline void LoadModel(const char *fname) { - IStream *fi = io::CreateStream(fname, "r"); + Stream *fi = io::CreateStream(fname, "r"); std::string header; header.resize(4); // check header for different binary encode // can be base64 or binary @@ -84,9 +84,9 @@ class LinearObjFunction : public solver::IObjFunction { if (header == "bs64") { io::Base64InStream bsin(fi); bsin.InitPosition(); - model.Load(bsin); + model.Load(&bsin); } else if (header == "binf") { - model.Load(*fi); + model.Load(fi); } else { utils::Error("invalid model file"); } @@ -95,15 +95,15 @@ class LinearObjFunction : public solver::IObjFunction { inline void SaveModel(const char *fname, const float *wptr, bool save_base64 = false) { - IStream *fo = io::CreateStream(fname, "w"); + Stream *fo = io::CreateStream(fname, "w"); if (save_base64 != 0 || !strcmp(fname, "stdout")) { fo->Write("bs64\t", 5); io::Base64OutStream bout(fo); - model.Save(bout, wptr); + model.Save(&bout, wptr); bout.Finish('\n'); } else { fo->Write("binf", 4); - model.Save(*fo, wptr); + model.Save(fo, wptr); } delete fo; } @@ -128,11 +128,11 @@ class LinearObjFunction : public solver::IObjFunction { } } // load model - virtual void Load(rabit::IStream &fi) { - fi.Read(&model.param, sizeof(model.param)); + virtual void Load(rabit::Stream *fi) { + fi->Read(&model.param, sizeof(model.param)); } - virtual void Save(rabit::IStream &fo) const { - fo.Write(&model.param, sizeof(model.param)); + virtual void Save(rabit::Stream *fo) const { + fo->Write(&model.param, sizeof(model.param)); } virtual double Eval(const float *weight, size_t size) { if (nthread != 0) omp_set_num_threads(nthread); diff --git a/rabit-learn/linear/linear.h b/rabit-learn/linear/linear.h index 67ae32b77..87fadb0d6 100644 --- a/rabit-learn/linear/linear.h +++ b/rabit-learn/linear/linear.h @@ -113,17 +113,17 @@ struct LinearModel { if (weight != NULL) delete [] weight; } // load model - inline void Load(rabit::IStream &fi) { - fi.Read(¶m, sizeof(param)); + inline void Load(rabit::Stream *fi) { + fi->Read(¶m, sizeof(param)); if (weight == NULL) { weight = new float[param.num_feature + 1]; } - fi.Read(weight, sizeof(float) * (param.num_feature + 1)); + fi->Read(weight, sizeof(float) * (param.num_feature + 1)); } - inline void Save(rabit::IStream &fo, const float *wptr = NULL) { - fo.Write(¶m, sizeof(param)); + inline void Save(rabit::Stream *fo, const float *wptr = NULL) { + fo->Write(¶m, sizeof(param)); if (wptr == NULL) wptr = weight; - fo.Write(wptr, sizeof(float) * (param.num_feature + 1)); + fo->Write(wptr, sizeof(float) * (param.num_feature + 1)); } inline float Predict(const SparseMat::Vector &v) const { return param.Predict(weight, v); diff --git a/rabit-learn/solver/lbfgs.h b/rabit-learn/solver/lbfgs.h index dcb2a93c4..cd8f783d2 100644 --- a/rabit-learn/solver/lbfgs.h +++ b/rabit-learn/solver/lbfgs.h @@ -19,7 +19,7 @@ namespace solver { * to remember the state parameters that might need to remember */ template -class IObjFunction : public rabit::ISerializable { +class IObjFunction : public rabit::Serializable { public: // destructor virtual ~IObjFunction(void){} @@ -463,7 +463,7 @@ class LBFGSSolver { } } // global solver state - struct GlobalState : public rabit::ISerializable { + struct GlobalState : public rabit::Serializable { public: // memory size of L-BFGS size_t size_memory; @@ -514,28 +514,28 @@ class LBFGSSolver { MapIndex(j, offset_, size_memory)]; } // load the shift array - virtual void Load(rabit::IStream &fi) { - fi.Read(&size_memory, sizeof(size_memory)); - fi.Read(&num_iteration, sizeof(num_iteration)); - fi.Read(&num_dim, sizeof(num_dim)); - fi.Read(&init_objval, sizeof(init_objval)); - fi.Read(&old_objval, sizeof(old_objval)); - fi.Read(&offset_, sizeof(offset_)); - fi.Read(&data); + virtual void Load(rabit::Stream *fi) { + fi->Read(&size_memory, sizeof(size_memory)); + fi->Read(&num_iteration, sizeof(num_iteration)); + fi->Read(&num_dim, sizeof(num_dim)); + fi->Read(&init_objval, sizeof(init_objval)); + fi->Read(&old_objval, sizeof(old_objval)); + fi->Read(&offset_, sizeof(offset_)); + fi->Read(&data); this->AllocSpace(); - fi.Read(weight, sizeof(DType) * num_dim); + fi->Read(weight, sizeof(DType) * num_dim); obj->Load(fi); } // save the shift array - virtual void Save(rabit::IStream &fo) const { - fo.Write(&size_memory, sizeof(size_memory)); - fo.Write(&num_iteration, sizeof(num_iteration)); - fo.Write(&num_dim, sizeof(num_dim)); - fo.Write(&init_objval, sizeof(init_objval)); - fo.Write(&old_objval, sizeof(old_objval)); - fo.Write(&offset_, sizeof(offset_)); - fo.Write(data); - fo.Write(weight, sizeof(DType) * num_dim); + virtual void Save(rabit::Stream *fo) const { + fo->Write(&size_memory, sizeof(size_memory)); + fo->Write(&num_iteration, sizeof(num_iteration)); + fo->Write(&num_dim, sizeof(num_dim)); + fo->Write(&init_objval, sizeof(init_objval)); + fo->Write(&old_objval, sizeof(old_objval)); + fo->Write(&offset_, sizeof(offset_)); + fo->Write(data); + fo->Write(weight, sizeof(DType) * num_dim); obj->Save(fo); } inline void Shift(void) { @@ -556,7 +556,7 @@ class LBFGSSolver { } }; /*! \brief rolling array that carries history information */ - struct HistoryArray : public rabit::ISerializable { + struct HistoryArray : public rabit::Serializable { public: HistoryArray(void) : dptr_(NULL) { num_useful_ = 0; @@ -609,26 +609,26 @@ class LBFGSSolver { num_useful_ = num_useful; } // load the shift array - virtual void Load(rabit::IStream &fi) { - fi.Read(&num_col_, sizeof(num_col_)); - fi.Read(&stride_, sizeof(stride_)); - fi.Read(&size_memory_, sizeof(size_memory_)); - fi.Read(&num_useful_, sizeof(num_useful_)); + virtual void Load(rabit::Stream *fi) { + fi->Read(&num_col_, sizeof(num_col_)); + fi->Read(&stride_, sizeof(stride_)); + fi->Read(&size_memory_, sizeof(size_memory_)); + fi->Read(&num_useful_, sizeof(num_useful_)); this->Init(num_col_, size_memory_); for (size_t i = 0; i < num_useful_; ++i) { - fi.Read((*this)[i], num_col_ * sizeof(DType)); - fi.Read((*this)[i + size_memory_], num_col_ * sizeof(DType)); + fi->Read((*this)[i], num_col_ * sizeof(DType)); + fi->Read((*this)[i + size_memory_], num_col_ * sizeof(DType)); } } // save the shift array - virtual void Save(rabit::IStream &fi) const { - fi.Write(&num_col_, sizeof(num_col_)); - fi.Write(&stride_, sizeof(stride_)); - fi.Write(&size_memory_, sizeof(size_memory_)); - fi.Write(&num_useful_, sizeof(num_useful_)); + virtual void Save(rabit::Stream *fo) const { + fo->Write(&num_col_, sizeof(num_col_)); + fo->Write(&stride_, sizeof(stride_)); + fo->Write(&size_memory_, sizeof(size_memory_)); + fo->Write(&num_useful_, sizeof(num_useful_)); for (size_t i = 0; i < num_useful_; ++i) { - fi.Write((*this)[i], num_col_ * sizeof(DType)); - fi.Write((*this)[i + size_memory_], num_col_ * sizeof(DType)); + fo->Write((*this)[i], num_col_ * sizeof(DType)); + fo->Write((*this)[i + size_memory_], num_col_ * sizeof(DType)); } } diff --git a/rabit-learn/utils/data.h b/rabit-learn/utils/data.h index e72a19d51..10b862f44 100644 --- a/rabit-learn/utils/data.h +++ b/rabit-learn/utils/data.h @@ -93,43 +93,6 @@ struct SparseMat { std::vector labels; }; -// dense matrix -struct Matrix { - inline void Init(size_t nrow, size_t ncol, float v = 0.0f) { - this->nrow = nrow; - this->ncol = ncol; - data.resize(nrow * ncol); - std::fill(data.begin(), data.end(), v); - } - inline float *operator[](size_t i) { - return &data[0] + i * ncol; - } - inline const float *operator[](size_t i) const { - return &data[0] + i * ncol; - } - inline void Print(const char *fname) { - FILE *fo; - if (!strcmp(fname, "stdout")) { - fo = stdout; - } else { - fo = utils::FopenCheck(fname, "w"); - } - for (size_t i = 0; i < data.size(); ++i) { - fprintf(fo, "%g", data[i]); - if ((i+1) % ncol == 0) { - fprintf(fo, "\n"); - } else { - fprintf(fo, " "); - } - } - // close the filed - if (fo != stdout) fclose(fo); - } - // number of data - size_t nrow, ncol; - std::vector data; -}; - /*!\brief computes a random number modulo the value */ inline int Random(int value) { return rand() % value; diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 690c27d8a..c34eb6042 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -126,8 +126,8 @@ class AllreduceBase : public IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model = NULL) { + virtual int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = NULL) { return 0; } /*! @@ -146,8 +146,8 @@ class AllreduceBase : public IEngine { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL) { + virtual void CheckPoint(const Serializable *global_model, + const Serializable *local_model = NULL) { version_number += 1; } /*! @@ -170,7 +170,7 @@ class AllreduceBase : public IEngine { * is the same in all nodes * \sa LoadCheckPoint, CheckPoint, VersionNumber */ - virtual void LazyCheckPoint(const ISerializable *global_model) { + virtual void LazyCheckPoint(const Serializable *global_model) { version_number += 1; } /*! diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 666acbeef..4c271e7ba 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -5,8 +5,8 @@ * * \author Ignacio Cano, Tianqi Chen */ -#ifndef RABIT_ALLREDUCE_MOCK_H -#define RABIT_ALLREDUCE_MOCK_H +#ifndef RABIT_ALLREDUCE_MOCK_H_ +#define RABIT_ALLREDUCE_MOCK_H_ #include #include #include @@ -58,8 +58,8 @@ class AllreduceMock : public AllreduceRobust { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); } - virtual int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model) { + virtual int LoadCheckPoint(Serializable *global_model, + Serializable *local_model) { tsum_allreduce = 0.0; time_checkpoint = utils::GetTime(); if (force_local == 0) { @@ -70,8 +70,8 @@ class AllreduceMock : public AllreduceRobust { return AllreduceRobust::LoadCheckPoint(&dum, &com); } } - virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model) { + virtual void CheckPoint(const Serializable *global_model, + const Serializable *local_model) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); double tstart = utils::GetTime(); double tbet_chkpt = tstart - time_checkpoint; @@ -96,7 +96,7 @@ class AllreduceMock : public AllreduceRobust { tsum_allreduce = 0.0; } - virtual void LazyCheckPoint(const ISerializable *global_model) { + virtual void LazyCheckPoint(const Serializable *global_model) { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); AllreduceRobust::LazyCheckPoint(global_model); } @@ -110,28 +110,28 @@ class AllreduceMock : public AllreduceRobust { double time_checkpoint; private: - struct DummySerializer : public ISerializable { - virtual void Load(IStream &fi) { + struct DummySerializer : public Serializable { + virtual void Load(Stream *fi) { } - virtual void Save(IStream &fo) const { + virtual void Save(Stream *fo) const { } }; - struct ComboSerializer : public ISerializable { - ISerializable *lhs; - ISerializable *rhs; - const ISerializable *c_lhs; - const ISerializable *c_rhs; - ComboSerializer(ISerializable *lhs, ISerializable *rhs) + struct ComboSerializer : public Serializable { + Serializable *lhs; + Serializable *rhs; + const Serializable *c_lhs; + const Serializable *c_rhs; + ComboSerializer(Serializable *lhs, Serializable *rhs) : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) { } - ComboSerializer(const ISerializable *lhs, const ISerializable *rhs) + ComboSerializer(const Serializable *lhs, const Serializable *rhs) : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { } - virtual void Load(IStream &fi) { + virtual void Load(Stream *fi) { if (lhs != NULL) lhs->Load(fi); if (rhs != NULL) rhs->Load(fi); } - virtual void Save(IStream &fo) const { + virtual void Save(Stream *fo) const { if (c_lhs != NULL) c_lhs->Save(fo); if (c_rhs != NULL) c_rhs->Save(fo); } @@ -173,4 +173,4 @@ class AllreduceMock : public AllreduceRobust { }; } // namespace engine } // namespace rabit -#endif // RABIT_ALLREDUCE_MOCK_H +#endif // RABIT_ALLREDUCE_MOCK_H_ diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 84abaceba..339603498 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -158,8 +158,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) * * \sa CheckPoint, VersionNumber */ -int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model) { +int AllreduceRobust::LoadCheckPoint(Serializable *global_model, + Serializable *local_model) { // skip action in single node if (world_size == 1) return 0; this->LocalModelCheck(local_model != NULL); @@ -175,7 +175,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, // load in local model utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]), local_rptr[local_chkpt_version][1]); - local_model->Load(fs); + local_model->Load(&fs); } else { utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); } @@ -189,7 +189,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, } else { utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number"); - global_model->Load(fs); + global_model->Load(&fs); utils::Assert(local_model == NULL || nlocal == num_local_replica + 1, "local model inconsistent, nlocal=%d", nlocal); } @@ -241,8 +241,8 @@ void AllreduceRobust::LocalModelCheck(bool with_local) { * * \sa CheckPoint, LazyCheckPoint */ -void AllreduceRobust::CheckPoint_(const ISerializable *global_model, - const ISerializable *local_model, +void AllreduceRobust::CheckPoint_(const Serializable *global_model, + const Serializable *local_model, bool lazy_checkpt) { // never do check point in single machine mode if (world_size == 1) { @@ -261,7 +261,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model, local_chkpt[new_version].clear(); utils::MemoryBufferStream fs(&local_chkpt[new_version]); if (local_model != NULL) { - local_model->Save(fs); + local_model->Save(&fs); } local_rptr[new_version].clear(); local_rptr[new_version].push_back(0); @@ -287,7 +287,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model, global_checkpoint.resize(0); utils::MemoryBufferStream fs(&global_checkpoint); fs.Write(&version_number, sizeof(version_number)); - global_model->Save(fs); + global_model->Save(&fs); global_lazycheck = NULL; } // reset result buffer @@ -748,7 +748,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { global_checkpoint.resize(0); utils::MemoryBufferStream fs(&global_checkpoint); fs.Write(&version_number, sizeof(version_number)); - global_lazycheck->Save(fs); + global_lazycheck->Save(&fs); global_lazycheck = NULL; } // recover global checkpoint diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 9fc97b9f4..658d6f8c7 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -80,8 +80,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model = NULL); + virtual int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = NULL); /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one @@ -98,8 +98,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL) { + virtual void CheckPoint(const Serializable *global_model, + const Serializable *local_model = NULL) { this->CheckPoint_(global_model, local_model, false); } /*! @@ -122,7 +122,7 @@ class AllreduceRobust : public AllreduceBase { * is the same in all nodes * \sa LoadCheckPoint, CheckPoint, VersionNumber */ - virtual void LazyCheckPoint(const ISerializable *global_model) { + virtual void LazyCheckPoint(const Serializable *global_model) { this->CheckPoint_(global_model, NULL, true); } /*! @@ -318,8 +318,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa CheckPoint, LazyCheckPoint */ - void CheckPoint_(const ISerializable *global_model, - const ISerializable *local_model, + void CheckPoint_(const Serializable *global_model, + const Serializable *local_model, bool lazy_checkpt); /*! * \brief reset the all the existing links by sending Out-of-Band message marker @@ -521,7 +521,7 @@ o * the input state must exactly one saved state(local state of current node) // last check point global model std::string global_checkpoint; // lazy checkpoint of global model - const ISerializable *global_lazycheck; + const Serializable *global_lazycheck; // number of replica for local state/model int num_local_replica; // number of default local replica diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index c434e71d0..5c8a4c372 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -37,15 +37,15 @@ class MPIEngine : public IEngine { virtual void InitAfterException(void) { utils::Error("MPI is not fault tolerant"); } - virtual int LoadCheckPoint(ISerializable *global_model, - ISerializable *local_model = NULL) { + virtual int LoadCheckPoint(Serializable *global_model, + Serializable *local_model = NULL) { return 0; } - virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL) { + virtual void CheckPoint(const Serializable *global_model, + const Serializable *local_model = NULL) { version_number += 1; } - virtual void LazyCheckPoint(const ISerializable *global_model) { + virtual void LazyCheckPoint(const Serializable *global_model) { version_number += 1; } virtual int VersionNumber(void) const { diff --git a/test/lazy_recover.cc b/test/lazy_recover.cc index d20e4f994..610a20664 100644 --- a/test/lazy_recover.cc +++ b/test/lazy_recover.cc @@ -8,17 +8,17 @@ using namespace rabit; // dummy model -class Model : public rabit::ISerializable { +class Model : public rabit::Serializable { public: // iterations std::vector data; // load from stream - virtual void Load(rabit::IStream &fi) { - fi.Read(&data); + virtual void Load(rabit::Stream *fi) { + fi->Read(&data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::IStream &fo) const { - fo.Write(data); + virtual void Save(rabit::Stream *fo) const { + fo->Write(data); } virtual void InitModel(size_t n) { data.clear(); diff --git a/test/local_recover.cc b/test/local_recover.cc index a601dd2d5..5162d5a2d 100644 --- a/test/local_recover.cc +++ b/test/local_recover.cc @@ -9,17 +9,17 @@ using namespace rabit; // dummy model -class Model : public rabit::ISerializable { +class Model : public rabit::Serializable { public: // iterations std::vector data; // load from stream - virtual void Load(rabit::IStream &fi) { - fi.Read(&data); + virtual void Load(rabit::Stream *fi) { + fi->Read(&data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::IStream &fo) const { - fo.Write(data); + virtual void Save(rabit::Stream *fo) const { + fo->Write(data); } virtual void InitModel(size_t n, float v) { data.clear(); diff --git a/test/model_recover.cc b/test/model_recover.cc index 24012b91f..f833ef295 100644 --- a/test/model_recover.cc +++ b/test/model_recover.cc @@ -8,17 +8,17 @@ using namespace rabit; // dummy model -class Model : public rabit::ISerializable { +class Model : public rabit::Serializable { public: // iterations std::vector data; // load from stream - virtual void Load(rabit::IStream &fi) { - fi.Read(&data); + virtual void Load(rabit::Stream *fi) { + fi->Read(&data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::IStream &fo) const { - fo.Write(data); + virtual void Save(rabit::Stream *fo) const { + fo->Write(data); } virtual void InitModel(size_t n) { data.clear(); diff --git a/wrapper/rabit_wrapper.cc b/wrapper/rabit_wrapper.cc index ac2708f00..704bf4abc 100644 --- a/wrapper/rabit_wrapper.cc +++ b/wrapper/rabit_wrapper.cc @@ -119,38 +119,38 @@ inline void Allreduce(void *sendrecvbuf, // temporal memory for global and local model std::string global_buffer, local_buffer; // wrapper for serialization -struct ReadWrapper : public ISerializable { +struct ReadWrapper : public Serializable { std::string *p_str; explicit ReadWrapper(std::string *p_str) : p_str(p_str) {} - virtual void Load(IStream &fi) { + virtual void Load(Stream *fi) { uint64_t sz; - utils::Assert(fi.Read(&sz, sizeof(sz)) != 0, + utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, "Read pickle string"); p_str->resize(sz); if (sz != 0) { - utils::Assert(fi.Read(&(*p_str)[0], sizeof(char) * sz) != 0, + utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0, "Read pickle string"); } } - virtual void Save(IStream &fo) const { + virtual void Save(Stream *fo) const { utils::Error("not implemented"); } }; -struct WriteWrapper : public ISerializable { +struct WriteWrapper : public Serializable { const char *data; size_t length; explicit WriteWrapper(const char *data, size_t length) : data(data), length(length) { } - virtual void Load(IStream &fi) { + virtual void Load(Stream *fi) { utils::Error("not implemented"); } - virtual void Save(IStream &fo) const { + virtual void Save(Stream *fo) const { uint64_t sz = static_cast(length); - fo.Write(&sz, sizeof(sz)); - fo.Write(data, length * sizeof(char)); + fo->Write(&sz, sizeof(sz)); + fo->Write(data, length * sizeof(char)); } }; } // namespace wrapper