isolate iserializable
This commit is contained in:
parent
8c35cff02c
commit
6bf282c6c2
@ -111,8 +111,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) {
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
@ -131,8 +131,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) {
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
|
||||
@ -11,8 +11,8 @@
|
||||
#include <utility>
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
#include "./rabit.h"
|
||||
#include "./allreduce_robust.h"
|
||||
#include "./rabit.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@ -141,8 +141,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
@ -198,8 +198,8 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
|
||||
@ -75,8 +75,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL);
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
@ -93,8 +93,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL);
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
|
||||
10
src/engine.h
10
src/engine.h
@ -5,7 +5,7 @@
|
||||
*/
|
||||
#ifndef RABIT_ENGINE_H
|
||||
#define RABIT_ENGINE_H
|
||||
#include "./io.h"
|
||||
#include "./serializable.h"
|
||||
|
||||
namespace MPI {
|
||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||
@ -92,8 +92,8 @@ class IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) = 0;
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
@ -110,8 +110,8 @@ class IEngine {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) = 0;
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
@ -34,12 +34,12 @@ class MPIEngine : public IEngine {
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("MPI is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) {
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) {
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
|
||||
89
src/io.h
89
src/io.h
@ -5,99 +5,14 @@
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include "./utils.h"
|
||||
#include "./serializable.h"
|
||||
/*!
|
||||
* \file io.h
|
||||
* \brief general stream interface for serialization, I/O
|
||||
* \brief utilities that implements different serializable interface
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*!
|
||||
* \brief interface of stream I/O, used to serialize model
|
||||
*/
|
||||
class IStream {
|
||||
public:
|
||||
/*!
|
||||
* \brief read data from stream
|
||||
* \param ptr pointer to memory buffer
|
||||
* \param size size of block
|
||||
* \return usually is the size of data readed
|
||||
*/
|
||||
virtual size_t Read(void *ptr, size_t size) = 0;
|
||||
/*!
|
||||
* \brief write data to stream
|
||||
* \param ptr pointer to memory buffer
|
||||
* \param size size of block
|
||||
*/
|
||||
virtual void Write(const void *ptr, size_t size) = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IStream(void) {}
|
||||
|
||||
public:
|
||||
// helper functions to write various of data structures
|
||||
/*!
|
||||
* \brief binary serialize a vector
|
||||
* \param vec vector to be serialized
|
||||
*/
|
||||
template<typename T>
|
||||
inline void Write(const std::vector<T> &vec) {
|
||||
uint64_t sz = static_cast<uint64_t>(vec.size());
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
this->Write(&vec[0], sizeof(T) * sz);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief binary load a vector
|
||||
* \param out_vec vector to be loaded
|
||||
* \return whether load is successfull
|
||||
*/
|
||||
template<typename T>
|
||||
inline bool Read(std::vector<T> *out_vec) {
|
||||
uint64_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_vec->resize(sz);
|
||||
if (sz != 0) {
|
||||
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief binary serialize a string
|
||||
* \param str the string to be serialized
|
||||
*/
|
||||
inline void Write(const std::string &str) {
|
||||
uint64_t sz = static_cast<uint64_t>(str.length());
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
this->Write(&str[0], sizeof(char) * sz);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief binary load a string
|
||||
* \param out_str string to be loaded
|
||||
* \return whether load is successful
|
||||
*/
|
||||
inline bool Read(std::string *out_str) {
|
||||
uint64_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_str->resize(sz);
|
||||
if (sz != 0) {
|
||||
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief interface of se*/
|
||||
class ISerializable {
|
||||
public:
|
||||
/*! \brief load the model from file */
|
||||
virtual void Load(IStream &fi) = 0;
|
||||
/*! \brief save the model to the stream*/
|
||||
virtual void Save(IStream &fo) const = 0;
|
||||
};
|
||||
|
||||
/*! \brief interface of i/o stream that support seek */
|
||||
class ISeekStream: public IStream {
|
||||
public:
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
*/
|
||||
#ifndef RABIT_RABIT_INL_H
|
||||
#define RABIT_RABIT_INL_H
|
||||
// use engine for implementation
|
||||
#include "./engine.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@ -139,13 +141,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
||||
#endif // C++11
|
||||
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
|
||||
}
|
||||
// checkpoint the model, meaning we finished a stage of execution
|
||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
inline void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||
}
|
||||
// return the version number of currently stored model
|
||||
|
||||
15
src/rabit.h
15
src/rabit.h
@ -6,6 +6,7 @@
|
||||
* The actual implementation is redirected to rabit engine
|
||||
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
|
||||
*
|
||||
* rabit.h and serializable.h is all the user need to use rabit interface
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#include <string>
|
||||
@ -14,9 +15,8 @@
|
||||
#if __cplusplus >= 201103L
|
||||
#include <functional>
|
||||
#endif // C++11
|
||||
// rabit headers
|
||||
#include "./io.h"
|
||||
#include "./engine.h"
|
||||
// contains definition of ISerializable
|
||||
#include "./serializable.h"
|
||||
|
||||
/*! \brief namespace of rabit */
|
||||
namespace rabit {
|
||||
@ -31,7 +31,6 @@ struct Sum;
|
||||
/*! \brief perform bitwise OR */
|
||||
struct BitOR;
|
||||
} // namespace op
|
||||
|
||||
/*!
|
||||
* \brief intialize the rabit module, call this once function before using anything
|
||||
* \param argc number of arguments in argv
|
||||
@ -143,8 +142,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL);
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
@ -159,8 +158,8 @@ inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
* So only CheckPoint with global_model if possible
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL);
|
||||
inline void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
99
src/serializable.h
Normal file
99
src/serializable.h
Normal file
@ -0,0 +1,99 @@
|
||||
#ifndef RABIT_SERIALIZABLE_H
|
||||
#define RABIT_SERIALIZABLE_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "./utils.h"
|
||||
/*!
|
||||
* \file serializable.h
|
||||
* \brief defines serializable interface of rabit
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief interface of stream I/O, used by ISerializable
|
||||
* \sa ISerializable
|
||||
*/
|
||||
class IStream {
|
||||
public:
|
||||
/*!
|
||||
* \brief read data from stream
|
||||
* \param ptr pointer to memory buffer
|
||||
* \param size size of block
|
||||
* \return usually is the size of data readed
|
||||
*/
|
||||
virtual size_t Read(void *ptr, size_t size) = 0;
|
||||
/*!
|
||||
* \brief write data to stream
|
||||
* \param ptr pointer to memory buffer
|
||||
* \param size size of block
|
||||
*/
|
||||
virtual void Write(const void *ptr, size_t size) = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IStream(void) {}
|
||||
|
||||
public:
|
||||
// helper functions to write various of data structures
|
||||
/*!
|
||||
* \brief binary serialize a vector
|
||||
* \param vec vector to be serialized
|
||||
*/
|
||||
template<typename T>
|
||||
inline void Write(const std::vector<T> &vec) {
|
||||
uint64_t sz = static_cast<uint64_t>(vec.size());
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
this->Write(&vec[0], sizeof(T) * sz);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief binary load a vector
|
||||
* \param out_vec vector to be loaded
|
||||
* \return whether load is successfull
|
||||
*/
|
||||
template<typename T>
|
||||
inline bool Read(std::vector<T> *out_vec) {
|
||||
uint64_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_vec->resize(sz);
|
||||
if (sz != 0) {
|
||||
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief binary serialize a string
|
||||
* \param str the string to be serialized
|
||||
*/
|
||||
inline void Write(const std::string &str) {
|
||||
uint64_t sz = static_cast<uint64_t>(str.length());
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
this->Write(&str[0], sizeof(char) * sz);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief binary load a string
|
||||
* \param out_str string to be loaded
|
||||
* \return whether load is successful
|
||||
*/
|
||||
inline bool Read(std::string *out_str) {
|
||||
uint64_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_str->resize(sz);
|
||||
if (sz != 0) {
|
||||
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief interface of se*/
|
||||
class ISerializable {
|
||||
public:
|
||||
/*! \brief load the model from file */
|
||||
virtual void Load(IStream &fi) = 0;
|
||||
/*! \brief save the model to the stream*/
|
||||
virtual void Save(IStream &fo) const = 0;
|
||||
};
|
||||
} // namespace rabit
|
||||
#endif
|
||||
@ -33,14 +33,14 @@ public:
|
||||
rabit::Allreduce<OP>(sendrecvbuf, count);
|
||||
}
|
||||
|
||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
|
||||
return rabit::LoadCheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
inline void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
|
||||
rabit::CheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
|
||||
}
|
||||
|
||||
// dummy model
|
||||
class Model : public rabit::utils::ISerializable {
|
||||
class Model : public rabit::ISerializable {
|
||||
public:
|
||||
// iterations
|
||||
std::vector<float> data;
|
||||
// load from stream
|
||||
virtual void Load(rabit::utils::IStream &fi) {
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(&data);
|
||||
}
|
||||
/*! \brief save the model to the stream */
|
||||
virtual void Save(rabit::utils::IStream &fo) const {
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(data);
|
||||
}
|
||||
virtual void InitModel(size_t n, float v) {
|
||||
|
||||
@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
|
||||
}
|
||||
|
||||
// dummy model
|
||||
class Model : public rabit::utils::ISerializable {
|
||||
class Model : public rabit::ISerializable {
|
||||
public:
|
||||
// iterations
|
||||
std::vector<float> data;
|
||||
// load from stream
|
||||
virtual void Load(rabit::utils::IStream &fi) {
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(&data);
|
||||
}
|
||||
/*! \brief save the model to the stream */
|
||||
virtual void Save(rabit::utils::IStream &fo) const {
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(data);
|
||||
}
|
||||
virtual void InitModel(size_t n) {
|
||||
|
||||
@ -8,18 +8,18 @@
|
||||
using namespace rabit;
|
||||
|
||||
// kmeans model
|
||||
class Model : public rabit::utils::ISerializable {
|
||||
class Model : public rabit::ISerializable {
|
||||
public:
|
||||
// matrix of centroids
|
||||
Matrix centroids;
|
||||
// load from stream
|
||||
virtual void Load(rabit::utils::IStream &fi) {
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fi.Read(¢roids.data);
|
||||
}
|
||||
/*! \brief save the model to the stream */
|
||||
virtual void Save(rabit::utils::IStream &fo) const {
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||
fo.Write(centroids.data);
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
|
||||
namespace rabit {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user