isolate iserializable

This commit is contained in:
tqchen 2014-12-19 17:36:42 -08:00
parent 8c35cff02c
commit 6bf282c6c2
14 changed files with 150 additions and 134 deletions

View File

@ -111,8 +111,8 @@ class AllreduceBase : public IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) {
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
/*!
@ -131,8 +131,8 @@ class AllreduceBase : public IEngine {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) {
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
/*!

View File

@ -11,8 +11,8 @@
#include <utility>
#include "./io.h"
#include "./utils.h"
#include "./rabit.h"
#include "./allreduce_robust.h"
#include "./rabit.h"
namespace rabit {
namespace engine {
@ -141,8 +141,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
*
* \sa CheckPoint, VersionNumber
*/
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}
@ -198,8 +198,8 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
*
* \sa LoadCheckPoint, VersionNumber
*/
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}

View File

@ -75,8 +75,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL);
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@ -93,8 +93,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL);
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,

View File

@ -5,7 +5,7 @@
*/
#ifndef RABIT_ENGINE_H
#define RABIT_ENGINE_H
#include "./io.h"
#include "./serializable.h"
namespace MPI {
/*! \brief MPI data type just to be compatible with MPI reduce function*/
@ -92,8 +92,8 @@ class IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) = 0;
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) = 0;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@ -110,8 +110,8 @@ class IEngine {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) = 0;
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@ -34,12 +34,12 @@ class MPIEngine : public IEngine {
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) {
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) {
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual int VersionNumber(void) const {

View File

@ -5,99 +5,14 @@
#include <cstring>
#include <string>
#include "./utils.h"
#include "./serializable.h"
/*!
* \file io.h
* \brief general stream interface for serialization, I/O
* \brief utilities that implements different serializable interface
* \author Tianqi Chen
*/
namespace rabit {
namespace utils {
/*!
* \brief interface of stream I/O, used to serialize model
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/
virtual void Save(IStream &fo) const = 0;
};
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:

View File

@ -6,6 +6,8 @@
*/
#ifndef RABIT_RABIT_INL_H
#define RABIT_RABIT_INL_H
// use engine for implementation
#include "./engine.h"
namespace rabit {
namespace engine {
@ -139,13 +141,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
#endif // C++11
// load latest check point
inline int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
}
// checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model);
}
// return the version number of currently stored model

View File

@ -6,6 +6,7 @@
* The actual implementation is redirected to rabit engine
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
*
* rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <string>
@ -14,9 +15,8 @@
#if __cplusplus >= 201103L
#include <functional>
#endif // C++11
// rabit headers
#include "./io.h"
#include "./engine.h"
// contains definition of ISerializable
#include "./serializable.h"
/*! \brief namespace of rabit */
namespace rabit {
@ -31,7 +31,6 @@ struct Sum;
/*! \brief perform bitwise OR */
struct BitOR;
} // namespace op
/*!
* \brief intialize the rabit module, call this once function before using anything
* \param argc number of arguments in argv
@ -143,8 +142,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
*
* \sa CheckPoint, VersionNumber
*/
inline int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL);
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@ -159,8 +158,8 @@ inline int LoadCheckPoint(utils::ISerializable *global_model,
* So only CheckPoint with global_model if possible
* \sa LoadCheckPoint, VersionNumber
*/
inline void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL);
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

99
src/serializable.h Normal file
View File

@ -0,0 +1,99 @@
#ifndef RABIT_SERIALIZABLE_H
#define RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./utils.h"
/*!
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
namespace rabit {
/*!
* \brief interface of stream I/O, used by ISerializable
* \sa ISerializable
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of se*/
class ISerializable {
public:
/*! \brief load the model from file */
virtual void Load(IStream &fi) = 0;
/*! \brief save the model to the stream*/
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit
#endif

View File

@ -33,14 +33,14 @@ public:
rabit::Allreduce<OP>(sendrecvbuf, count);
}
inline int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {
inline int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
return rabit::LoadCheckPoint(global_model, local_model);
}
inline void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
rabit::CheckPoint(global_model, local_model);
}

View File

@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
}
// dummy model
class Model : public rabit::utils::ISerializable {
class Model : public rabit::ISerializable {
public:
// iterations
std::vector<float> data;
// load from stream
virtual void Load(rabit::utils::IStream &fi) {
virtual void Load(rabit::IStream &fi) {
fi.Read(&data);
}
/*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const {
virtual void Save(rabit::IStream &fo) const {
fo.Write(data);
}
virtual void InitModel(size_t n, float v) {

View File

@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
}
// dummy model
class Model : public rabit::utils::ISerializable {
class Model : public rabit::ISerializable {
public:
// iterations
std::vector<float> data;
// load from stream
virtual void Load(rabit::utils::IStream &fi) {
virtual void Load(rabit::IStream &fi) {
fi.Read(&data);
}
/*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const {
virtual void Save(rabit::IStream &fo) const {
fo.Write(data);
}
virtual void InitModel(size_t n) {

View File

@ -8,18 +8,18 @@
using namespace rabit;
// kmeans model
class Model : public rabit::utils::ISerializable {
class Model : public rabit::ISerializable {
public:
// matrix of centroids
Matrix centroids;
// load from stream
virtual void Load(rabit::utils::IStream &fi) {
virtual void Load(rabit::IStream &fi) {
fi.Read(&centroids.nrow, sizeof(centroids.nrow));
fi.Read(&centroids.ncol, sizeof(centroids.ncol));
fi.Read(&centroids.data);
}
/*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const {
virtual void Save(rabit::IStream &fo) const {
fo.Write(&centroids.nrow, sizeof(centroids.nrow));
fo.Write(&centroids.ncol, sizeof(centroids.ncol));
fo.Write(centroids.data);

View File

@ -2,6 +2,7 @@
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <cmath>
namespace rabit {