isolate iserializable

This commit is contained in:
tqchen 2014-12-19 17:36:42 -08:00
parent 8c35cff02c
commit 6bf282c6c2
14 changed files with 150 additions and 134 deletions

View File

@ -111,8 +111,8 @@ class AllreduceBase : public IEngine {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(utils::ISerializable *global_model, virtual int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model = NULL) { ISerializable *local_model = NULL) {
return 0; return 0;
} }
/*! /*!
@ -131,8 +131,8 @@ class AllreduceBase : public IEngine {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable *global_model, virtual void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model = NULL) { const ISerializable *local_model = NULL) {
version_number += 1; version_number += 1;
} }
/*! /*!

View File

@ -11,8 +11,8 @@
#include <utility> #include <utility>
#include "./io.h" #include "./io.h"
#include "./utils.h" #include "./utils.h"
#include "./rabit.h"
#include "./allreduce_robust.h" #include "./allreduce_robust.h"
#include "./rabit.h"
namespace rabit { namespace rabit {
namespace engine { namespace engine {
@ -141,8 +141,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model, int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model) { ISerializable *local_model) {
if (num_local_replica == 0) { if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
} }
@ -198,8 +198,8 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model, void AllreduceRobust::CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model) { const ISerializable *local_model) {
if (num_local_replica == 0) { if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
} }

View File

@ -75,8 +75,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(utils::ISerializable *global_model, virtual int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model = NULL); ISerializable *local_model = NULL);
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
@ -93,8 +93,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable *global_model, virtual void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model = NULL); const ISerializable *local_model = NULL);
/*! /*!
* \brief explicitly re-init everything before calling LoadCheckPoint * \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out, * call this function when IEngine throw an exception out,

View File

@ -5,7 +5,7 @@
*/ */
#ifndef RABIT_ENGINE_H #ifndef RABIT_ENGINE_H
#define RABIT_ENGINE_H #define RABIT_ENGINE_H
#include "./io.h" #include "./serializable.h"
namespace MPI { namespace MPI {
/*! \brief MPI data type just to be compatible with MPI reduce function*/ /*! \brief MPI data type just to be compatible with MPI reduce function*/
@ -92,8 +92,8 @@ class IEngine {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(utils::ISerializable *global_model, virtual int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model = NULL) = 0; ISerializable *local_model = NULL) = 0;
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
@ -110,8 +110,8 @@ class IEngine {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable *global_model, virtual void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model = NULL) = 0; const ISerializable *local_model = NULL) = 0;
/*! /*!
* \return version number of current stored model, * \return version number of current stored model,
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far

View File

@ -34,12 +34,12 @@ class MPIEngine : public IEngine {
virtual void InitAfterException(void) { virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant"); utils::Error("MPI is not fault tolerant");
} }
virtual int LoadCheckPoint(utils::ISerializable *global_model, virtual int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model = NULL) { ISerializable *local_model = NULL) {
return 0; return 0;
} }
virtual void CheckPoint(const utils::ISerializable *global_model, virtual void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model = NULL) { const ISerializable *local_model = NULL) {
version_number += 1; version_number += 1;
} }
virtual int VersionNumber(void) const { virtual int VersionNumber(void) const {

View File

@ -5,99 +5,14 @@
#include <cstring> #include <cstring>
#include <string> #include <string>
#include "./utils.h" #include "./utils.h"
#include "./serializable.h"
/*! /*!
* \file io.h * \file io.h
* \brief general stream interface for serialization, I/O * \brief utilities that implements different serializable interface
* \author Tianqi Chen * \author Tianqi Chen
*/ */
namespace rabit { namespace rabit {
namespace utils { namespace utils {
/*!
* \brief interface of stream I/O, used to serialize model
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/
virtual void Save(IStream &fo) const = 0;
};
/*! \brief interface of i/o stream that support seek */ /*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream { class ISeekStream: public IStream {
public: public:

View File

@ -6,6 +6,8 @@
*/ */
#ifndef RABIT_RABIT_INL_H #ifndef RABIT_RABIT_INL_H
#define RABIT_RABIT_INL_H #define RABIT_RABIT_INL_H
// use engine for implementation
#include "./engine.h"
namespace rabit { namespace rabit {
namespace engine { namespace engine {
@ -139,13 +141,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
#endif // C++11 #endif // C++11
// load latest check point // load latest check point
inline int LoadCheckPoint(utils::ISerializable *global_model, inline int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model) { ISerializable *local_model) {
return engine::GetEngine()->LoadCheckPoint(global_model, local_model); return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
} }
// checkpoint the model, meaning we finished a stage of execution // checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const utils::ISerializable *global_model, inline void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model) { const ISerializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model); engine::GetEngine()->CheckPoint(global_model, local_model);
} }
// return the version number of currently stored model // return the version number of currently stored model

View File

@ -6,6 +6,7 @@
* The actual implementation is redirected to rabit engine * The actual implementation is redirected to rabit engine
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery), * Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
* *
* rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#include <string> #include <string>
@ -14,9 +15,8 @@
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
#include <functional> #include <functional>
#endif // C++11 #endif // C++11
// rabit headers // contains definition of ISerializable
#include "./io.h" #include "./serializable.h"
#include "./engine.h"
/*! \brief namespace of rabit */ /*! \brief namespace of rabit */
namespace rabit { namespace rabit {
@ -31,7 +31,6 @@ struct Sum;
/*! \brief perform bitwise OR */ /*! \brief perform bitwise OR */
struct BitOR; struct BitOR;
} // namespace op } // namespace op
/*! /*!
* \brief intialize the rabit module, call this once function before using anything * \brief intialize the rabit module, call this once function before using anything
* \param argc number of arguments in argv * \param argc number of arguments in argv
@ -143,8 +142,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
inline int LoadCheckPoint(utils::ISerializable *global_model, inline int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model = NULL); ISerializable *local_model = NULL);
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
@ -159,8 +158,8 @@ inline int LoadCheckPoint(utils::ISerializable *global_model,
* So only CheckPoint with global_model if possible * So only CheckPoint with global_model if possible
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
inline void CheckPoint(const utils::ISerializable *global_model, inline void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model = NULL); const ISerializable *local_model = NULL);
/*! /*!
* \return version number of current stored model, * \return version number of current stored model,
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far

99
src/serializable.h Normal file
View File

@ -0,0 +1,99 @@
#ifndef RABIT_SERIALIZABLE_H
#define RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./utils.h"
/*!
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
namespace rabit {
/*!
* \brief interface of stream I/O, used by ISerializable
* \sa ISerializable
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit
#endif

View File

@ -33,14 +33,14 @@ public:
rabit::Allreduce<OP>(sendrecvbuf, count); rabit::Allreduce<OP>(sendrecvbuf, count);
} }
inline int LoadCheckPoint(utils::ISerializable *global_model, inline int LoadCheckPoint(ISerializable *global_model,
utils::ISerializable *local_model) { ISerializable *local_model) {
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank); utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
return rabit::LoadCheckPoint(global_model, local_model); return rabit::LoadCheckPoint(global_model, local_model);
} }
inline void CheckPoint(const utils::ISerializable *global_model, inline void CheckPoint(const ISerializable *global_model,
const utils::ISerializable *local_model) { const ISerializable *local_model) {
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank); utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
rabit::CheckPoint(global_model, local_model); rabit::CheckPoint(global_model, local_model);
} }

View File

@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
} }
// dummy model // dummy model
class Model : public rabit::utils::ISerializable { class Model : public rabit::ISerializable {
public: public:
// iterations // iterations
std::vector<float> data; std::vector<float> data;
// load from stream // load from stream
virtual void Load(rabit::utils::IStream &fi) { virtual void Load(rabit::IStream &fi) {
fi.Read(&data); fi.Read(&data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const { virtual void Save(rabit::IStream &fo) const {
fo.Write(data); fo.Write(data);
} }
virtual void InitModel(size_t n, float v) { virtual void InitModel(size_t n, float v) {

View File

@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
} }
// dummy model // dummy model
class Model : public rabit::utils::ISerializable { class Model : public rabit::ISerializable {
public: public:
// iterations // iterations
std::vector<float> data; std::vector<float> data;
// load from stream // load from stream
virtual void Load(rabit::utils::IStream &fi) { virtual void Load(rabit::IStream &fi) {
fi.Read(&data); fi.Read(&data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const { virtual void Save(rabit::IStream &fo) const {
fo.Write(data); fo.Write(data);
} }
virtual void InitModel(size_t n) { virtual void InitModel(size_t n) {

View File

@ -8,18 +8,18 @@
using namespace rabit; using namespace rabit;
// kmeans model // kmeans model
class Model : public rabit::utils::ISerializable { class Model : public rabit::ISerializable {
public: public:
// matrix of centroids // matrix of centroids
Matrix centroids; Matrix centroids;
// load from stream // load from stream
virtual void Load(rabit::utils::IStream &fi) { virtual void Load(rabit::IStream &fi) {
fi.Read(&centroids.nrow, sizeof(centroids.nrow)); fi.Read(&centroids.nrow, sizeof(centroids.nrow));
fi.Read(&centroids.ncol, sizeof(centroids.ncol)); fi.Read(&centroids.ncol, sizeof(centroids.ncol));
fi.Read(&centroids.data); fi.Read(&centroids.data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const { virtual void Save(rabit::IStream &fo) const {
fo.Write(&centroids.nrow, sizeof(centroids.nrow)); fo.Write(&centroids.nrow, sizeof(centroids.nrow));
fo.Write(&centroids.ncol, sizeof(centroids.ncol)); fo.Write(&centroids.ncol, sizeof(centroids.ncol));
fo.Write(centroids.data); fo.Write(centroids.data);

View File

@ -2,6 +2,7 @@
#include <vector> #include <vector>
#include <cstdlib> #include <cstdlib>
#include <cstdio> #include <cstdio>
#include <cstring>
#include <cmath> #include <cmath>
namespace rabit { namespace rabit {