isolate iserializable
This commit is contained in:
parent
8c35cff02c
commit
6bf282c6c2
@ -111,8 +111,8 @@ class AllreduceBase : public IEngine {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model = NULL) {
|
ISerializable *local_model = NULL) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -131,8 +131,8 @@ class AllreduceBase : public IEngine {
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -11,8 +11,8 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
#include "./rabit.h"
|
|
||||||
#include "./allreduce_robust.h"
|
#include "./allreduce_robust.h"
|
||||||
|
#include "./rabit.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -141,8 +141,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||||
}
|
}
|
||||||
@ -198,8 +198,8 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -75,8 +75,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model = NULL);
|
ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
@ -93,8 +93,8 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
|
|||||||
10
src/engine.h
10
src/engine.h
@ -5,7 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
#ifndef RABIT_ENGINE_H
|
#ifndef RABIT_ENGINE_H
|
||||||
#define RABIT_ENGINE_H
|
#define RABIT_ENGINE_H
|
||||||
#include "./io.h"
|
#include "./serializable.h"
|
||||||
|
|
||||||
namespace MPI {
|
namespace MPI {
|
||||||
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||||
@ -92,8 +92,8 @@ class IEngine {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model = NULL) = 0;
|
ISerializable *local_model = NULL) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
@ -110,8 +110,8 @@ class IEngine {
|
|||||||
*
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model = NULL) = 0;
|
const ISerializable *local_model = NULL) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -34,12 +34,12 @@ class MPIEngine : public IEngine {
|
|||||||
virtual void InitAfterException(void) {
|
virtual void InitAfterException(void) {
|
||||||
utils::Error("MPI is not fault tolerant");
|
utils::Error("MPI is not fault tolerant");
|
||||||
}
|
}
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model = NULL) {
|
ISerializable *local_model = NULL) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
|
|||||||
89
src/io.h
89
src/io.h
@ -5,99 +5,14 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
|
#include "./serializable.h"
|
||||||
/*!
|
/*!
|
||||||
* \file io.h
|
* \file io.h
|
||||||
* \brief general stream interface for serialization, I/O
|
* \brief utilities that implements different serializable interface
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
/*!
|
|
||||||
* \brief interface of stream I/O, used to serialize model
|
|
||||||
*/
|
|
||||||
class IStream {
|
|
||||||
public:
|
|
||||||
/*!
|
|
||||||
* \brief read data from stream
|
|
||||||
* \param ptr pointer to memory buffer
|
|
||||||
* \param size size of block
|
|
||||||
* \return usually is the size of data readed
|
|
||||||
*/
|
|
||||||
virtual size_t Read(void *ptr, size_t size) = 0;
|
|
||||||
/*!
|
|
||||||
* \brief write data to stream
|
|
||||||
* \param ptr pointer to memory buffer
|
|
||||||
* \param size size of block
|
|
||||||
*/
|
|
||||||
virtual void Write(const void *ptr, size_t size) = 0;
|
|
||||||
/*! \brief virtual destructor */
|
|
||||||
virtual ~IStream(void) {}
|
|
||||||
|
|
||||||
public:
|
|
||||||
// helper functions to write various of data structures
|
|
||||||
/*!
|
|
||||||
* \brief binary serialize a vector
|
|
||||||
* \param vec vector to be serialized
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline void Write(const std::vector<T> &vec) {
|
|
||||||
uint64_t sz = static_cast<uint64_t>(vec.size());
|
|
||||||
this->Write(&sz, sizeof(sz));
|
|
||||||
if (sz != 0) {
|
|
||||||
this->Write(&vec[0], sizeof(T) * sz);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary load a vector
|
|
||||||
* \param out_vec vector to be loaded
|
|
||||||
* \return whether load is successfull
|
|
||||||
*/
|
|
||||||
template<typename T>
|
|
||||||
inline bool Read(std::vector<T> *out_vec) {
|
|
||||||
uint64_t sz;
|
|
||||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
||||||
out_vec->resize(sz);
|
|
||||||
if (sz != 0) {
|
|
||||||
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary serialize a string
|
|
||||||
* \param str the string to be serialized
|
|
||||||
*/
|
|
||||||
inline void Write(const std::string &str) {
|
|
||||||
uint64_t sz = static_cast<uint64_t>(str.length());
|
|
||||||
this->Write(&sz, sizeof(sz));
|
|
||||||
if (sz != 0) {
|
|
||||||
this->Write(&str[0], sizeof(char) * sz);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/*!
|
|
||||||
* \brief binary load a string
|
|
||||||
* \param out_str string to be loaded
|
|
||||||
* \return whether load is successful
|
|
||||||
*/
|
|
||||||
inline bool Read(std::string *out_str) {
|
|
||||||
uint64_t sz;
|
|
||||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
||||||
out_str->resize(sz);
|
|
||||||
if (sz != 0) {
|
|
||||||
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief interface of se*/
|
|
||||||
class ISerializable {
|
|
||||||
public:
|
|
||||||
/*! \brief load the model from file */
|
|
||||||
virtual void Load(IStream &fi) = 0;
|
|
||||||
/*! \brief save the model to the stream*/
|
|
||||||
virtual void Save(IStream &fo) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief interface of i/o stream that support seek */
|
/*! \brief interface of i/o stream that support seek */
|
||||||
class ISeekStream: public IStream {
|
class ISeekStream: public IStream {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -6,6 +6,8 @@
|
|||||||
*/
|
*/
|
||||||
#ifndef RABIT_RABIT_INL_H
|
#ifndef RABIT_RABIT_INL_H
|
||||||
#define RABIT_RABIT_INL_H
|
#define RABIT_RABIT_INL_H
|
||||||
|
// use engine for implementation
|
||||||
|
#include "./engine.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
@ -139,13 +141,13 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
|||||||
#endif // C++11
|
#endif // C++11
|
||||||
|
|
||||||
// load latest check point
|
// load latest check point
|
||||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
inline int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
|
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
// checkpoint the model, meaning we finished a stage of execution
|
// checkpoint the model, meaning we finished a stage of execution
|
||||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
inline void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
// return the version number of currently stored model
|
// return the version number of currently stored model
|
||||||
|
|||||||
15
src/rabit.h
15
src/rabit.h
@ -6,6 +6,7 @@
|
|||||||
* The actual implementation is redirected to rabit engine
|
* The actual implementation is redirected to rabit engine
|
||||||
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
|
* Code only using this header can also compiled with MPI Allreduce(with no fault recovery),
|
||||||
*
|
*
|
||||||
|
* rabit.h and serializable.h is all the user need to use rabit interface
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -14,9 +15,8 @@
|
|||||||
#if __cplusplus >= 201103L
|
#if __cplusplus >= 201103L
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#endif // C++11
|
#endif // C++11
|
||||||
// rabit headers
|
// contains definition of ISerializable
|
||||||
#include "./io.h"
|
#include "./serializable.h"
|
||||||
#include "./engine.h"
|
|
||||||
|
|
||||||
/*! \brief namespace of rabit */
|
/*! \brief namespace of rabit */
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -31,7 +31,6 @@ struct Sum;
|
|||||||
/*! \brief perform bitwise OR */
|
/*! \brief perform bitwise OR */
|
||||||
struct BitOR;
|
struct BitOR;
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief intialize the rabit module, call this once function before using anything
|
* \brief intialize the rabit module, call this once function before using anything
|
||||||
* \param argc number of arguments in argv
|
* \param argc number of arguments in argv
|
||||||
@ -143,8 +142,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> pr
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
inline int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model = NULL);
|
ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
@ -159,8 +158,8 @@ inline int LoadCheckPoint(utils::ISerializable *global_model,
|
|||||||
* So only CheckPoint with global_model if possible
|
* So only CheckPoint with global_model if possible
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
inline void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
99
src/serializable.h
Normal file
99
src/serializable.h
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
#ifndef RABIT_SERIALIZABLE_H
|
||||||
|
#define RABIT_SERIALIZABLE_H
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "./utils.h"
|
||||||
|
/*!
|
||||||
|
* \file serializable.h
|
||||||
|
* \brief defines serializable interface of rabit
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
namespace rabit {
|
||||||
|
/*!
|
||||||
|
* \brief interface of stream I/O, used by ISerializable
|
||||||
|
* \sa ISerializable
|
||||||
|
*/
|
||||||
|
class IStream {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief read data from stream
|
||||||
|
* \param ptr pointer to memory buffer
|
||||||
|
* \param size size of block
|
||||||
|
* \return usually is the size of data readed
|
||||||
|
*/
|
||||||
|
virtual size_t Read(void *ptr, size_t size) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief write data to stream
|
||||||
|
* \param ptr pointer to memory buffer
|
||||||
|
* \param size size of block
|
||||||
|
*/
|
||||||
|
virtual void Write(const void *ptr, size_t size) = 0;
|
||||||
|
/*! \brief virtual destructor */
|
||||||
|
virtual ~IStream(void) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// helper functions to write various of data structures
|
||||||
|
/*!
|
||||||
|
* \brief binary serialize a vector
|
||||||
|
* \param vec vector to be serialized
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
inline void Write(const std::vector<T> &vec) {
|
||||||
|
uint64_t sz = static_cast<uint64_t>(vec.size());
|
||||||
|
this->Write(&sz, sizeof(sz));
|
||||||
|
if (sz != 0) {
|
||||||
|
this->Write(&vec[0], sizeof(T) * sz);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary load a vector
|
||||||
|
* \param out_vec vector to be loaded
|
||||||
|
* \return whether load is successfull
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
inline bool Read(std::vector<T> *out_vec) {
|
||||||
|
uint64_t sz;
|
||||||
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||||
|
out_vec->resize(sz);
|
||||||
|
if (sz != 0) {
|
||||||
|
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary serialize a string
|
||||||
|
* \param str the string to be serialized
|
||||||
|
*/
|
||||||
|
inline void Write(const std::string &str) {
|
||||||
|
uint64_t sz = static_cast<uint64_t>(str.length());
|
||||||
|
this->Write(&sz, sizeof(sz));
|
||||||
|
if (sz != 0) {
|
||||||
|
this->Write(&str[0], sizeof(char) * sz);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary load a string
|
||||||
|
* \param out_str string to be loaded
|
||||||
|
* \return whether load is successful
|
||||||
|
*/
|
||||||
|
inline bool Read(std::string *out_str) {
|
||||||
|
uint64_t sz;
|
||||||
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||||
|
out_str->resize(sz);
|
||||||
|
if (sz != 0) {
|
||||||
|
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief interface of se*/
|
||||||
|
class ISerializable {
|
||||||
|
public:
|
||||||
|
/*! \brief load the model from file */
|
||||||
|
virtual void Load(IStream &fi) = 0;
|
||||||
|
/*! \brief save the model to the stream*/
|
||||||
|
virtual void Save(IStream &fo) const = 0;
|
||||||
|
};
|
||||||
|
} // namespace rabit
|
||||||
|
#endif
|
||||||
@ -33,14 +33,14 @@ public:
|
|||||||
rabit::Allreduce<OP>(sendrecvbuf, count);
|
rabit::Allreduce<OP>(sendrecvbuf, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int LoadCheckPoint(utils::ISerializable *global_model,
|
inline int LoadCheckPoint(ISerializable *global_model,
|
||||||
utils::ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
|
utils::Assert(verify(loadCheckpoint), "[%d] error when loading checkpoint", rank);
|
||||||
return rabit::LoadCheckPoint(global_model, local_model);
|
return rabit::LoadCheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void CheckPoint(const utils::ISerializable *global_model,
|
inline void CheckPoint(const ISerializable *global_model,
|
||||||
const utils::ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
|
utils::Assert(verify(checkpoint), "[%d] error when checkpointing", rank);
|
||||||
rabit::CheckPoint(global_model, local_model);
|
rabit::CheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dummy model
|
// dummy model
|
||||||
class Model : public rabit::utils::ISerializable {
|
class Model : public rabit::ISerializable {
|
||||||
public:
|
public:
|
||||||
// iterations
|
// iterations
|
||||||
std::vector<float> data;
|
std::vector<float> data;
|
||||||
// load from stream
|
// load from stream
|
||||||
virtual void Load(rabit::utils::IStream &fi) {
|
virtual void Load(rabit::IStream &fi) {
|
||||||
fi.Read(&data);
|
fi.Read(&data);
|
||||||
}
|
}
|
||||||
/*! \brief save the model to the stream */
|
/*! \brief save the model to the stream */
|
||||||
virtual void Save(rabit::utils::IStream &fo) const {
|
virtual void Save(rabit::IStream &fo) const {
|
||||||
fo.Write(data);
|
fo.Write(data);
|
||||||
}
|
}
|
||||||
virtual void InitModel(size_t n, float v) {
|
virtual void InitModel(size_t n, float v) {
|
||||||
|
|||||||
@ -29,16 +29,16 @@ inline void CallEnd(const char *fun, int ntrial, int iter) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dummy model
|
// dummy model
|
||||||
class Model : public rabit::utils::ISerializable {
|
class Model : public rabit::ISerializable {
|
||||||
public:
|
public:
|
||||||
// iterations
|
// iterations
|
||||||
std::vector<float> data;
|
std::vector<float> data;
|
||||||
// load from stream
|
// load from stream
|
||||||
virtual void Load(rabit::utils::IStream &fi) {
|
virtual void Load(rabit::IStream &fi) {
|
||||||
fi.Read(&data);
|
fi.Read(&data);
|
||||||
}
|
}
|
||||||
/*! \brief save the model to the stream */
|
/*! \brief save the model to the stream */
|
||||||
virtual void Save(rabit::utils::IStream &fo) const {
|
virtual void Save(rabit::IStream &fo) const {
|
||||||
fo.Write(data);
|
fo.Write(data);
|
||||||
}
|
}
|
||||||
virtual void InitModel(size_t n) {
|
virtual void InitModel(size_t n) {
|
||||||
|
|||||||
@ -8,18 +8,18 @@
|
|||||||
using namespace rabit;
|
using namespace rabit;
|
||||||
|
|
||||||
// kmeans model
|
// kmeans model
|
||||||
class Model : public rabit::utils::ISerializable {
|
class Model : public rabit::ISerializable {
|
||||||
public:
|
public:
|
||||||
// matrix of centroids
|
// matrix of centroids
|
||||||
Matrix centroids;
|
Matrix centroids;
|
||||||
// load from stream
|
// load from stream
|
||||||
virtual void Load(rabit::utils::IStream &fi) {
|
virtual void Load(rabit::IStream &fi) {
|
||||||
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
fi.Read(¢roids.nrow, sizeof(centroids.nrow));
|
||||||
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
fi.Read(¢roids.ncol, sizeof(centroids.ncol));
|
||||||
fi.Read(¢roids.data);
|
fi.Read(¢roids.data);
|
||||||
}
|
}
|
||||||
/*! \brief save the model to the stream */
|
/*! \brief save the model to the stream */
|
||||||
virtual void Save(rabit::utils::IStream &fo) const {
|
virtual void Save(rabit::IStream &fo) const {
|
||||||
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
fo.Write(¢roids.nrow, sizeof(centroids.nrow));
|
||||||
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
fo.Write(¢roids.ncol, sizeof(centroids.ncol));
|
||||||
fo.Write(centroids.data);
|
fo.Write(centroids.data);
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user