diff --git a/src/allreduce_base.h b/src/allreduce_base.h index bc7cc26c9..e313cab88 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -111,8 +111,8 @@ class AllreduceBase : public IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model = NULL) { + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL) { return 0; } /*! @@ -131,8 +131,8 @@ class AllreduceBase : public IEngine { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model = NULL) { + virtual void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL) { version_number += 1; } /*! diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 7d339cf84..538609c62 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -11,8 +11,8 @@ #include #include "./io.h" #include "./utils.h" -#include "./rabit.h" #include "./allreduce_robust.h" +#include "./rabit.h" namespace rabit { namespace engine { @@ -141,8 +141,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) * * \sa CheckPoint, VersionNumber */ -int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model) { +int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model) { if (num_local_replica == 0) { utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); } @@ -198,8 +198,8 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, * * \sa LoadCheckPoint, VersionNumber */ -void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model) { +void AllreduceRobust::CheckPoint(const ISerializable *global_model, + const ISerializable *local_model) { if (num_local_replica == 0) { utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); } diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 7888a66f1..fd85e4828 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -75,8 +75,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model = NULL); + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL); /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one @@ -93,8 +93,8 @@ class AllreduceRobust : public AllreduceBase { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model = NULL); + virtual void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL); /*! * \brief explicitly re-init everything before calling LoadCheckPoint * call this function when IEngine throw an exception out, diff --git a/src/engine.h b/src/engine.h index 977b0d6ff..0700b2a95 100644 --- a/src/engine.h +++ b/src/engine.h @@ -5,7 +5,7 @@ */ #ifndef RABIT_ENGINE_H #define RABIT_ENGINE_H -#include "./io.h" +#include "./serializable.h" namespace MPI { /*! \brief MPI data type just to be compatible with MPI reduce function*/ @@ -92,8 +92,8 @@ class IEngine { * * \sa CheckPoint, VersionNumber */ - virtual int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model = NULL) = 0; + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL) = 0; /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one @@ -110,8 +110,8 @@ class IEngine { * * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model = NULL) = 0; + virtual void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL) = 0; /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 9e5972e1a..870c93fdb 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -34,12 +34,12 @@ class MPIEngine : public IEngine { virtual void InitAfterException(void) { utils::Error("MPI is not fault tolerant"); } - virtual int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model = NULL) { + virtual int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL) { return 0; } - virtual void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model = NULL) { + virtual void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL) { version_number += 1; } virtual int VersionNumber(void) const { diff --git a/src/io.h b/src/io.h index ed01545f2..699a93c83 100644 --- a/src/io.h +++ b/src/io.h @@ -5,99 +5,14 @@ #include #include #include "./utils.h" +#include "./serializable.h" /*! * \file io.h - * \brief general stream interface for serialization, I/O + * \brief utilities that implements different serializable interface * \author Tianqi Chen */ namespace rabit { namespace utils { -/*! - * \brief interface of stream I/O, used to serialize model - */ -class IStream { - public: - /*! - * \brief read data from stream - * \param ptr pointer to memory buffer - * \param size size of block - * \return usually is the size of data readed - */ - virtual size_t Read(void *ptr, size_t size) = 0; - /*! - * \brief write data to stream - * \param ptr pointer to memory buffer - * \param size size of block - */ - virtual void Write(const void *ptr, size_t size) = 0; - /*! \brief virtual destructor */ - virtual ~IStream(void) {} - - public: - // helper functions to write various of data structures - /*! - * \brief binary serialize a vector - * \param vec vector to be serialized - */ - template - inline void Write(const std::vector &vec) { - uint64_t sz = static_cast(vec.size()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&vec[0], sizeof(T) * sz); - } - } - /*! - * \brief binary load a vector - * \param out_vec vector to be loaded - * \return whether load is successfull - */ - template - inline bool Read(std::vector *out_vec) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - out_vec->resize(sz); - if (sz != 0) { - if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; - } - return true; - } - /*! - * \brief binary serialize a string - * \param str the string to be serialized - */ - inline void Write(const std::string &str) { - uint64_t sz = static_cast(str.length()); - this->Write(&sz, sizeof(sz)); - if (sz != 0) { - this->Write(&str[0], sizeof(char) * sz); - } - } - /*! - * \brief binary load a string - * \param out_str string to be loaded - * \return whether load is successful - */ - inline bool Read(std::string *out_str) { - uint64_t sz; - if (this->Read(&sz, sizeof(sz)) == 0) return false; - out_str->resize(sz); - if (sz != 0) { - if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false; - } - return true; - } -}; - -/*! \brief interface of se*/ -class ISerializable { - public: - /*! \brief load the model from file */ - virtual void Load(IStream &fi) = 0; - /*! \brief save the model to the stream*/ - virtual void Save(IStream &fo) const = 0; -}; - /*! \brief interface of i/o stream that support seek */ class ISeekStream: public IStream { public: diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 8d379f920..b6126f47d 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -6,6 +6,8 @@ */ #ifndef RABIT_RABIT_INL_H #define RABIT_RABIT_INL_H +// use engine for implementation +#include "./engine.h" namespace rabit { namespace engine { @@ -139,13 +141,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function pr #endif // C++11 // load latest check point -inline int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model) { +inline int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model) { return engine::GetEngine()->LoadCheckPoint(global_model, local_model); } // checkpoint the model, meaning we finished a stage of execution -inline void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model) { +inline void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model) { engine::GetEngine()->CheckPoint(global_model, local_model); } // return the version number of currently stored model diff --git a/src/rabit.h b/src/rabit.h index ac17faec6..cc65e62ae 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -6,6 +6,7 @@ * The actual implementation is redirected to rabit engine * Code only using this header can also compiled with MPI Allreduce(with no fault recovery), * + * rabit.h and serializable.h is all the user need to use rabit interface * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ #include @@ -14,9 +15,8 @@ #if __cplusplus >= 201103L #include #endif // C++11 -// rabit headers -#include "./io.h" -#include "./engine.h" +// contains definition of ISerializable +#include "./serializable.h" /*! \brief namespace of rabit */ namespace rabit { @@ -31,7 +31,6 @@ struct Sum; /*! \brief perform bitwise OR */ struct BitOR; } // namespace op - /*! * \brief intialize the rabit module, call this once function before using anything * \param argc number of arguments in argv @@ -143,8 +142,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function pr * * \sa CheckPoint, VersionNumber */ -inline int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model = NULL); +inline int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model = NULL); /*! * \brief checkpoint the model, meaning we finished a stage of execution * every time we call check point, there is a version number which will increase by one @@ -159,8 +158,8 @@ inline int LoadCheckPoint(utils::ISerializable *global_model, * So only CheckPoint with global_model if possible * \sa LoadCheckPoint, VersionNumber */ -inline void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model = NULL); +inline void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model = NULL); /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/src/serializable.h b/src/serializable.h new file mode 100644 index 000000000..a269dc1c7 --- /dev/null +++ b/src/serializable.h @@ -0,0 +1,99 @@ +#ifndef RABIT_SERIALIZABLE_H +#define RABIT_SERIALIZABLE_H +#include +#include +#include "./utils.h" +/*! + * \file serializable.h + * \brief defines serializable interface of rabit + * \author Tianqi Chen + */ +namespace rabit { +/*! + * \brief interface of stream I/O, used by ISerializable + * \sa ISerializable + */ +class IStream { + public: + /*! + * \brief read data from stream + * \param ptr pointer to memory buffer + * \param size size of block + * \return usually is the size of data readed + */ + virtual size_t Read(void *ptr, size_t size) = 0; + /*! + * \brief write data to stream + * \param ptr pointer to memory buffer + * \param size size of block + */ + virtual void Write(const void *ptr, size_t size) = 0; + /*! \brief virtual destructor */ + virtual ~IStream(void) {} + + public: + // helper functions to write various of data structures + /*! + * \brief binary serialize a vector + * \param vec vector to be serialized + */ + template + inline void Write(const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&vec[0], sizeof(T) * sz); + } + } + /*! + * \brief binary load a vector + * \param out_vec vector to be loaded + * \return whether load is successfull + */ + template + inline bool Read(std::vector *out_vec) { + uint64_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_vec->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; + } + return true; + } + /*! + * \brief binary serialize a string + * \param str the string to be serialized + */ + inline void Write(const std::string &str) { + uint64_t sz = static_cast(str.length()); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&str[0], sizeof(char) * sz); + } + } + /*! + * \brief binary load a string + * \param out_str string to be loaded + * \return whether load is successful + */ + inline bool Read(std::string *out_str) { + uint64_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_str->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false; + } + return true; + } +}; + +/*! \brief interface of se*/ +class ISerializable { + public: + /*! \brief load the model from file */ + virtual void Load(IStream &fi) = 0; + /*! \brief save the model to the stream*/ + virtual void Save(IStream &fo) const = 0; +}; +} // namespace rabit +#endif diff --git a/test/mock.h b/test/mock.h index a5ac39c83..17e5b75c9 100644 --- a/test/mock.h +++ b/test/mock.h @@ -33,14 +33,14 @@ public: rabit::Allreduce(sendrecvbuf, count); } -inline int LoadCheckPoint(utils::ISerializable *global_model, - utils::ISerializable *local_model) { +inline int LoadCheckPoint(ISerializable *global_model, + ISerializable *local_model) { utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank); return rabit::LoadCheckPoint(global_model, local_model); } - inline void CheckPoint(const utils::ISerializable *global_model, - const utils::ISerializable *local_model) { + inline void CheckPoint(const ISerializable *global_model, + const ISerializable *local_model) { utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank); rabit::CheckPoint(global_model, local_model); } diff --git a/test/test_local_recover.cpp b/test/test_local_recover.cpp index d98c6ae48..b9b84f2d1 100644 --- a/test/test_local_recover.cpp +++ b/test/test_local_recover.cpp @@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) { } // dummy model -class Model : public rabit::utils::ISerializable { +class Model : public rabit::ISerializable { public: // iterations std::vector data; // load from stream - virtual void Load(rabit::utils::IStream &fi) { + virtual void Load(rabit::IStream &fi) { fi.Read(&data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::utils::IStream &fo) const { + virtual void Save(rabit::IStream &fo) const { fo.Write(data); } virtual void InitModel(size_t n, float v) { diff --git a/test/test_model_recover.cpp b/test/test_model_recover.cpp index 6feb56dde..aba107a85 100644 --- a/test/test_model_recover.cpp +++ b/test/test_model_recover.cpp @@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) { } // dummy model -class Model : public rabit::utils::ISerializable { +class Model : public rabit::ISerializable { public: // iterations std::vector data; // load from stream - virtual void Load(rabit::utils::IStream &fi) { + virtual void Load(rabit::IStream &fi) { fi.Read(&data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::utils::IStream &fo) const { + virtual void Save(rabit::IStream &fo) const { fo.Write(data); } virtual void InitModel(size_t n) { diff --git a/toolkit/kmeans.cpp b/toolkit/kmeans.cpp index c08e50c23..bbd5067af 100644 --- a/toolkit/kmeans.cpp +++ b/toolkit/kmeans.cpp @@ -8,18 +8,18 @@ using namespace rabit; // kmeans model -class Model : public rabit::utils::ISerializable { +class Model : public rabit::ISerializable { public: // matrix of centroids Matrix centroids; // load from stream - virtual void Load(rabit::utils::IStream &fi) { + virtual void Load(rabit::IStream &fi) { fi.Read(¢roids.nrow, sizeof(centroids.nrow)); fi.Read(¢roids.ncol, sizeof(centroids.ncol)); fi.Read(¢roids.data); } /*! \brief save the model to the stream */ - virtual void Save(rabit::utils::IStream &fo) const { + virtual void Save(rabit::IStream &fo) const { fo.Write(¢roids.nrow, sizeof(centroids.nrow)); fo.Write(¢roids.ncol, sizeof(centroids.ncol)); fo.Write(centroids.data); diff --git a/toolkit/toolkit_util.h b/toolkit/toolkit_util.h index cff7b7fe0..a2f8f56ac 100644 --- a/toolkit/toolkit_util.h +++ b/toolkit/toolkit_util.h @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace rabit {